├── .DS_Store ├── .gitignore ├── LICENSE ├── NLI └── deploy_NLI.py ├── README.md ├── figure └── framework.png ├── play.py └── question_generator ├── LICENSE.txt ├── README.md ├── deploy.py ├── setup.py └── text2text ├── __init__.py ├── biunilm ├── __init__.py ├── loader_utils.py └── seq2seq_loader.py ├── pytorch_pretrained_bert ├── __init__.py ├── __main__.py ├── file_utils.py ├── loss.py ├── modeling.py └── tokenization.py └── text_generator.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/AIH/e50cb8ab714afa0b791865da5233626dd6398cfe/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 ICTNLP 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NLI/deploy_NLI.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from flask import Flask, jsonify, request, render_template 4 | from flask_cors import CORS 5 | from transformers import * 6 | app = Flask(__name__) 7 | CORS(app) 8 | 9 | tokenizer = AutoTokenizer.from_pretrained("MNLI/") 10 | model = AutoModelForSequenceClassification.from_pretrained("MNLI/") 11 | model.to("cuda") 12 | model.eval() 13 | sm = torch.nn.Softmax() 14 | 15 | def score(c1, c2): 16 | c1_t = torch.LongTensor(tokenizer(c1, padding="longest")["input_ids"]).cuda() 17 | c2_t = torch.LongTensor(tokenizer(c2, padding="longest")["input_ids"]).cuda() 18 | text = torch.cat([c2_t, c1_t], dim=-1).unsqueeze(0) 19 | return sm(model(text)[0]) 20 | 21 | @app.route("/NLI", methods=["POST"]) 22 | def nli(): 23 | body = request.get_data() 24 | body = json.loads(body) 25 | res = body["res"] 26 | res_gold = body["res_gold"] 27 | results = score(res, res_gold).cpu().tolist()[0] 28 | 29 | return jsonify({"nli_score": results}) 30 | 31 | 32 | if __name__ == "__main__": 33 | app.run(host='0.0.0.0', port=8085) 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Addressing Inquiries about History: An Efficient and Practical Framework for Evaluating Open-domain Chatbot Consistency 2 | This repository contains the code for Findings of ACL 2021 paper Addressing Inquiries about History: An Efficient and Practical Framework for Evaluating Open-domain Chatbot Consistency. 3 | 4 | Our paper and code will be public soon! 5 | 6 | ## Overview 7 | 8 | We propose the *Addressing Inquiries the History* (**AIH**), an effective and practical framework for open-domain chatbot consistency evaluation. **AIH** contains two stages: (1) during the inquiry stage, questions about the facts and opinions mentioned in the conversation history are inserted into the conversation between chatbots. (2) during the contradiction recognition stage, the responses of the inserted questions are collected, and automatic models or human judges can be adopted to decide whether the responses are consistent with the dialogue history. 9 | 10 | ![Overview of the Addressing Inquiries about History framework.](figure/framework.png) 11 | 12 | 13 | 14 | ## How to Run? 15 | 16 | ### Entity Extraction & Question Generation 17 | 18 | Contradiction usually occurs when chatting about facts or opinions. Contradiction about facts or opinions is really annoying in the human-bot interaction. Therefore, we first extract the name entities from the chatbot2's responses using the stanza toolkit. 19 | 20 | Then, we generate questions based on the extracted entities. We employ text2text (https://github.com/artitw/text2text) as the question generator. 21 | 22 | Please refer to `question_generator/deploy.py` for deploying entity extraction and question generation. 23 | 24 | ### Deploy Chatbots 25 | 26 | Deploy several chatbots that can be accessed through interface like http://127.0.0.1:8082/interact. The chatbots should receive contexts and return responses. Note that the responses of generated questions should not be added in contexts. 27 | 28 | 29 | 30 | ### Automatic Evaluation 31 | 32 | `NLI/`: Automatic evaluation using Natural Language Inference model (https://huggingface.co/roberta-large-mnli). 33 | 34 | Please refer to `NLI/deploy_NLI.py` for deploying automatic evaluation. 35 | 36 | ### Bot Play 37 | 38 | After deploying chatbots and entity extraction & question generation, start bot-play. 39 | 40 | Please refer to `play.py` for deploying bot play. The results will be saved at `results/`. 41 | 42 | 43 | 44 | ## Citation 45 | 46 | Please cite our paper if you use **AIH** in your work. 47 | 48 | ```bibtex 49 | @inproceedings{li2021addressing, 50 | title={Addressing Inquiries about History: An Efficient and Practical Framework for Evaluating Open-domain Chatbot Consistency}, 51 | author={Li, Zekang and Zhang, Jinchao and Fei, Zhengcong and Feng, Yang and Zhou, Jie}, 52 | booktitle={Findings of Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics}, 53 | year={2021} 54 | } 55 | ``` 56 | 57 | -------------------------------------------------------------------------------- /figure/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/AIH/e50cb8ab714afa0b791865da5233626dd6398cfe/figure/framework.png -------------------------------------------------------------------------------- /play.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | import random 4 | import os 5 | import stanza 6 | 7 | nlp = stanza.Pipeline(lang='en', processors='tokenize') 8 | 9 | 10 | agent_pool = {"plato": "http://127.0.0.1:8082/interact", "blender": "http://127.0.0.1:8080/interact", "dialoflow": "http://127.0.0.1:8089/interact", "dialogpt": "http://127.0.0.1:8086/interact"} 11 | 12 | data = {} 13 | 14 | start_utterance = "hello" 15 | 16 | 17 | class PlatoAgent: 18 | def __init__(self, userid): 19 | self.userid = userid 20 | self.url = agent_pool["plato"] 21 | 22 | def act(self, text, replace=False): 23 | data = json.dumps({"userID": self.userid, "text": text, "replace": replace}) 24 | r = requests.post(self.url, data=data) 25 | text = json.loads(r.text)['body']['utterance'] 26 | return text 27 | 28 | class DialoFlowAgent: 29 | def __init__(self, userid): 30 | self.userid = userid 31 | self.url = agent_pool["dialoflow"] 32 | 33 | def act(self, text, replace=False): 34 | data = json.dumps({"userID": self.userid, "text": text, "replace": replace}) 35 | r = requests.post(self.url, data=data) 36 | text = json.loads(r.text)['body']['utterance'] 37 | return text 38 | 39 | class DialoGPTAgent: 40 | def __init__(self, userid): 41 | self.userid = userid 42 | self.url = agent_pool["dialogpt"] 43 | 44 | def act(self, text, replace=False): 45 | data = json.dumps({"userID": self.userid, "text": text, "replace": replace}) 46 | r = requests.post(self.url, data=data) 47 | text = json.loads(r.text)['body']['utterance'] 48 | return text 49 | 50 | 51 | 52 | class BlenderAgent: 53 | def __init__(self, userid): 54 | self.userid = userid 55 | self.url = agent_pool["blender"] 56 | 57 | def act(self, text, replace=False): 58 | if replace: 59 | data = text+self.userid + "*" 60 | else: 61 | data = text+self.userid 62 | r = requests.post(self.url, data=data.encode("utf-8")) 63 | text = json.loads(r.text)["text"] 64 | return text 65 | 66 | def gen_q(text): 67 | data = json.dumps({"text": text}) 68 | url = "http://127.0.0.1:8084/gen" 69 | r = requests.post(url, data=data) 70 | text = json.loads(r.text)['body']['text'] 71 | return text 72 | 73 | def nli(res, res_gold): 74 | data = json.dumps({"res": res, "res_gold": res_gold}) 75 | url = "http://127.0.0.1:8085/NLI" 76 | r = requests.post(url, data=data) 77 | score = json.loads(r.text)['nli_score'] 78 | return score 79 | 80 | PLAY_NUM = 1000 81 | TURN = 15 82 | METHOD = "GEN" 83 | agent_name_pool = list(agent_pool.keys()) 84 | for i in range(PLAY_NUM): 85 | userid = random.randrange(100000, 999997) 86 | userid1 = str(userid+1) 87 | userid2 = str(userid+2) 88 | if userid in data.keys(): 89 | continue 90 | else: 91 | data[userid] = [] 92 | agent1_name = random.choice(agent_name_pool) 93 | if agent1_name == "plato": 94 | agent1 = PlatoAgent(userid1) 95 | elif agent1_name == "blender": 96 | agent1 = BlenderAgent(userid1) 97 | elif agent1_name == "dialogpt": 98 | agent1 = DialoGPTAgent(userid1) 99 | elif agent1_name == "dialoflow": 100 | agent1 = DialoFlowAgent(userid1) 101 | agent2_name = random.choice(agent_name_pool) 102 | if agent2_name == "plato": 103 | agent2 = PlatoAgent(userid2) 104 | elif agent2_name == "blender": 105 | agent2 = BlenderAgent(userid2) 106 | elif agent2_name == "dialogpt": 107 | agent2 = DialoGPTAgent(userid2) 108 | elif agent2_name == "dialoflow": 109 | agent2 = DialoFlowAgent(userid2) 110 | 111 | r1 = None 112 | r2 = None 113 | questions = [] 114 | for j in range(TURN): 115 | if j == 0: 116 | r1 = start_utterance 117 | data[userid].append(r1+"\n") 118 | else: 119 | r1 = agent1.act(r2) 120 | print(r1) 121 | data[userid].append(r1+"\n") 122 | r2 = agent2.act(r1) 123 | print(r2) 124 | data[userid].append(r2+"\n") 125 | doc = nlp(r2) 126 | clean_r2 = [] 127 | for k in doc.sentences: 128 | if "?" not in k.text: 129 | clean_r2.append(k.text) 130 | if len(clean_r2) == 0: 131 | continue 132 | q = gen_q(" ".join(clean_r2)) 133 | print(" ".join(clean_r2), q) 134 | if len(q) > 0 and j > 0: 135 | q = random.choice(q) 136 | if len(q[0].strip()) > 0: 137 | data[userid].append("\t" + METHOD + ": " + q[0] + "\n") 138 | temp_r2 = agent2.act(q[0], replace=True) 139 | score = " ".join([str(x) for x in nli(temp_r2, q[1])]) 140 | print(METHOD, temp_r2, score) 141 | data[userid].append("\t" + METHOD + ": " + temp_r2 + "\t" + q[1] + "\t" + score + "\n") 142 | 143 | if not os.path.exists(METHOD + "/" + agent1_name + "_" + agent2_name): 144 | os.mkdir(METHOD + "/" + agent1_name + "_" + agent2_name) 145 | with open(METHOD + "/" + agent1_name + "_" + agent2_name + '/' + str(userid), "w") as f: 146 | f.writelines(data[userid]) 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /question_generator/LICENSE.txt: -------------------------------------------------------------------------------- 1 | I am providing code in the repository to you under an open source license. Because this is my personal repository, the license you receive to my code is from me and not my employer (Facebook). 2 | 3 | The MIT License (MIT) 4 | 5 | Copyright (c) Artit Wangperawong 6 | Copyright (c) Microsoft Corporation 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | -------------------------------------------------------------------------------- /question_generator/README.md: -------------------------------------------------------------------------------- 1 | # Text2Text: generate questions and summaries for your texts 2 | Input your text and get questions and summaries in return! 3 | 4 | ### Citation 5 | To cite this work, use the following BibTeX citation. 6 | 7 | ``` 8 | @misc{text2text@2020, 9 | author={Wangperawong, Artit}, 10 | title={Text2Text: generate questions and summaries for your texts}, 11 | year={2020}, 12 | publisher = {GitHub}, 13 | journal = {GitHub repository}, 14 | howpublished = {\url{https://github.com/artitw/text2text}}, 15 | url = {https://github.com/artitw/text2text} 16 | } 17 | ``` 18 | 19 | ## Requirements 20 | * pytorch 21 | * [pytorch-extension](https://github.com/artitw/apex) 22 | * numpy 23 | * many GBs of memory 24 | 25 | ## Installation 26 | ### A PyTorch Extension (APEX) 27 | ``` 28 | export CUDA_HOME=/usr/local/cuda-10.1 29 | pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" pytorch-extension 30 | ``` 31 | 32 | ### Text2Text 33 | ``` 34 | pip install text2text 35 | ``` 36 | 37 | ## Examples 38 | ### Colab demo 39 | 40 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1LE_ifTpOGO5QJCKNQYtZe6c_tjbwnulR) 41 | 42 | ### Demo Video 43 | Text2Text demo 45 | 46 | ### Obtain some texts 47 | ``` 48 | notre_dame_str = "As at most other universities, Notre Dame's students run a number of news media outlets. The nine student - run outlets include three newspapers, both a radio and television station, and several magazines and journals. Begun as a one - page journal in September 1876, the Scholastic magazine is issued twice monthly and claims to be the oldest continuous collegiate publication in the United States. The other magazine, The Juggler, is released twice a year and focuses on student literature and artwork. The Dome yearbook is published annually. The newspapers have varying publication interests, with The Observer published daily and mainly reporting university and other news, and staffed by students from both Notre Dame and Saint Mary's College. Unlike Scholastic and The Dome, The Observer is an independent publication and does not have a faculty advisor or any editorial oversight from the University. In 1987, when some students believed that The Observer began to show a conservative bias, a liberal newspaper, Common Sense was published. Likewise, in 2003, when other students believed that the paper showed a liberal bias, the conservative paper Irish Rover went into production. Neither paper is published as often as The Observer; however, all three are distributed to all students. Finally, in Spring 2008 an undergraduate journal for political science research, Beyond Politics, made its debut." 49 | 50 | bacteria_str = "Bacteria are a type of biological cell. They constitute a large domain of prokaryotic microorganisms. Typically a few micrometres in length, bacteria have a number of shapes, ranging from spheres to rods and spirals. Bacteria were among the first life forms to appear on Earth, and are present in most of its habitats." 51 | 52 | bio_str = "Biology is the science that studies life. What exactly is life? This may sound like a silly question with an obvious answer, but it is not easy to define life. For example, a branch of biology called virology studies viruses, which exhibit some of the characteristics of living entities but lack others. It turns out that although viruses can attack living organisms, cause diseases, and even reproduce, they do not meet the criteria that biologists use to define life." 53 | ``` 54 | 55 | ### Question Generation 56 | ``` 57 | from text2text.text_generator import TextGenerator 58 | qg = TextGenerator(output_type="question") 59 | 60 | qg.predict([ 61 | bio_str, 62 | bio_str, 63 | bio_str, 64 | bio_str, 65 | bio_str, 66 | "I will go to school today to take my math exam.", 67 | "I will go to school today to take my math exam.", 68 | "Tomorrow is my cousin's birthday. He will turn 24 years old.", 69 | notre_dame_str, 70 | bacteria_str, 71 | bacteria_str, 72 | bacteria_str, 73 | "I will go to school today to take my math exam. [SEP] school", 74 | "I will go to school today to take my math exam. [SEP] exam", 75 | "I will go to school today to take my math exam. [SEP] math", 76 | ]) 77 | ``` 78 | #### Generated Questions 79 | Note that the last three answers were controlled by specifying the `[SEP]` token in the input above. 80 | ``` 81 | [('What is biology the science that studies?', 'life'), 82 | ('What is the study of life?', 'studies'), 83 | ('What would you find the question " life "?', 'sound'), 84 | ('What can viruses do to living organisms?', 'attack'), 85 | ('What is the study of life?', 'studies'), 86 | ('Where will I go to to take my math exam?', 'school'), 87 | ('Where will I go to to take my math exam?', 'school'), 88 | ("What will my cousin's birthday?", 'turn'), 89 | ('What type of oversight does The Observer not have?', 'editorial'), 90 | ('What shape can bacteria be found in?', 'rods'), 91 | ('What is the typical length of bacteria?', 'micrometres'), 92 | ('What is the typical length of bacteria?', 'micrometres'), 93 | ('Where will I go to to take my math exam?', 'school'), 94 | ('What will I take after school?', 'exam'), 95 | ('What exam will I take?', 'math')] 96 | ``` 97 | 98 | ### Summary Generation 99 | ``` 100 | from text2text import TextGenerator 101 | sg = TextGenerator(output_type="summary") 102 | sg.predict([notre_dame_str, bacteria_str, bio_str]) 103 | ``` 104 | #### Generated Summaries 105 | ``` 106 | ["Notre Dame's students run nine student - run outlets . [X_SEP] Scholastic magazine claims to be the oldest continuous collegiate publication in the United States . [X_SEP] The Observer is an independent publication .", 107 | 'Bacteria were among the first life forms to appear on Earth .', 108 | 'biology is the science that studies life .'] 109 | ``` 110 | 111 | ## Questions? 112 | For questions or help using Text2Text, please submit a GitHub issue. 113 | 114 | ## Acknowledgements 115 | This package is based on [UniLM](https://github.com/microsoft/unilm) 116 | -------------------------------------------------------------------------------- /question_generator/deploy.py: -------------------------------------------------------------------------------- 1 | import json 2 | from flask import Flask, jsonify, request, render_template 3 | from flask_cors import CORS 4 | import stanza 5 | import truecase 6 | nlp = stanza.Pipeline(lang='en', processors='tokenize,ner,pos') 7 | 8 | app = Flask(__name__) 9 | CORS(app) 10 | 11 | from text2text.text_generator import TextGenerator 12 | qg = TextGenerator(output_type="question") 13 | 14 | def extract(features, select_list=None): 15 | upos = "" 16 | xpos = "" 17 | results_upos = [] 18 | results_xpos = [] 19 | temp_upos_text = [] 20 | temp_xpos_text = [] 21 | for feature in features.to_dict(): 22 | for i in feature: 23 | if upos != i["upos"]: 24 | results_upos.append((" ".join(temp_upos_text), upos)) 25 | temp_upos_text = [i["text"]] 26 | upos = i["upos"] 27 | else: 28 | temp_upos_text.append(i["text"]) 29 | if xpos != i["xpos"]: 30 | results_xpos.append((" ".join(temp_xpos_text), xpos)) 31 | temp_xpos_text = [i["text"]] 32 | xpos = i["xpos"] 33 | else: 34 | temp_xpos_text.append(i["text"]) 35 | 36 | if select_list is not None: 37 | r_list = [i.to_dict()["text"] for i in features.entities] 38 | for i in results_xpos: 39 | if i[1] in select_list: 40 | if i[0] not in " ".join(r_list): 41 | r_list.append(i[0]) 42 | return r_list 43 | else: 44 | return results_xpos 45 | 46 | @app.route("/gen", methods=["POST"]) 47 | def gen(): 48 | body = request.get_data() 49 | body = json.loads(body) 50 | text = truecase.get_true_case(body["text"]) 51 | select_list = ["CD", "NNP"] 52 | doc = nlp(text) 53 | ents = extract(doc, select_list) 54 | #ents = [ent.text for sent in doc.sentences for ent in sent.ents] 55 | print(ents) 56 | candidates = [text + " [SEP] " + i for i in ents] 57 | q = qg.predict(candidates) 58 | results = [] 59 | for i in q: 60 | if i != "\n": 61 | results.append((i[0].replace(" I ", " you ").replace(" i ", " you ").replace(" my ", " your "), text)) 62 | return jsonify({"body": {"text": results}}) 63 | 64 | if __name__ == "__main__": 65 | app.run(host='0.0.0.0', port=8084) 66 | -------------------------------------------------------------------------------- /question_generator/setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="text2text", 8 | version="0.0.9", 9 | author="Artit Wangperawong", 10 | author_email="artitw@gmail.com", 11 | description="Text2Text: generate questions and summaries for your texts", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/artitw/text2text", 15 | packages=setuptools.find_packages(), 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: OS Independent", 20 | ], 21 | keywords='bert nlp nlg text generation question summary summarization data science machine learning', 22 | install_requires=[ 23 | 'torch', 24 | 'tqdm', 25 | 'numpy', 26 | ], 27 | ) 28 | -------------------------------------------------------------------------------- /question_generator/text2text/__init__.py: -------------------------------------------------------------------------------- 1 | from .text_generator import TextGenerator -------------------------------------------------------------------------------- /question_generator/text2text/biunilm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/AIH/e50cb8ab714afa0b791865da5233626dd6398cfe/question_generator/text2text/biunilm/__init__.py -------------------------------------------------------------------------------- /question_generator/text2text/biunilm/loader_utils.py: -------------------------------------------------------------------------------- 1 | from random import randint, shuffle 2 | from random import random as rand 3 | import numpy as np 4 | 5 | import torch 6 | import torch.utils.data 7 | 8 | 9 | def get_random_word(vocab_words): 10 | i = randint(0, len(vocab_words)-1) 11 | return vocab_words[i] 12 | 13 | 14 | def batch_list_to_batch_tensors(batch): 15 | batch_tensors = [] 16 | for x in zip(*batch): 17 | if x[0] is None: 18 | batch_tensors.append(None) 19 | elif isinstance(x[0], torch.Tensor): 20 | batch_tensors.append(torch.stack(x)) 21 | else: 22 | batch_tensors.append(torch.tensor(x, dtype=torch.long)) 23 | return batch_tensors 24 | 25 | 26 | class TrieNode(object): 27 | def __init__(self): 28 | self.children = {} 29 | self.is_leaf = False 30 | 31 | def try_get_children(self, key): 32 | if key not in self.children: 33 | self.children[key] = TrieNode() 34 | return self.children[key] 35 | 36 | 37 | class TrieTree(object): 38 | def __init__(self): 39 | self.root = TrieNode() 40 | 41 | def add(self, tokens): 42 | r = self.root 43 | for token in tokens: 44 | r = r.try_get_children(token) 45 | r.is_leaf = True 46 | 47 | def get_pieces(self, tokens, offset): 48 | pieces = [] 49 | r = self.root 50 | token_id = 0 51 | last_valid = 0 52 | match_count = 0 53 | while last_valid < len(tokens): 54 | if token_id < len(tokens) and tokens[token_id] in r.children: 55 | r = r.children[tokens[token_id]] 56 | match_count += 1 57 | if r.is_leaf: 58 | last_valid = token_id 59 | token_id += 1 60 | else: 61 | pieces.append( 62 | list(range(token_id - match_count + offset, last_valid + 1 + offset))) 63 | last_valid += 1 64 | token_id = last_valid 65 | r = self.root 66 | match_count = 0 67 | 68 | return pieces 69 | 70 | 71 | def _get_word_split_index(tokens, st, end): 72 | split_idx = [] 73 | i = st 74 | while i < end: 75 | if (not tokens[i].startswith('##')) or (i == st): 76 | split_idx.append(i) 77 | i += 1 78 | split_idx.append(end) 79 | return split_idx 80 | 81 | 82 | def _expand_whole_word(tokens, st, end): 83 | new_st, new_end = st, end 84 | while (new_st >= 0) and tokens[new_st].startswith('##'): 85 | new_st -= 1 86 | while (new_end < len(tokens)) and tokens[new_end].startswith('##'): 87 | new_end += 1 88 | return new_st, new_end 89 | 90 | 91 | class Pipeline(): 92 | """ Pre-process Pipeline Class : callable """ 93 | 94 | def __init__(self): 95 | super().__init__() 96 | self.skipgram_prb = None 97 | self.skipgram_size = None 98 | self.pre_whole_word = None 99 | self.mask_whole_word = None 100 | self.word_subsample_prb = None 101 | self.sp_prob = None 102 | self.pieces_dir = None 103 | self.vocab_words = None 104 | self.pieces_threshold = 10 105 | self.trie = None 106 | self.call_count = 0 107 | self.offline_mode = False 108 | self.skipgram_size_geo_list = None 109 | self.span_same_mask = False 110 | 111 | def init_skipgram_size_geo_list(self, p): 112 | if p > 0: 113 | g_list = [] 114 | t = p 115 | for _ in range(self.skipgram_size): 116 | g_list.append(t) 117 | t *= (1-p) 118 | s = sum(g_list) 119 | self.skipgram_size_geo_list = [x/s for x in g_list] 120 | 121 | def create_trie_tree(self, pieces_dir): 122 | print("sp_prob = {}".format(self.sp_prob)) 123 | print("pieces_threshold = {}".format(self.pieces_threshold)) 124 | if pieces_dir is not None: 125 | self.trie = TrieTree() 126 | pieces_files = [pieces_dir] 127 | for token in self.vocab_words: 128 | self.trie.add([token]) 129 | for piece_file in pieces_files: 130 | print("Load piece file: {}".format(piece_file)) 131 | with open(piece_file, mode='r', encoding='utf-8') as reader: 132 | for line in reader: 133 | parts = line.split('\t') 134 | if int(parts[-1]) < self.pieces_threshold: 135 | pass 136 | tokens = [] 137 | for part in parts[:-1]: 138 | tokens.extend(part.split(' ')) 139 | self.trie.add(tokens) 140 | 141 | def __call__(self, instance): 142 | raise NotImplementedError 143 | 144 | # pre_whole_word: tokenize to words before masking 145 | # post whole word (--mask_whole_word): expand to words after masking 146 | def get_masked_pos(self, tokens, n_pred, add_skipgram=False, mask_segment=None, protect_range=None): 147 | if self.pieces_dir is not None and self.trie is None: 148 | self.create_trie_tree(self.pieces_dir) 149 | if self.pre_whole_word: 150 | if self.trie is not None: 151 | pieces = self.trie.get_pieces(tokens, 0) 152 | 153 | new_pieces = [] 154 | for piece in pieces: 155 | if len(new_pieces) > 0 and tokens[piece[0]].startswith("##"): 156 | new_pieces[-1].extend(piece) 157 | else: 158 | new_pieces.append(piece) 159 | del pieces 160 | pieces = new_pieces 161 | 162 | pre_word_split = list(_[-1] for _ in pieces) 163 | pre_word_split.append(len(tokens)) 164 | else: 165 | pre_word_split = _get_word_split_index(tokens, 0, len(tokens)) 166 | index2piece = None 167 | else: 168 | pre_word_split = list(range(0, len(tokens)+1)) 169 | 170 | if self.trie is not None: 171 | pieces = self.trie.get_pieces(tokens, 0) 172 | 173 | index2piece = {} 174 | for piece in pieces: 175 | for index in piece: 176 | index2piece[index] = (piece[0], piece[-1]) 177 | else: 178 | index2piece = None 179 | 180 | span_list = list(zip(pre_word_split[:-1], pre_word_split[1:])) 181 | 182 | # candidate positions of masked tokens 183 | cand_pos = [] 184 | special_pos = set() 185 | if mask_segment: 186 | for i, sp in enumerate(span_list): 187 | sp_st, sp_end = sp 188 | if (sp_end-sp_st == 1) and tokens[sp_st].endswith('SEP]'): 189 | segment_index = i 190 | break 191 | for i, sp in enumerate(span_list): 192 | sp_st, sp_end = sp 193 | if (sp_end-sp_st == 1) and (tokens[sp_st].endswith('CLS]') or tokens[sp_st].endswith('SEP]')): 194 | special_pos.add(i) 195 | else: 196 | if mask_segment: 197 | if ((i < segment_index) and ('a' in mask_segment)) or ((i > segment_index) and ('b' in mask_segment)): 198 | cand_pos.append(i) 199 | else: 200 | cand_pos.append(i) 201 | shuffle(cand_pos) 202 | 203 | masked_pos = set() 204 | for i_span in cand_pos: 205 | if len(masked_pos) >= n_pred: 206 | break 207 | cand_st, cand_end = span_list[i_span] 208 | if len(masked_pos)+cand_end-cand_st > n_pred: 209 | continue 210 | if any(p in masked_pos for p in range(cand_st, cand_end)): 211 | continue 212 | 213 | n_span = 1 214 | if index2piece is not None: 215 | p_start, p_end = index2piece[i_span] 216 | if p_start < p_end and (rand() < self.sp_prob): 217 | # n_span = p_end - p_start + 1 218 | st_span, end_span = p_start, p_end + 1 219 | else: 220 | st_span, end_span = i_span, i_span + 1 221 | else: 222 | rand_skipgram_size = 0 223 | # ngram 224 | if self.skipgram_size_geo_list: 225 | # sampling ngram size from geometric distribution 226 | rand_skipgram_size = np.random.choice( 227 | len(self.skipgram_size_geo_list), 1, p=self.skipgram_size_geo_list)[0] + 1 228 | else: 229 | if add_skipgram and (self.skipgram_prb > 0) and (self.skipgram_size >= 2) and (rand() < self.skipgram_prb): 230 | rand_skipgram_size = min( 231 | randint(2, self.skipgram_size), len(span_list)-i_span) 232 | for n in range(2, rand_skipgram_size+1): 233 | tail_st, tail_end = span_list[i_span+n-1] 234 | if (tail_end-tail_st == 1) and (tail_st in special_pos): 235 | break 236 | if len(masked_pos)+tail_end-cand_st > n_pred: 237 | break 238 | n_span = n 239 | st_span, end_span = i_span, i_span + n_span 240 | 241 | if self.mask_whole_word: 242 | # pre_whole_word==False: position index of span_list is the same as tokens 243 | st_span, end_span = _expand_whole_word( 244 | tokens, st_span, end_span) 245 | 246 | # subsampling according to frequency 247 | if self.word_subsample_prb: 248 | skip_pos = set() 249 | if self.pre_whole_word: 250 | w_span_list = span_list[st_span:end_span] 251 | else: 252 | split_idx = _get_word_split_index( 253 | tokens, st_span, end_span) 254 | w_span_list = list( 255 | zip(split_idx[:-1], split_idx[1:])) 256 | for i, sp in enumerate(w_span_list): 257 | sp_st, sp_end = sp 258 | if sp_end-sp_st == 1: 259 | w_cat = tokens[sp_st] 260 | else: 261 | w_cat = ''.join(tokens[sp_st:sp_end]) 262 | if (w_cat in self.word_subsample_prb) and (rand() < self.word_subsample_prb[w_cat]): 263 | for k in range(sp_st, sp_end): 264 | skip_pos.add(k) 265 | else: 266 | skip_pos = None 267 | 268 | for sp in range(st_span, end_span): 269 | for mp in range(span_list[sp][0], span_list[sp][1]): 270 | if not(skip_pos and (mp in skip_pos)) and (mp not in special_pos) and not(protect_range and (protect_range[0] <= mp < protect_range[1])): 271 | masked_pos.add(mp) 272 | 273 | if len(masked_pos) < n_pred: 274 | shuffle(cand_pos) 275 | for pos in cand_pos: 276 | if len(masked_pos) >= n_pred: 277 | break 278 | if pos not in masked_pos: 279 | masked_pos.add(pos) 280 | masked_pos = list(masked_pos) 281 | if len(masked_pos) > n_pred: 282 | # shuffle(masked_pos) 283 | masked_pos = masked_pos[:n_pred] 284 | return masked_pos 285 | 286 | def replace_masked_tokens(self, tokens, masked_pos): 287 | if self.span_same_mask: 288 | masked_pos = sorted(list(masked_pos)) 289 | prev_pos, prev_rand = None, None 290 | for pos in masked_pos: 291 | if self.span_same_mask and (pos-1 == prev_pos): 292 | t_rand = prev_rand 293 | else: 294 | t_rand = rand() 295 | if t_rand < 0.8: # 80% 296 | tokens[pos] = '[MASK]' 297 | elif t_rand < 0.9: # 10% 298 | tokens[pos] = get_random_word(self.vocab_words) 299 | prev_pos, prev_rand = pos, t_rand 300 | -------------------------------------------------------------------------------- /question_generator/text2text/biunilm/seq2seq_loader.py: -------------------------------------------------------------------------------- 1 | from random import randint, shuffle, choice 2 | from random import random as rand 3 | import math 4 | import torch 5 | 6 | from .loader_utils import get_random_word, batch_list_to_batch_tensors, Pipeline 7 | 8 | # Input file format : 9 | # 1. One sentence per line. These should ideally be actual sentences, 10 | # not entire paragraphs or arbitrary spans of text. (Because we use 11 | # the sentence boundaries for the "next sentence prediction" task). 12 | # 2. Blank lines between documents. Document boundaries are needed 13 | # so that the "next sentence prediction" task doesn't span between documents. 14 | 15 | 16 | def truncate_tokens_pair(tokens_a, tokens_b, max_len, max_len_a=0, max_len_b=0, trunc_seg=None, always_truncate_tail=False): 17 | num_truncated_a = [0, 0] 18 | num_truncated_b = [0, 0] 19 | while True: 20 | if len(tokens_a) + len(tokens_b) <= max_len: 21 | break 22 | if (max_len_a > 0) and len(tokens_a) > max_len_a: 23 | trunc_tokens = tokens_a 24 | num_truncated = num_truncated_a 25 | elif (max_len_b > 0) and len(tokens_b) > max_len_b: 26 | trunc_tokens = tokens_b 27 | num_truncated = num_truncated_b 28 | elif trunc_seg: 29 | # truncate the specified segment 30 | if trunc_seg == 'a': 31 | trunc_tokens = tokens_a 32 | num_truncated = num_truncated_a 33 | else: 34 | trunc_tokens = tokens_b 35 | num_truncated = num_truncated_b 36 | else: 37 | # truncate the longer segment 38 | if len(tokens_a) > len(tokens_b): 39 | trunc_tokens = tokens_a 40 | num_truncated = num_truncated_a 41 | else: 42 | trunc_tokens = tokens_b 43 | num_truncated = num_truncated_b 44 | # whether always truncate source sequences 45 | if (not always_truncate_tail) and (rand() < 0.5): 46 | del trunc_tokens[0] 47 | num_truncated[0] += 1 48 | else: 49 | trunc_tokens.pop() 50 | num_truncated[1] += 1 51 | return num_truncated_a, num_truncated_b 52 | 53 | 54 | class Seq2SeqDataset(torch.utils.data.Dataset): 55 | """ Load sentence pair (sequential or random order) from corpus """ 56 | 57 | def __init__(self, file_src, file_tgt, batch_size, tokenizer, max_len, file_oracle=None, short_sampling_prob=0.1, sent_reverse_order=False, bi_uni_pipeline=[]): 58 | super().__init__() 59 | self.tokenizer = tokenizer # tokenize function 60 | self.max_len = max_len # maximum length of tokens 61 | self.short_sampling_prob = short_sampling_prob 62 | self.bi_uni_pipeline = bi_uni_pipeline 63 | self.batch_size = batch_size 64 | self.sent_reverse_order = sent_reverse_order 65 | 66 | # read the file into memory 67 | self.ex_list = [] 68 | if file_oracle is None: 69 | with open(file_src, "r", encoding='utf-8') as f_src, open(file_tgt, "r", encoding='utf-8') as f_tgt: 70 | for src, tgt in zip(f_src, f_tgt): 71 | src_tk = tokenizer.tokenize(src.strip()) 72 | tgt_tk = tokenizer.tokenize(tgt.strip()) 73 | assert len(src_tk) > 0 74 | assert len(tgt_tk) > 0 75 | self.ex_list.append((src_tk, tgt_tk)) 76 | else: 77 | with open(file_src, "r", encoding='utf-8') as f_src, \ 78 | open(file_tgt, "r", encoding='utf-8') as f_tgt, \ 79 | open(file_oracle, "r", encoding='utf-8') as f_orc: 80 | for src, tgt, orc in zip(f_src, f_tgt, f_orc): 81 | src_tk = tokenizer.tokenize(src.strip()) 82 | tgt_tk = tokenizer.tokenize(tgt.strip()) 83 | s_st, labl = orc.split('\t') 84 | s_st = [int(x) for x in s_st.split()] 85 | labl = [int(x) for x in labl.split()] 86 | self.ex_list.append((src_tk, tgt_tk, s_st, labl)) 87 | print('Load {0} documents'.format(len(self.ex_list))) 88 | 89 | def __len__(self): 90 | return len(self.ex_list) 91 | 92 | def __getitem__(self, idx): 93 | instance = self.ex_list[idx] 94 | proc = choice(self.bi_uni_pipeline) 95 | instance = proc(instance) 96 | return instance 97 | 98 | def __iter__(self): # iterator to load data 99 | for __ in range(math.ceil(len(self.ex_list) / float(self.batch_size))): 100 | batch = [] 101 | for __ in range(self.batch_size): 102 | idx = randint(0, len(self.ex_list)-1) 103 | batch.append(self.__getitem__(idx)) 104 | # To Tensor 105 | yield batch_list_to_batch_tensors(batch) 106 | 107 | 108 | class Preprocess4Seq2seq(Pipeline): 109 | """ Pre-processing steps for pretraining transformer """ 110 | 111 | def __init__(self, max_pred, mask_prob, vocab_words, indexer, max_len=512, skipgram_prb=0, skipgram_size=0, block_mask=False, mask_whole_word=False, new_segment_ids=False, truncate_config={}, mask_source_words=False, mode="s2s", has_oracle=False, num_qkv=0, s2s_special_token=False, s2s_add_segment=False, s2s_share_segment=False, pos_shift=False): 112 | super().__init__() 113 | self.max_len = max_len 114 | self.max_pred = max_pred # max tokens of prediction 115 | self.mask_prob = mask_prob # masking probability 116 | self.vocab_words = vocab_words # vocabulary (sub)words 117 | self.indexer = indexer # function from token to token index 118 | self.max_len = max_len 119 | self._tril_matrix = torch.tril(torch.ones( 120 | (max_len, max_len), dtype=torch.long)) 121 | self.skipgram_prb = skipgram_prb 122 | self.skipgram_size = skipgram_size 123 | self.mask_whole_word = mask_whole_word 124 | self.new_segment_ids = new_segment_ids 125 | self.always_truncate_tail = truncate_config.get( 126 | 'always_truncate_tail', False) 127 | self.max_len_a = truncate_config.get('max_len_a', None) 128 | self.max_len_b = truncate_config.get('max_len_b', None) 129 | self.trunc_seg = truncate_config.get('trunc_seg', None) 130 | self.task_idx = 3 # relax projection layer for different tasks 131 | self.mask_source_words = mask_source_words 132 | assert mode in ("s2s", "l2r") 133 | self.mode = mode 134 | self.has_oracle = has_oracle 135 | self.num_qkv = num_qkv 136 | self.s2s_special_token = s2s_special_token 137 | self.s2s_add_segment = s2s_add_segment 138 | self.s2s_share_segment = s2s_share_segment 139 | self.pos_shift = pos_shift 140 | 141 | def __call__(self, instance): 142 | tokens_a, tokens_b = instance[:2] 143 | 144 | if self.pos_shift: 145 | tokens_b = ['[S2S_SOS]'] + tokens_b 146 | 147 | # -3 for special tokens [CLS], [SEP], [SEP] 148 | num_truncated_a, _ = truncate_tokens_pair(tokens_a, tokens_b, self.max_len - 3, max_len_a=self.max_len_a, 149 | max_len_b=self.max_len_b, trunc_seg=self.trunc_seg, always_truncate_tail=self.always_truncate_tail) 150 | 151 | # Add Special Tokens 152 | if self.s2s_special_token: 153 | tokens = ['[S2S_CLS]'] + tokens_a + \ 154 | ['[S2S_SEP]'] + tokens_b + ['[SEP]'] 155 | else: 156 | tokens = ['[CLS]'] + tokens_a + ['[SEP]'] + tokens_b + ['[SEP]'] 157 | 158 | if self.new_segment_ids: 159 | if self.mode == "s2s": 160 | if self.s2s_add_segment: 161 | if self.s2s_share_segment: 162 | segment_ids = [0] + [1] * \ 163 | (len(tokens_a)+1) + [5]*(len(tokens_b)+1) 164 | else: 165 | segment_ids = [4] + [6] * \ 166 | (len(tokens_a)+1) + [5]*(len(tokens_b)+1) 167 | else: 168 | segment_ids = [4] * (len(tokens_a)+2) + \ 169 | [5]*(len(tokens_b)+1) 170 | else: 171 | segment_ids = [2] * (len(tokens)) 172 | else: 173 | segment_ids = [0]*(len(tokens_a)+2) + [1]*(len(tokens_b)+1) 174 | 175 | if self.pos_shift: 176 | n_pred = min(self.max_pred, len(tokens_b)) 177 | masked_pos = [len(tokens_a)+2+i for i in range(len(tokens_b))] 178 | masked_weights = [1]*n_pred 179 | masked_ids = self.indexer(tokens_b[1:]+['[SEP]']) 180 | else: 181 | # For masked Language Models 182 | # the number of prediction is sometimes less than max_pred when sequence is short 183 | effective_length = len(tokens_b) 184 | if self.mask_source_words: 185 | effective_length += len(tokens_a) 186 | n_pred = min(self.max_pred, max( 187 | 1, int(round(effective_length*self.mask_prob)))) 188 | # candidate positions of masked tokens 189 | cand_pos = [] 190 | special_pos = set() 191 | for i, tk in enumerate(tokens): 192 | # only mask tokens_b (target sequence) 193 | # we will mask [SEP] as an ending symbol 194 | if (i >= len(tokens_a)+2) and (tk != '[CLS]'): 195 | cand_pos.append(i) 196 | elif self.mask_source_words and (i < len(tokens_a)+2) and (tk != '[CLS]') and (not tk.startswith('[SEP')): 197 | cand_pos.append(i) 198 | else: 199 | special_pos.add(i) 200 | shuffle(cand_pos) 201 | 202 | masked_pos = set() 203 | max_cand_pos = max(cand_pos) 204 | for pos in cand_pos: 205 | if len(masked_pos) >= n_pred: 206 | break 207 | if pos in masked_pos: 208 | continue 209 | 210 | def _expand_whole_word(st, end): 211 | new_st, new_end = st, end 212 | while (new_st >= 0) and tokens[new_st].startswith('##'): 213 | new_st -= 1 214 | while (new_end < len(tokens)) and tokens[new_end].startswith('##'): 215 | new_end += 1 216 | return new_st, new_end 217 | 218 | if (self.skipgram_prb > 0) and (self.skipgram_size >= 2) and (rand() < self.skipgram_prb): 219 | # ngram 220 | cur_skipgram_size = randint(2, self.skipgram_size) 221 | if self.mask_whole_word: 222 | st_pos, end_pos = _expand_whole_word( 223 | pos, pos + cur_skipgram_size) 224 | else: 225 | st_pos, end_pos = pos, pos + cur_skipgram_size 226 | else: 227 | # directly mask 228 | if self.mask_whole_word: 229 | st_pos, end_pos = _expand_whole_word(pos, pos + 1) 230 | else: 231 | st_pos, end_pos = pos, pos + 1 232 | 233 | for mp in range(st_pos, end_pos): 234 | if (0 < mp <= max_cand_pos) and (mp not in special_pos): 235 | masked_pos.add(mp) 236 | else: 237 | break 238 | 239 | masked_pos = list(masked_pos) 240 | if len(masked_pos) > n_pred: 241 | shuffle(masked_pos) 242 | masked_pos = masked_pos[:n_pred] 243 | 244 | masked_tokens = [tokens[pos] for pos in masked_pos] 245 | for pos in masked_pos: 246 | if rand() < 0.8: # 80% 247 | tokens[pos] = '[MASK]' 248 | elif rand() < 0.5: # 10% 249 | tokens[pos] = get_random_word(self.vocab_words) 250 | # when n_pred < max_pred, we only calculate loss within n_pred 251 | masked_weights = [1]*len(masked_tokens) 252 | 253 | # Token Indexing 254 | masked_ids = self.indexer(masked_tokens) 255 | # Token Indexing 256 | input_ids = self.indexer(tokens) 257 | 258 | # Zero Padding 259 | n_pad = self.max_len - len(input_ids) 260 | input_ids.extend([0]*n_pad) 261 | segment_ids.extend([0]*n_pad) 262 | 263 | if self.num_qkv > 1: 264 | mask_qkv = [0]*(len(tokens_a)+2) + [1] * (len(tokens_b)+1) 265 | mask_qkv.extend([0]*n_pad) 266 | else: 267 | mask_qkv = None 268 | 269 | input_mask = torch.zeros(self.max_len, self.max_len, dtype=torch.long) 270 | if self.mode == "s2s": 271 | input_mask[:, :len(tokens_a)+2].fill_(1) 272 | second_st, second_end = len( 273 | tokens_a)+2, len(tokens_a)+len(tokens_b)+3 274 | input_mask[second_st:second_end, second_st:second_end].copy_( 275 | self._tril_matrix[:second_end-second_st, :second_end-second_st]) 276 | else: 277 | st, end = 0, len(tokens_a) + len(tokens_b) + 3 278 | input_mask[st:end, st:end].copy_(self._tril_matrix[:end, :end]) 279 | 280 | # Zero Padding for masked target 281 | if self.max_pred > n_pred: 282 | n_pad = self.max_pred - n_pred 283 | if masked_ids is not None: 284 | masked_ids.extend([0]*n_pad) 285 | if masked_pos is not None: 286 | masked_pos.extend([0]*n_pad) 287 | if masked_weights is not None: 288 | masked_weights.extend([0]*n_pad) 289 | 290 | oracle_pos = None 291 | oracle_weights = None 292 | oracle_labels = None 293 | if self.has_oracle: 294 | s_st, labls = instance[2:] 295 | oracle_pos = [] 296 | oracle_labels = [] 297 | for st, lb in zip(s_st, labls): 298 | st = st - num_truncated_a[0] 299 | if st > 0 and st < len(tokens_a): 300 | oracle_pos.append(st) 301 | oracle_labels.append(lb) 302 | oracle_pos = oracle_pos[:20] 303 | oracle_labels = oracle_labels[:20] 304 | oracle_weights = [1] * len(oracle_pos) 305 | if len(oracle_pos) < 20: 306 | x_pad = 20 - len(oracle_pos) 307 | oracle_pos.extend([0] * x_pad) 308 | oracle_labels.extend([0] * x_pad) 309 | oracle_weights.extend([0] * x_pad) 310 | 311 | return (input_ids, segment_ids, input_mask, mask_qkv, masked_ids, 312 | masked_pos, masked_weights, -1, self.task_idx, 313 | oracle_pos, oracle_weights, oracle_labels) 314 | 315 | return (input_ids, segment_ids, input_mask, mask_qkv, masked_ids, masked_pos, masked_weights, -1, self.task_idx) 316 | 317 | 318 | class Preprocess4Seq2seqDecoder(Pipeline): 319 | """ Pre-processing steps for pretraining transformer """ 320 | 321 | def __init__(self, vocab_words, indexer, max_len=512, max_tgt_length=128, new_segment_ids=False, mode="s2s", num_qkv=0, s2s_special_token=False, s2s_add_segment=False, s2s_share_segment=False, pos_shift=False): 322 | super().__init__() 323 | self.max_len = max_len 324 | self.vocab_words = vocab_words # vocabulary (sub)words 325 | self.indexer = indexer # function from token to token index 326 | self.max_len = max_len 327 | self._tril_matrix = torch.tril(torch.ones( 328 | (max_len, max_len), dtype=torch.long)) 329 | self.new_segment_ids = new_segment_ids 330 | self.task_idx = 3 # relax projection layer for different tasks 331 | assert mode in ("s2s", "l2r") 332 | self.mode = mode 333 | self.max_tgt_length = max_tgt_length 334 | self.num_qkv = num_qkv 335 | self.s2s_special_token = s2s_special_token 336 | self.s2s_add_segment = s2s_add_segment 337 | self.s2s_share_segment = s2s_share_segment 338 | self.pos_shift = pos_shift 339 | 340 | def __call__(self, instance): 341 | tokens_a, max_a_len = instance 342 | 343 | # Add Special Tokens 344 | if self.s2s_special_token: 345 | padded_tokens_a = ['[S2S_CLS]'] + tokens_a + ['[S2S_SEP]'] 346 | else: 347 | padded_tokens_a = ['[CLS]'] + tokens_a + ['[SEP]'] 348 | assert len(padded_tokens_a) <= max_a_len + 2 349 | if max_a_len + 2 > len(padded_tokens_a): 350 | padded_tokens_a += ['[PAD]'] * \ 351 | (max_a_len + 2 - len(padded_tokens_a)) 352 | assert len(padded_tokens_a) == max_a_len + 2 353 | max_len_in_batch = min(self.max_tgt_length + 354 | max_a_len + 2, self.max_len) 355 | tokens = padded_tokens_a 356 | if self.new_segment_ids: 357 | if self.mode == "s2s": 358 | _enc_seg1 = 0 if self.s2s_share_segment else 4 359 | if self.s2s_add_segment: 360 | if self.s2s_share_segment: 361 | segment_ids = [ 362 | 0] + [1]*(len(padded_tokens_a)-1) + [5]*(max_len_in_batch - len(padded_tokens_a)) 363 | else: 364 | segment_ids = [ 365 | 4] + [6]*(len(padded_tokens_a)-1) + [5]*(max_len_in_batch - len(padded_tokens_a)) 366 | else: 367 | segment_ids = [4]*(len(padded_tokens_a)) + \ 368 | [5]*(max_len_in_batch - len(padded_tokens_a)) 369 | else: 370 | segment_ids = [2]*max_len_in_batch 371 | else: 372 | segment_ids = [0]*(len(padded_tokens_a)) \ 373 | + [1]*(max_len_in_batch - len(padded_tokens_a)) 374 | 375 | if self.num_qkv > 1: 376 | mask_qkv = [0]*(len(padded_tokens_a)) + [1] * \ 377 | (max_len_in_batch - len(padded_tokens_a)) 378 | else: 379 | mask_qkv = None 380 | 381 | position_ids = [] 382 | for i in range(len(tokens_a) + 2): 383 | position_ids.append(i) 384 | for i in range(len(tokens_a) + 2, max_a_len + 2): 385 | position_ids.append(0) 386 | for i in range(max_a_len + 2, max_len_in_batch): 387 | position_ids.append(i - (max_a_len + 2) + len(tokens_a) + 2) 388 | 389 | # Token Indexing 390 | input_ids = self.indexer(tokens) 391 | 392 | # Zero Padding 393 | input_mask = torch.zeros( 394 | max_len_in_batch, max_len_in_batch, dtype=torch.long) 395 | if self.mode == "s2s": 396 | input_mask[:, :len(tokens_a)+2].fill_(1) 397 | else: 398 | st, end = 0, len(tokens_a) + 2 399 | input_mask[st:end, st:end].copy_( 400 | self._tril_matrix[:end, :end]) 401 | input_mask[end:, :len(tokens_a)+2].fill_(1) 402 | second_st, second_end = len(padded_tokens_a), max_len_in_batch 403 | 404 | input_mask[second_st:second_end, second_st:second_end].copy_( 405 | self._tril_matrix[:second_end-second_st, :second_end-second_st]) 406 | 407 | return (input_ids, segment_ids, position_ids, input_mask, mask_qkv, self.task_idx) 408 | -------------------------------------------------------------------------------- /question_generator/text2text/pytorch_pretrained_bert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/AIH/e50cb8ab714afa0b791865da5233626dd6398cfe/question_generator/text2text/pytorch_pretrained_bert/__init__.py -------------------------------------------------------------------------------- /question_generator/text2text/pytorch_pretrained_bert/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | try: 5 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 6 | except ModuleNotFoundError: 7 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 8 | "In that case, it requires TensorFlow to be installed. Please see " 9 | "https://www.tensorflow.org/install/ for installation instructions.") 10 | raise 11 | 12 | if len(sys.argv) != 5: 13 | # pylint: disable=line-too-long 14 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 15 | else: 16 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 17 | TF_CONFIG = sys.argv.pop() 18 | TF_CHECKPOINT = sys.argv.pop() 19 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 20 | 21 | if __name__ == '__main__': 22 | main() 23 | -------------------------------------------------------------------------------- /question_generator/text2text/pytorch_pretrained_bert/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | 7 | import os 8 | import logging 9 | import shutil 10 | import tempfile 11 | import json 12 | from urllib.parse import urlparse 13 | from pathlib import Path 14 | from typing import Optional, Tuple, Union, IO, Callable, Set 15 | from hashlib import sha256 16 | from functools import wraps 17 | 18 | from tqdm import tqdm 19 | 20 | import boto3 21 | from botocore.exceptions import ClientError 22 | import requests 23 | 24 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 25 | 26 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 27 | Path.home() / '.pytorch_pretrained_bert')) 28 | 29 | 30 | def url_to_filename(url: str, etag: str = None) -> str: 31 | """ 32 | Convert `url` into a hashed filename in a repeatable way. 33 | If `etag` is specified, append its hash to the url's, delimited 34 | by a period. 35 | """ 36 | url_bytes = url.encode('utf-8') 37 | url_hash = sha256(url_bytes) 38 | filename = url_hash.hexdigest() 39 | 40 | if etag: 41 | etag_bytes = etag.encode('utf-8') 42 | etag_hash = sha256(etag_bytes) 43 | filename += '.' + etag_hash.hexdigest() 44 | 45 | return filename 46 | 47 | 48 | def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]: 49 | """ 50 | Return the url and etag (which may be ``None``) stored for `filename`. 51 | Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist. 52 | """ 53 | if cache_dir is None: 54 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 55 | if isinstance(cache_dir, Path): 56 | cache_dir = str(cache_dir) 57 | 58 | cache_path = os.path.join(cache_dir, filename) 59 | if not os.path.exists(cache_path): 60 | raise FileNotFoundError("file {} not found".format(cache_path)) 61 | 62 | meta_path = cache_path + '.json' 63 | if not os.path.exists(meta_path): 64 | raise FileNotFoundError("file {} not found".format(meta_path)) 65 | 66 | with open(meta_path) as meta_file: 67 | metadata = json.load(meta_file) 68 | url = metadata['url'] 69 | etag = metadata['etag'] 70 | 71 | return url, etag 72 | 73 | 74 | def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str: 75 | """ 76 | Given something that might be a URL (or might be a local path), 77 | determine which. If it's a URL, download the file and cache it, and 78 | return the path to the cached file. If it's already a local path, 79 | make sure the file exists and then return the path. 80 | """ 81 | if cache_dir is None: 82 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 83 | if isinstance(url_or_filename, Path): 84 | url_or_filename = str(url_or_filename) 85 | if isinstance(cache_dir, Path): 86 | cache_dir = str(cache_dir) 87 | 88 | parsed = urlparse(url_or_filename) 89 | 90 | if parsed.scheme in ('http', 'https', 's3'): 91 | # URL, so get it from the cache (downloading if necessary) 92 | return get_from_cache(url_or_filename, cache_dir) 93 | elif os.path.exists(url_or_filename): 94 | # File, and it exists. 95 | return url_or_filename 96 | elif parsed.scheme == '': 97 | # File, but it doesn't exist. 98 | raise FileNotFoundError("file {} not found".format(url_or_filename)) 99 | else: 100 | # Something unknown 101 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 102 | 103 | 104 | def split_s3_path(url: str) -> Tuple[str, str]: 105 | """Split a full s3 path into the bucket name and path.""" 106 | parsed = urlparse(url) 107 | if not parsed.netloc or not parsed.path: 108 | raise ValueError("bad s3 path {}".format(url)) 109 | bucket_name = parsed.netloc 110 | s3_path = parsed.path 111 | # Remove '/' at beginning of path. 112 | if s3_path.startswith("/"): 113 | s3_path = s3_path[1:] 114 | return bucket_name, s3_path 115 | 116 | 117 | def s3_request(func: Callable): 118 | """ 119 | Wrapper function for s3 requests in order to create more helpful error 120 | messages. 121 | """ 122 | 123 | @wraps(func) 124 | def wrapper(url: str, *args, **kwargs): 125 | try: 126 | return func(url, *args, **kwargs) 127 | except ClientError as exc: 128 | if int(exc.response["Error"]["Code"]) == 404: 129 | raise FileNotFoundError("file {} not found".format(url)) 130 | else: 131 | raise 132 | 133 | return wrapper 134 | 135 | 136 | @s3_request 137 | def s3_etag(url: str) -> Optional[str]: 138 | """Check ETag on S3 object.""" 139 | s3_resource = boto3.resource("s3") 140 | bucket_name, s3_path = split_s3_path(url) 141 | s3_object = s3_resource.Object(bucket_name, s3_path) 142 | return s3_object.e_tag 143 | 144 | 145 | @s3_request 146 | def s3_get(url: str, temp_file: IO) -> None: 147 | """Pull a file directly from S3.""" 148 | s3_resource = boto3.resource("s3") 149 | bucket_name, s3_path = split_s3_path(url) 150 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 151 | 152 | 153 | def http_get(url: str, temp_file: IO) -> None: 154 | req = requests.get(url, stream=True) 155 | content_length = req.headers.get('Content-Length') 156 | total = int(content_length) if content_length is not None else None 157 | progress = tqdm(unit="B", total=total) 158 | for chunk in req.iter_content(chunk_size=1024): 159 | if chunk: # filter out keep-alive new chunks 160 | progress.update(len(chunk)) 161 | temp_file.write(chunk) 162 | progress.close() 163 | 164 | 165 | def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str: 166 | """ 167 | Given a URL, look for the corresponding dataset in the local cache. 168 | If it's not there, download it. Then return the path to the cached file. 169 | """ 170 | if cache_dir is None: 171 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 172 | if isinstance(cache_dir, Path): 173 | cache_dir = str(cache_dir) 174 | 175 | os.makedirs(cache_dir, exist_ok=True) 176 | 177 | # Get eTag to add to filename, if it exists. 178 | if url.startswith("s3://"): 179 | etag = s3_etag(url) 180 | else: 181 | response = requests.head(url, allow_redirects=True) 182 | if response.status_code != 200: 183 | raise IOError("HEAD request failed for url {} with status code {}" 184 | .format(url, response.status_code)) 185 | etag = response.headers.get("ETag") 186 | 187 | filename = url_to_filename(url, etag) 188 | 189 | # get cache path to put the file 190 | cache_path = os.path.join(cache_dir, filename) 191 | 192 | if not os.path.exists(cache_path): 193 | # Download to temporary file, then copy to cache dir once finished. 194 | # Otherwise you get corrupt cache entries if the download gets interrupted. 195 | with tempfile.NamedTemporaryFile() as temp_file: 196 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 197 | 198 | # GET file object 199 | if url.startswith("s3://"): 200 | s3_get(url, temp_file) 201 | else: 202 | http_get(url, temp_file) 203 | 204 | # we are copying the file before closing it, so flush to avoid truncation 205 | temp_file.flush() 206 | # shutil.copyfileobj() starts at the current position, so go to the start 207 | temp_file.seek(0) 208 | 209 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 210 | with open(cache_path, 'wb') as cache_file: 211 | shutil.copyfileobj(temp_file, cache_file) 212 | 213 | logger.info("creating metadata file for %s", cache_path) 214 | meta = {'url': url, 'etag': etag} 215 | meta_path = cache_path + '.json' 216 | with open(meta_path, 'w') as meta_file: 217 | json.dump(meta, meta_file) 218 | 219 | logger.info("removing temp file %s", temp_file.name) 220 | 221 | return cache_path 222 | 223 | 224 | def read_set_from_file(filename: str) -> Set[str]: 225 | ''' 226 | Extract a de-duped collection (set) of text from a file. 227 | Expected file format is one item per line. 228 | ''' 229 | collection = set() 230 | with open(filename, 'r', encoding='utf-8') as file_: 231 | for line in file_: 232 | collection.add(line.rstrip()) 233 | return collection 234 | 235 | 236 | def get_file_extension(path: str, dot=True, lower: bool = True): 237 | ext = os.path.splitext(path)[1] 238 | ext = ext if dot else ext[1:] 239 | return ext.lower() if lower else ext 240 | -------------------------------------------------------------------------------- /question_generator/text2text/pytorch_pretrained_bert/loss.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.nn.modules.loss import _Loss 10 | 11 | 12 | class LabelSmoothingLoss(_Loss): 13 | """ 14 | With label smoothing, 15 | KL-divergence between q_{smoothed ground truth prob.}(w) 16 | and p_{prob. computed by model}(w) is minimized. 17 | """ 18 | 19 | def __init__(self, label_smoothing=0, tgt_vocab_size=0, ignore_index=0, size_average=None, reduce=None, reduction='mean'): 20 | assert 0.0 < label_smoothing <= 1.0 21 | self.ignore_index = ignore_index 22 | super(LabelSmoothingLoss, self).__init__( 23 | size_average=size_average, reduce=reduce, reduction=reduction) 24 | 25 | assert label_smoothing > 0 26 | assert tgt_vocab_size > 0 27 | 28 | smoothing_value = label_smoothing / (tgt_vocab_size - 2) 29 | one_hot = torch.full((tgt_vocab_size,), smoothing_value) 30 | one_hot[self.ignore_index] = 0 31 | self.register_buffer('one_hot', one_hot.unsqueeze(0)) 32 | self.confidence = 1.0 - label_smoothing 33 | self.tgt_vocab_size = tgt_vocab_size 34 | 35 | def forward(self, output, target): 36 | """ 37 | output (FloatTensor): batch_size * num_pos * n_classes 38 | target (LongTensor): batch_size * num_pos 39 | """ 40 | assert self.tgt_vocab_size == output.size(2) 41 | batch_size, num_pos = target.size(0), target.size(1) 42 | output = output.view(-1, self.tgt_vocab_size) 43 | target = target.view(-1) 44 | model_prob = self.one_hot.repeat(target.size(0), 1) 45 | model_prob.scatter_(1, target.unsqueeze(1), self.confidence) 46 | model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) 47 | 48 | return F.kl_div(output, model_prob, reduction='none').view(batch_size, num_pos, -1).sum(2) 49 | -------------------------------------------------------------------------------- /question_generator/text2text/pytorch_pretrained_bert/modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """PyTorch BERT model.""" 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import os 9 | import copy 10 | import json 11 | import math 12 | import logging 13 | import tarfile 14 | import tempfile 15 | import shutil 16 | import numpy as np 17 | from scipy.stats import truncnorm 18 | 19 | import torch 20 | from torch import nn 21 | from torch.nn import CrossEntropyLoss, MSELoss 22 | import torch.nn.functional as F 23 | 24 | from .file_utils import cached_path 25 | from .loss import LabelSmoothingLoss 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | PRETRAINED_MODEL_ARCHIVE_MAP = { 30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", 31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", 32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", 33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", 34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", 35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", 36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", 37 | } 38 | CONFIG_NAME = 'bert_config.json' 39 | WEIGHTS_NAME = 'pytorch_model.bin' 40 | 41 | 42 | def gelu(x): 43 | """Implementation of the gelu activation function. 44 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 45 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 46 | """ 47 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 48 | 49 | 50 | def swish(x): 51 | return x * torch.sigmoid(x) 52 | 53 | 54 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 55 | 56 | 57 | class BertConfig(object): 58 | """Configuration class to store the configuration of a `BertModel`. 59 | """ 60 | 61 | def __init__(self, 62 | vocab_size_or_config_json_file, 63 | hidden_size=768, 64 | num_hidden_layers=12, 65 | num_attention_heads=12, 66 | intermediate_size=3072, 67 | hidden_act="gelu", 68 | hidden_dropout_prob=0.1, 69 | attention_probs_dropout_prob=0.1, 70 | max_position_embeddings=512, 71 | type_vocab_size=2, 72 | relax_projection=0, 73 | new_pos_ids=False, 74 | initializer_range=0.02, 75 | task_idx=None, 76 | fp32_embedding=False, 77 | ffn_type=0, 78 | label_smoothing=None, 79 | num_qkv=0, 80 | seg_emb=False): 81 | """Constructs BertConfig. 82 | 83 | Args: 84 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. 85 | hidden_size: Size of the encoder layers and the pooler layer. 86 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 87 | num_attention_heads: Number of attention heads for each attention layer in 88 | the Transformer encoder. 89 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 90 | layer in the Transformer encoder. 91 | hidden_act: The non-linear activation function (function or string) in the 92 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 93 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 94 | layers in the embeddings, encoder, and pooler. 95 | attention_probs_dropout_prob: The dropout ratio for the attention 96 | probabilities. 97 | max_position_embeddings: The maximum sequence length that this model might 98 | ever be used with. Typically set this to something large just in case 99 | (e.g., 512 or 1024 or 2048). 100 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 101 | `BertModel`. 102 | initializer_range: The sttdev of the truncated_normal_initializer for 103 | initializing all weight matrices. 104 | """ 105 | if isinstance(vocab_size_or_config_json_file, str): 106 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 107 | json_config = json.loads(reader.read()) 108 | for key, value in json_config.items(): 109 | self.__dict__[key] = value 110 | elif isinstance(vocab_size_or_config_json_file, int): 111 | self.vocab_size = vocab_size_or_config_json_file 112 | self.hidden_size = hidden_size 113 | self.num_hidden_layers = num_hidden_layers 114 | self.num_attention_heads = num_attention_heads 115 | self.hidden_act = hidden_act 116 | self.intermediate_size = intermediate_size 117 | self.hidden_dropout_prob = hidden_dropout_prob 118 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 119 | self.max_position_embeddings = max_position_embeddings 120 | self.type_vocab_size = type_vocab_size 121 | self.relax_projection = relax_projection 122 | self.new_pos_ids = new_pos_ids 123 | self.initializer_range = initializer_range 124 | self.task_idx = task_idx 125 | self.fp32_embedding = fp32_embedding 126 | self.ffn_type = ffn_type 127 | self.label_smoothing = label_smoothing 128 | self.num_qkv = num_qkv 129 | self.seg_emb = seg_emb 130 | else: 131 | raise ValueError("First argument must be either a vocabulary size (int)" 132 | "or the path to a pretrained model config file (str)") 133 | 134 | @classmethod 135 | def from_dict(cls, json_object): 136 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 137 | config = BertConfig(vocab_size_or_config_json_file=-1) 138 | for key, value in json_object.items(): 139 | config.__dict__[key] = value 140 | return config 141 | 142 | @classmethod 143 | def from_json_file(cls, json_file): 144 | """Constructs a `BertConfig` from a json file of parameters.""" 145 | with open(json_file, "r", encoding='utf-8') as reader: 146 | text = reader.read() 147 | return cls.from_dict(json.loads(text)) 148 | 149 | def __repr__(self): 150 | return str(self.to_json_string()) 151 | 152 | def to_dict(self): 153 | """Serializes this instance to a Python dictionary.""" 154 | output = copy.deepcopy(self.__dict__) 155 | return output 156 | 157 | def to_json_string(self): 158 | """Serializes this instance to a JSON string.""" 159 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 160 | 161 | 162 | try: 163 | from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm 164 | except ImportError: 165 | print("Better speed can be achieved with apex installed from https://www.github.com/artitw/apex.") 166 | 167 | class BertLayerNorm(nn.Module): 168 | def __init__(self, hidden_size, eps=1e-5): 169 | """Construct a layernorm module in the TF style (epsilon inside the square root). 170 | """ 171 | super(BertLayerNorm, self).__init__() 172 | self.weight = nn.Parameter(torch.ones(hidden_size)) 173 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 174 | self.variance_epsilon = eps 175 | 176 | def forward(self, x): 177 | u = x.mean(-1, keepdim=True) 178 | s = (x - u).pow(2).mean(-1, keepdim=True) 179 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 180 | return self.weight * x + self.bias 181 | 182 | 183 | class PositionalEmbedding(nn.Module): 184 | def __init__(self, demb): 185 | super(PositionalEmbedding, self).__init__() 186 | 187 | self.demb = demb 188 | 189 | inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) 190 | self.register_buffer('inv_freq', inv_freq) 191 | 192 | def forward(self, pos_seq, bsz=None): 193 | sinusoid_inp = torch.ger(pos_seq, self.inv_freq) 194 | pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) 195 | 196 | if bsz is not None: 197 | return pos_emb[:, None, :].expand(-1, bsz, -1) 198 | else: 199 | return pos_emb[:, None, :] 200 | 201 | 202 | class BertEmbeddings(nn.Module): 203 | """Construct the embeddings from word, position and token_type embeddings. 204 | """ 205 | 206 | def __init__(self, config): 207 | super(BertEmbeddings, self).__init__() 208 | self.word_embeddings = nn.Embedding( 209 | config.vocab_size, config.hidden_size) 210 | self.token_type_embeddings = nn.Embedding( 211 | config.type_vocab_size, config.hidden_size) 212 | if hasattr(config, 'fp32_embedding'): 213 | self.fp32_embedding = config.fp32_embedding 214 | else: 215 | self.fp32_embedding = False 216 | 217 | if hasattr(config, 'new_pos_ids') and config.new_pos_ids: 218 | self.num_pos_emb = 4 219 | else: 220 | self.num_pos_emb = 1 221 | self.position_embeddings = nn.Embedding( 222 | config.max_position_embeddings, config.hidden_size*self.num_pos_emb) 223 | 224 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 225 | # any TensorFlow checkpoint file 226 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) 227 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 228 | 229 | def forward(self, input_ids, token_type_ids=None, position_ids=None, task_idx=None): 230 | seq_length = input_ids.size(1) 231 | if position_ids is None: 232 | position_ids = torch.arange( 233 | seq_length, dtype=torch.long, device=input_ids.device) 234 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 235 | if token_type_ids is None: 236 | token_type_ids = torch.zeros_like(input_ids) 237 | 238 | words_embeddings = self.word_embeddings(input_ids) 239 | position_embeddings = self.position_embeddings(position_ids) 240 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 241 | 242 | if self.num_pos_emb > 1: 243 | num_batch = position_embeddings.size(0) 244 | num_pos = position_embeddings.size(1) 245 | position_embeddings = position_embeddings.view( 246 | num_batch, num_pos, self.num_pos_emb, -1)[torch.arange(0, num_batch).long(), :, task_idx, :] 247 | 248 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 249 | if self.fp32_embedding: 250 | embeddings = embeddings.half() 251 | embeddings = self.LayerNorm(embeddings) 252 | embeddings = self.dropout(embeddings) 253 | return embeddings 254 | 255 | 256 | class BertSelfAttention(nn.Module): 257 | def __init__(self, config): 258 | super(BertSelfAttention, self).__init__() 259 | if config.hidden_size % config.num_attention_heads != 0: 260 | raise ValueError( 261 | "The hidden size (%d) is not a multiple of the number of attention " 262 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 263 | self.num_attention_heads = config.num_attention_heads 264 | self.attention_head_size = int( 265 | config.hidden_size / config.num_attention_heads) 266 | self.all_head_size = self.num_attention_heads * self.attention_head_size 267 | 268 | if hasattr(config, 'num_qkv') and (config.num_qkv > 1): 269 | self.num_qkv = config.num_qkv 270 | else: 271 | self.num_qkv = 1 272 | 273 | self.query = nn.Linear( 274 | config.hidden_size, self.all_head_size*self.num_qkv) 275 | self.key = nn.Linear(config.hidden_size, 276 | self.all_head_size*self.num_qkv) 277 | self.value = nn.Linear( 278 | config.hidden_size, self.all_head_size*self.num_qkv) 279 | 280 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 281 | 282 | self.uni_debug_flag = True if os.getenv( 283 | 'UNI_DEBUG_FLAG', '') else False 284 | if self.uni_debug_flag: 285 | self.register_buffer('debug_attention_probs', 286 | torch.zeros((512, 512))) 287 | if hasattr(config, 'seg_emb') and config.seg_emb: 288 | self.b_q_s = nn.Parameter(torch.zeros( 289 | 1, self.num_attention_heads, 1, self.attention_head_size)) 290 | self.seg_emb = nn.Embedding( 291 | config.type_vocab_size, self.all_head_size) 292 | else: 293 | self.b_q_s = None 294 | self.seg_emb = None 295 | 296 | def transpose_for_scores(self, x, mask_qkv=None): 297 | if self.num_qkv > 1: 298 | sz = x.size()[:-1] + (self.num_qkv, 299 | self.num_attention_heads, self.all_head_size) 300 | # (batch, pos, num_qkv, head, head_hid) 301 | x = x.view(*sz) 302 | if mask_qkv is None: 303 | x = x[:, :, 0, :, :] 304 | elif isinstance(mask_qkv, int): 305 | x = x[:, :, mask_qkv, :, :] 306 | else: 307 | # mask_qkv: (batch, pos) 308 | if mask_qkv.size(1) > sz[1]: 309 | mask_qkv = mask_qkv[:, :sz[1]] 310 | # -> x: (batch, pos, head, head_hid) 311 | x = x.gather(2, mask_qkv.view(sz[0], sz[1], 1, 1, 1).expand( 312 | sz[0], sz[1], 1, sz[3], sz[4])).squeeze(2) 313 | else: 314 | sz = x.size()[:-1] + (self.num_attention_heads, 315 | self.attention_head_size) 316 | # (batch, pos, head, head_hid) 317 | x = x.view(*sz) 318 | # (batch, head, pos, head_hid) 319 | return x.permute(0, 2, 1, 3) 320 | 321 | def forward(self, hidden_states, attention_mask, history_states=None, mask_qkv=None, seg_ids=None): 322 | if history_states is None: 323 | mixed_query_layer = self.query(hidden_states) 324 | mixed_key_layer = self.key(hidden_states) 325 | mixed_value_layer = self.value(hidden_states) 326 | else: 327 | x_states = torch.cat((history_states, hidden_states), dim=1) 328 | mixed_query_layer = self.query(hidden_states) 329 | mixed_key_layer = self.key(x_states) 330 | mixed_value_layer = self.value(x_states) 331 | 332 | query_layer = self.transpose_for_scores(mixed_query_layer, mask_qkv) 333 | key_layer = self.transpose_for_scores(mixed_key_layer, mask_qkv) 334 | value_layer = self.transpose_for_scores(mixed_value_layer, mask_qkv) 335 | 336 | # Take the dot product between "query" and "key" to get the raw attention scores. 337 | # (batch, head, pos, pos) 338 | attention_scores = torch.matmul( 339 | query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2)) 340 | 341 | if self.seg_emb is not None: 342 | seg_rep = self.seg_emb(seg_ids) 343 | # (batch, pos, head, head_hid) 344 | seg_rep = seg_rep.view(seg_rep.size(0), seg_rep.size( 345 | 1), self.num_attention_heads, self.attention_head_size) 346 | qs = torch.einsum('bnih,bjnh->bnij', 347 | query_layer+self.b_q_s, seg_rep) 348 | attention_scores = attention_scores + qs 349 | 350 | # attention_scores = attention_scores / math.sqrt(self.attention_head_size) 351 | 352 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 353 | attention_scores = attention_scores + attention_mask 354 | 355 | # Normalize the attention scores to probabilities. 356 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 357 | 358 | if self.uni_debug_flag: 359 | _pos = attention_probs.size(-1) 360 | self.debug_attention_probs[:_pos, :_pos].copy_( 361 | attention_probs[0].mean(0).view(_pos, _pos)) 362 | 363 | # This is actually dropping out entire tokens to attend to, which might 364 | # seem a bit unusual, but is taken from the original Transformer paper. 365 | attention_probs = self.dropout(attention_probs) 366 | 367 | context_layer = torch.matmul(attention_probs, value_layer) 368 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 369 | new_context_layer_shape = context_layer.size()[ 370 | :-2] + (self.all_head_size,) 371 | context_layer = context_layer.view(*new_context_layer_shape) 372 | return context_layer 373 | 374 | 375 | class BertSelfOutput(nn.Module): 376 | def __init__(self, config): 377 | super(BertSelfOutput, self).__init__() 378 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 379 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) 380 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 381 | 382 | def forward(self, hidden_states, input_tensor): 383 | hidden_states = self.dense(hidden_states) 384 | hidden_states = self.dropout(hidden_states) 385 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 386 | return hidden_states 387 | 388 | 389 | class BertAttention(nn.Module): 390 | def __init__(self, config): 391 | super(BertAttention, self).__init__() 392 | self.self = BertSelfAttention(config) 393 | self.output = BertSelfOutput(config) 394 | 395 | def forward(self, input_tensor, attention_mask, history_states=None, mask_qkv=None, seg_ids=None): 396 | self_output = self.self( 397 | input_tensor, attention_mask, history_states=history_states, mask_qkv=mask_qkv, seg_ids=seg_ids) 398 | attention_output = self.output(self_output, input_tensor) 399 | return attention_output 400 | 401 | 402 | class BertIntermediate(nn.Module): 403 | def __init__(self, config): 404 | super(BertIntermediate, self).__init__() 405 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 406 | self.intermediate_act_fn = ACT2FN[config.hidden_act] \ 407 | if isinstance(config.hidden_act, str) else config.hidden_act 408 | 409 | def forward(self, hidden_states): 410 | hidden_states = self.dense(hidden_states) 411 | hidden_states = self.intermediate_act_fn(hidden_states) 412 | return hidden_states 413 | 414 | 415 | class BertOutput(nn.Module): 416 | def __init__(self, config): 417 | super(BertOutput, self).__init__() 418 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 419 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) 420 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 421 | 422 | def forward(self, hidden_states, input_tensor): 423 | hidden_states = self.dense(hidden_states) 424 | hidden_states = self.dropout(hidden_states) 425 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 426 | return hidden_states 427 | 428 | 429 | class TransformerFFN(nn.Module): 430 | def __init__(self, config): 431 | super(TransformerFFN, self).__init__() 432 | self.ffn_type = config.ffn_type 433 | assert self.ffn_type in (1, 2) 434 | if self.ffn_type in (1, 2): 435 | self.wx0 = nn.Linear(config.hidden_size, config.hidden_size) 436 | if self.ffn_type in (2,): 437 | self.wx1 = nn.Linear(config.hidden_size, config.hidden_size) 438 | if self.ffn_type in (1, 2): 439 | self.output = nn.Linear(config.hidden_size, config.hidden_size) 440 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) 441 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 442 | 443 | def forward(self, x): 444 | if self.ffn_type in (1, 2): 445 | x0 = self.wx0(x) 446 | if self.ffn_type == 1: 447 | x1 = x 448 | elif self.ffn_type == 2: 449 | x1 = self.wx1(x) 450 | out = self.output(x0 * x1) 451 | out = self.dropout(out) 452 | out = self.LayerNorm(out + x) 453 | return out 454 | 455 | 456 | class BertLayer(nn.Module): 457 | def __init__(self, config): 458 | super(BertLayer, self).__init__() 459 | self.attention = BertAttention(config) 460 | self.ffn_type = config.ffn_type 461 | if self.ffn_type: 462 | self.ffn = TransformerFFN(config) 463 | else: 464 | self.intermediate = BertIntermediate(config) 465 | self.output = BertOutput(config) 466 | 467 | def forward(self, hidden_states, attention_mask, history_states=None, mask_qkv=None, seg_ids=None): 468 | attention_output = self.attention( 469 | hidden_states, attention_mask, history_states=history_states, mask_qkv=mask_qkv, seg_ids=seg_ids) 470 | if self.ffn_type: 471 | layer_output = self.ffn(attention_output) 472 | else: 473 | intermediate_output = self.intermediate(attention_output) 474 | layer_output = self.output(intermediate_output, attention_output) 475 | return layer_output 476 | 477 | 478 | class BertEncoder(nn.Module): 479 | def __init__(self, config): 480 | super(BertEncoder, self).__init__() 481 | layer = BertLayer(config) 482 | self.layer = nn.ModuleList([copy.deepcopy(layer) 483 | for _ in range(config.num_hidden_layers)]) 484 | 485 | def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, prev_embedding=None, prev_encoded_layers=None, mask_qkv=None, seg_ids=None): 486 | # history embedding and encoded layer must be simultanously given 487 | assert (prev_embedding is None) == (prev_encoded_layers is None) 488 | 489 | all_encoder_layers = [] 490 | if (prev_embedding is not None) and (prev_encoded_layers is not None): 491 | history_states = prev_embedding 492 | for i, layer_module in enumerate(self.layer): 493 | hidden_states = layer_module( 494 | hidden_states, attention_mask, history_states=history_states, mask_qkv=mask_qkv, seg_ids=seg_ids) 495 | if output_all_encoded_layers: 496 | all_encoder_layers.append(hidden_states) 497 | if prev_encoded_layers is not None: 498 | history_states = prev_encoded_layers[i] 499 | else: 500 | for layer_module in self.layer: 501 | hidden_states = layer_module( 502 | hidden_states, attention_mask, mask_qkv=mask_qkv, seg_ids=seg_ids) 503 | if output_all_encoded_layers: 504 | all_encoder_layers.append(hidden_states) 505 | if not output_all_encoded_layers: 506 | all_encoder_layers.append(hidden_states) 507 | return all_encoder_layers 508 | 509 | 510 | class BertPooler(nn.Module): 511 | def __init__(self, config): 512 | super(BertPooler, self).__init__() 513 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 514 | self.activation = nn.Tanh() 515 | 516 | def forward(self, hidden_states): 517 | # We "pool" the model by simply taking the hidden state corresponding 518 | # to the first token. 519 | first_token_tensor = hidden_states[:, 0] 520 | pooled_output = self.dense(first_token_tensor) 521 | pooled_output = self.activation(pooled_output) 522 | return pooled_output 523 | 524 | 525 | class BertPredictionHeadTransform(nn.Module): 526 | def __init__(self, config): 527 | super(BertPredictionHeadTransform, self).__init__() 528 | self.transform_act_fn = ACT2FN[config.hidden_act] \ 529 | if isinstance(config.hidden_act, str) else config.hidden_act 530 | hid_size = config.hidden_size 531 | if hasattr(config, 'relax_projection') and (config.relax_projection > 1): 532 | hid_size *= config.relax_projection 533 | self.dense = nn.Linear(config.hidden_size, hid_size) 534 | self.LayerNorm = BertLayerNorm(hid_size, eps=1e-5) 535 | 536 | def forward(self, hidden_states): 537 | hidden_states = self.dense(hidden_states) 538 | hidden_states = self.transform_act_fn(hidden_states) 539 | hidden_states = self.LayerNorm(hidden_states) 540 | return hidden_states 541 | 542 | 543 | class BertLMPredictionHead(nn.Module): 544 | def __init__(self, config, bert_model_embedding_weights): 545 | super(BertLMPredictionHead, self).__init__() 546 | self.transform = BertPredictionHeadTransform(config) 547 | 548 | # The output weights are the same as the input embeddings, but there is 549 | # an output-only bias for each token. 550 | self.decoder = nn.Linear(bert_model_embedding_weights.size(1), 551 | bert_model_embedding_weights.size(0), 552 | bias=False) 553 | self.decoder.weight = bert_model_embedding_weights 554 | self.bias = nn.Parameter(torch.zeros( 555 | bert_model_embedding_weights.size(0))) 556 | if hasattr(config, 'relax_projection') and (config.relax_projection > 1): 557 | self.relax_projection = config.relax_projection 558 | else: 559 | self.relax_projection = 0 560 | self.fp32_embedding = config.fp32_embedding 561 | 562 | def convert_to_type(tensor): 563 | if self.fp32_embedding: 564 | return tensor.half() 565 | else: 566 | return tensor 567 | self.type_converter = convert_to_type 568 | self.converted = False 569 | 570 | def forward(self, hidden_states, task_idx=None): 571 | if not self.converted: 572 | self.converted = True 573 | if self.fp32_embedding: 574 | self.transform.half() 575 | hidden_states = self.transform(self.type_converter(hidden_states)) 576 | if self.relax_projection > 1: 577 | num_batch = hidden_states.size(0) 578 | num_pos = hidden_states.size(1) 579 | # (batch, num_pos, relax_projection*hid) -> (batch, num_pos, relax_projection, hid) -> (batch, num_pos, hid) 580 | hidden_states = hidden_states.view( 581 | num_batch, num_pos, self.relax_projection, -1)[torch.arange(0, num_batch).long(), :, task_idx, :] 582 | if self.fp32_embedding: 583 | hidden_states = F.linear(self.type_converter(hidden_states), self.type_converter( 584 | self.decoder.weight), self.type_converter(self.bias)) 585 | else: 586 | hidden_states = self.decoder(hidden_states) + self.bias 587 | return hidden_states 588 | 589 | 590 | class BertOnlyMLMHead(nn.Module): 591 | def __init__(self, config, bert_model_embedding_weights): 592 | super(BertOnlyMLMHead, self).__init__() 593 | self.predictions = BertLMPredictionHead( 594 | config, bert_model_embedding_weights) 595 | 596 | def forward(self, sequence_output): 597 | prediction_scores = self.predictions(sequence_output) 598 | return prediction_scores 599 | 600 | 601 | class BertOnlyNSPHead(nn.Module): 602 | def __init__(self, config): 603 | super(BertOnlyNSPHead, self).__init__() 604 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 605 | 606 | def forward(self, pooled_output): 607 | seq_relationship_score = self.seq_relationship(pooled_output) 608 | return seq_relationship_score 609 | 610 | 611 | class BertPreTrainingHeads(nn.Module): 612 | def __init__(self, config, bert_model_embedding_weights, num_labels=2): 613 | super(BertPreTrainingHeads, self).__init__() 614 | self.predictions = BertLMPredictionHead( 615 | config, bert_model_embedding_weights) 616 | self.seq_relationship = nn.Linear(config.hidden_size, num_labels) 617 | 618 | def forward(self, sequence_output, pooled_output, task_idx=None): 619 | prediction_scores = self.predictions(sequence_output, task_idx) 620 | if pooled_output is None: 621 | seq_relationship_score = None 622 | else: 623 | seq_relationship_score = self.seq_relationship(pooled_output) 624 | return prediction_scores, seq_relationship_score 625 | 626 | 627 | class PreTrainedBertModel(nn.Module): 628 | """ An abstract class to handle weights initialization and 629 | a simple interface for dowloading and loading pretrained models. 630 | """ 631 | 632 | def __init__(self, config, *inputs, **kwargs): 633 | super(PreTrainedBertModel, self).__init__() 634 | if not isinstance(config, BertConfig): 635 | raise ValueError( 636 | "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " 637 | "To create a model from a Google pretrained model use " 638 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 639 | self.__class__.__name__, self.__class__.__name__ 640 | )) 641 | self.config = config 642 | 643 | def init_bert_weights(self, module): 644 | """ Initialize the weights. 645 | """ 646 | if isinstance(module, (nn.Linear, nn.Embedding)): 647 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 648 | elif isinstance(module, BertLayerNorm): 649 | module.bias.data.zero_() 650 | module.weight.data.fill_(1.0) 651 | if isinstance(module, nn.Linear) and module.bias is not None: 652 | module.bias.data.zero_() 653 | 654 | @classmethod 655 | def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs): 656 | """ 657 | Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict. 658 | Download and cache the pre-trained model file if needed. 659 | 660 | Params: 661 | pretrained_model_name: either: 662 | - a str with the name of a pre-trained model to load selected in the list of: 663 | . `bert-base-uncased` 664 | . `bert-large-uncased` 665 | . `bert-base-cased` 666 | . `bert-base-multilingual` 667 | . `bert-base-chinese` 668 | - a path or url to a pretrained model archive containing: 669 | . `bert_config.json` a configuration file for the model 670 | . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance 671 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 672 | state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models 673 | *inputs, **kwargs: additional input for the specific Bert class 674 | (ex: num_labels for BertForSequenceClassification) 675 | """ 676 | if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP: 677 | archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name] 678 | else: 679 | archive_file = pretrained_model_name 680 | # redirect to the cache, if necessary 681 | try: 682 | resolved_archive_file = cached_path( 683 | archive_file, cache_dir=cache_dir) 684 | except FileNotFoundError: 685 | logger.error( 686 | "Model name '{}' was not found in model name list ({}). " 687 | "We assumed '{}' was a path or url but couldn't find any file " 688 | "associated to this path or url.".format( 689 | pretrained_model_name, 690 | ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), 691 | archive_file)) 692 | return None 693 | if resolved_archive_file == archive_file: 694 | logger.info("loading archive file {}".format(archive_file)) 695 | else: 696 | logger.info("loading archive file {} from cache at {}".format( 697 | archive_file, resolved_archive_file)) 698 | tempdir = None 699 | if os.path.isdir(resolved_archive_file): 700 | serialization_dir = resolved_archive_file 701 | else: 702 | # Extract archive to temp dir 703 | tempdir = tempfile.mkdtemp() 704 | logger.info("extracting archive file {} to temp dir {}".format( 705 | resolved_archive_file, tempdir)) 706 | with tarfile.open(resolved_archive_file, 'r:gz') as archive: 707 | archive.extractall(tempdir) 708 | serialization_dir = tempdir 709 | # Load config 710 | if ('config_path' in kwargs) and kwargs['config_path']: 711 | config_file = kwargs['config_path'] 712 | else: 713 | config_file = os.path.join(serialization_dir, CONFIG_NAME) 714 | config = BertConfig.from_json_file(config_file) 715 | 716 | # define new type_vocab_size (there might be different numbers of segment ids) 717 | if 'type_vocab_size' in kwargs: 718 | config.type_vocab_size = kwargs['type_vocab_size'] 719 | # define new relax_projection 720 | if ('relax_projection' in kwargs) and kwargs['relax_projection']: 721 | config.relax_projection = kwargs['relax_projection'] 722 | # new position embedding 723 | if ('new_pos_ids' in kwargs) and kwargs['new_pos_ids']: 724 | config.new_pos_ids = kwargs['new_pos_ids'] 725 | # define new relax_projection 726 | if ('task_idx' in kwargs) and kwargs['task_idx']: 727 | config.task_idx = kwargs['task_idx'] 728 | # define new max position embedding for length expansion 729 | if ('max_position_embeddings' in kwargs) and kwargs['max_position_embeddings']: 730 | config.max_position_embeddings = kwargs['max_position_embeddings'] 731 | # use fp32 for embeddings 732 | if ('fp32_embedding' in kwargs) and kwargs['fp32_embedding']: 733 | config.fp32_embedding = kwargs['fp32_embedding'] 734 | # type of FFN in transformer blocks 735 | if ('ffn_type' in kwargs) and kwargs['ffn_type']: 736 | config.ffn_type = kwargs['ffn_type'] 737 | # label smoothing 738 | if ('label_smoothing' in kwargs) and kwargs['label_smoothing']: 739 | config.label_smoothing = kwargs['label_smoothing'] 740 | # dropout 741 | if ('hidden_dropout_prob' in kwargs) and kwargs['hidden_dropout_prob']: 742 | config.hidden_dropout_prob = kwargs['hidden_dropout_prob'] 743 | if ('attention_probs_dropout_prob' in kwargs) and kwargs['attention_probs_dropout_prob']: 744 | config.attention_probs_dropout_prob = kwargs['attention_probs_dropout_prob'] 745 | # different QKV 746 | if ('num_qkv' in kwargs) and kwargs['num_qkv']: 747 | config.num_qkv = kwargs['num_qkv'] 748 | # segment embedding for self-attention 749 | if ('seg_emb' in kwargs) and kwargs['seg_emb']: 750 | config.seg_emb = kwargs['seg_emb'] 751 | # initialize word embeddings 752 | _word_emb_map = None 753 | if ('word_emb_map' in kwargs) and kwargs['word_emb_map']: 754 | _word_emb_map = kwargs['word_emb_map'] 755 | 756 | logger.info("Model config {}".format(config)) 757 | 758 | # clean the arguments in kwargs 759 | for arg_clean in ('config_path', 'type_vocab_size', 'relax_projection', 'new_pos_ids', 'task_idx', 'max_position_embeddings', 'fp32_embedding', 'ffn_type', 'label_smoothing', 'hidden_dropout_prob', 'attention_probs_dropout_prob', 'num_qkv', 'seg_emb', 'word_emb_map'): 760 | if arg_clean in kwargs: 761 | del kwargs[arg_clean] 762 | 763 | # Instantiate model. 764 | model = cls(config, *inputs, **kwargs) 765 | if state_dict is None: 766 | weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) 767 | state_dict = torch.load(weights_path) 768 | 769 | old_keys = [] 770 | new_keys = [] 771 | for key in state_dict.keys(): 772 | new_key = None 773 | if 'gamma' in key: 774 | new_key = key.replace('gamma', 'weight') 775 | if 'beta' in key: 776 | new_key = key.replace('beta', 'bias') 777 | if new_key: 778 | old_keys.append(key) 779 | new_keys.append(new_key) 780 | for old_key, new_key in zip(old_keys, new_keys): 781 | state_dict[new_key] = state_dict.pop(old_key) 782 | 783 | # initialize new segment embeddings 784 | _k = 'bert.embeddings.token_type_embeddings.weight' 785 | if (_k in state_dict) and (config.type_vocab_size != state_dict[_k].shape[0]): 786 | logger.info("config.type_vocab_size != state_dict[bert.embeddings.token_type_embeddings.weight] ({0} != {1})".format( 787 | config.type_vocab_size, state_dict[_k].shape[0])) 788 | if config.type_vocab_size > state_dict[_k].shape[0]: 789 | # state_dict[_k].data = state_dict[_k].data.resize_(config.type_vocab_size, state_dict[_k].shape[1]) 790 | state_dict[_k].resize_( 791 | config.type_vocab_size, state_dict[_k].shape[1]) 792 | # L2R 793 | if config.type_vocab_size >= 3: 794 | state_dict[_k].data[2, :].copy_(state_dict[_k].data[0, :]) 795 | # R2L 796 | if config.type_vocab_size >= 4: 797 | state_dict[_k].data[3, :].copy_(state_dict[_k].data[0, :]) 798 | # S2S 799 | if config.type_vocab_size >= 6: 800 | state_dict[_k].data[4, :].copy_(state_dict[_k].data[0, :]) 801 | state_dict[_k].data[5, :].copy_(state_dict[_k].data[1, :]) 802 | if config.type_vocab_size >= 7: 803 | state_dict[_k].data[6, :].copy_(state_dict[_k].data[1, :]) 804 | elif config.type_vocab_size < state_dict[_k].shape[0]: 805 | state_dict[_k].data = state_dict[_k].data[:config.type_vocab_size, :] 806 | 807 | _k = 'bert.embeddings.position_embeddings.weight' 808 | n_config_pos_emb = 4 if config.new_pos_ids else 1 809 | if (_k in state_dict) and (n_config_pos_emb*config.hidden_size != state_dict[_k].shape[1]): 810 | logger.info("n_config_pos_emb*config.hidden_size != state_dict[bert.embeddings.position_embeddings.weight] ({0}*{1} != {2})".format( 811 | n_config_pos_emb, config.hidden_size, state_dict[_k].shape[1])) 812 | assert state_dict[_k].shape[1] % config.hidden_size == 0 813 | n_state_pos_emb = int(state_dict[_k].shape[1]/config.hidden_size) 814 | assert (n_state_pos_emb == 1) != (n_config_pos_emb == 815 | 1), "!!!!n_state_pos_emb == 1 xor n_config_pos_emb == 1!!!!" 816 | if n_state_pos_emb == 1: 817 | state_dict[_k].data = state_dict[_k].data.unsqueeze(1).repeat( 818 | 1, n_config_pos_emb, 1).reshape((config.max_position_embeddings, n_config_pos_emb*config.hidden_size)) 819 | elif n_config_pos_emb == 1: 820 | if hasattr(config, 'task_idx') and (config.task_idx is not None) and (0 <= config.task_idx <= 3): 821 | _task_idx = config.task_idx 822 | else: 823 | _task_idx = 0 824 | state_dict[_k].data = state_dict[_k].data.view( 825 | config.max_position_embeddings, n_state_pos_emb, config.hidden_size).select(1, _task_idx) 826 | 827 | # initialize new position embeddings 828 | _k = 'bert.embeddings.position_embeddings.weight' 829 | if _k in state_dict and config.max_position_embeddings != state_dict[_k].shape[0]: 830 | logger.info("config.max_position_embeddings != state_dict[bert.embeddings.position_embeddings.weight] ({0} - {1})".format( 831 | config.max_position_embeddings, state_dict[_k].shape[0])) 832 | if config.max_position_embeddings > state_dict[_k].shape[0]: 833 | old_size = state_dict[_k].shape[0] 834 | # state_dict[_k].data = state_dict[_k].data.resize_(config.max_position_embeddings, state_dict[_k].shape[1]) 835 | state_dict[_k].resize_( 836 | config.max_position_embeddings, state_dict[_k].shape[1]) 837 | start = old_size 838 | while start < config.max_position_embeddings: 839 | chunk_size = min( 840 | old_size, config.max_position_embeddings - start) 841 | state_dict[_k].data[start:start+chunk_size, 842 | :].copy_(state_dict[_k].data[:chunk_size, :]) 843 | start += chunk_size 844 | elif config.max_position_embeddings < state_dict[_k].shape[0]: 845 | state_dict[_k].data = state_dict[_k].data[:config.max_position_embeddings, :] 846 | 847 | # initialize relax projection 848 | _k = 'cls.predictions.transform.dense.weight' 849 | n_config_relax = 1 if (config.relax_projection < 850 | 1) else config.relax_projection 851 | if (_k in state_dict) and (n_config_relax*config.hidden_size != state_dict[_k].shape[0]): 852 | logger.info("n_config_relax*config.hidden_size != state_dict[cls.predictions.transform.dense.weight] ({0}*{1} != {2})".format( 853 | n_config_relax, config.hidden_size, state_dict[_k].shape[0])) 854 | assert state_dict[_k].shape[0] % config.hidden_size == 0 855 | n_state_relax = int(state_dict[_k].shape[0]/config.hidden_size) 856 | assert (n_state_relax == 1) != (n_config_relax == 857 | 1), "!!!!n_state_relax == 1 xor n_config_relax == 1!!!!" 858 | if n_state_relax == 1: 859 | _k = 'cls.predictions.transform.dense.weight' 860 | state_dict[_k].data = state_dict[_k].data.unsqueeze(0).repeat( 861 | n_config_relax, 1, 1).reshape((n_config_relax*config.hidden_size, config.hidden_size)) 862 | for _k in ('cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias'): 863 | state_dict[_k].data = state_dict[_k].data.unsqueeze( 864 | 0).repeat(n_config_relax, 1).view(-1) 865 | elif n_config_relax == 1: 866 | if hasattr(config, 'task_idx') and (config.task_idx is not None) and (0 <= config.task_idx <= 3): 867 | _task_idx = config.task_idx 868 | else: 869 | _task_idx = 0 870 | _k = 'cls.predictions.transform.dense.weight' 871 | state_dict[_k].data = state_dict[_k].data.view( 872 | n_state_relax, config.hidden_size, config.hidden_size).select(0, _task_idx) 873 | for _k in ('cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias'): 874 | state_dict[_k].data = state_dict[_k].data.view( 875 | n_state_relax, config.hidden_size).select(0, _task_idx) 876 | 877 | # initialize QKV 878 | _all_head_size = config.num_attention_heads * \ 879 | int(config.hidden_size / config.num_attention_heads) 880 | n_config_num_qkv = 1 if (config.num_qkv < 1) else config.num_qkv 881 | for qkv_name in ('query', 'key', 'value'): 882 | _k = 'bert.encoder.layer.0.attention.self.{0}.weight'.format( 883 | qkv_name) 884 | if (_k in state_dict) and (n_config_num_qkv*_all_head_size != state_dict[_k].shape[0]): 885 | logger.info("n_config_num_qkv*_all_head_size != state_dict[_k] ({0}*{1} != {2})".format( 886 | n_config_num_qkv, _all_head_size, state_dict[_k].shape[0])) 887 | for layer_idx in range(config.num_hidden_layers): 888 | _k = 'bert.encoder.layer.{0}.attention.self.{1}.weight'.format( 889 | layer_idx, qkv_name) 890 | assert state_dict[_k].shape[0] % _all_head_size == 0 891 | n_state_qkv = int(state_dict[_k].shape[0]/_all_head_size) 892 | assert (n_state_qkv == 1) != (n_config_num_qkv == 893 | 1), "!!!!n_state_qkv == 1 xor n_config_num_qkv == 1!!!!" 894 | if n_state_qkv == 1: 895 | _k = 'bert.encoder.layer.{0}.attention.self.{1}.weight'.format( 896 | layer_idx, qkv_name) 897 | state_dict[_k].data = state_dict[_k].data.unsqueeze(0).repeat( 898 | n_config_num_qkv, 1, 1).reshape((n_config_num_qkv*_all_head_size, _all_head_size)) 899 | _k = 'bert.encoder.layer.{0}.attention.self.{1}.bias'.format( 900 | layer_idx, qkv_name) 901 | state_dict[_k].data = state_dict[_k].data.unsqueeze( 902 | 0).repeat(n_config_num_qkv, 1).view(-1) 903 | elif n_config_num_qkv == 1: 904 | if hasattr(config, 'task_idx') and (config.task_idx is not None) and (0 <= config.task_idx <= 3): 905 | _task_idx = config.task_idx 906 | else: 907 | _task_idx = 0 908 | assert _task_idx != 3, "[INVALID] _task_idx=3: n_config_num_qkv=1 (should be 2)" 909 | if _task_idx == 0: 910 | _qkv_idx = 0 911 | else: 912 | _qkv_idx = 1 913 | _k = 'bert.encoder.layer.{0}.attention.self.{1}.weight'.format( 914 | layer_idx, qkv_name) 915 | state_dict[_k].data = state_dict[_k].data.view( 916 | n_state_qkv, _all_head_size, _all_head_size).select(0, _qkv_idx) 917 | _k = 'bert.encoder.layer.{0}.attention.self.{1}.bias'.format( 918 | layer_idx, qkv_name) 919 | state_dict[_k].data = state_dict[_k].data.view( 920 | n_state_qkv, _all_head_size).select(0, _qkv_idx) 921 | 922 | if _word_emb_map: 923 | _k = 'bert.embeddings.word_embeddings.weight' 924 | for _tgt, _src in _word_emb_map: 925 | state_dict[_k].data[_tgt, :].copy_( 926 | state_dict[_k].data[_src, :]) 927 | 928 | missing_keys = [] 929 | unexpected_keys = [] 930 | error_msgs = [] 931 | # copy state_dict so _load_from_state_dict can modify it 932 | metadata = getattr(state_dict, '_metadata', None) 933 | state_dict = state_dict.copy() 934 | if metadata is not None: 935 | state_dict._metadata = metadata 936 | 937 | def load(module, prefix=''): 938 | local_metadata = {} if metadata is None else metadata.get( 939 | prefix[:-1], {}) 940 | module._load_from_state_dict( 941 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 942 | for name, child in module._modules.items(): 943 | if child is not None: 944 | load(child, prefix + name + '.') 945 | load(model, prefix='' if hasattr(model, 'bert') else 'bert.') 946 | model.missing_keys = missing_keys 947 | if len(missing_keys) > 0: 948 | logger.info("Weights of {} not initialized from pretrained model: {}".format( 949 | model.__class__.__name__, missing_keys)) 950 | if len(unexpected_keys) > 0: 951 | logger.info("Weights from pretrained model not used in {}: {}".format( 952 | model.__class__.__name__, unexpected_keys)) 953 | if len(error_msgs) > 0: 954 | logger.info('\n'.join(error_msgs)) 955 | if tempdir: 956 | # Clean up temp dir 957 | shutil.rmtree(tempdir) 958 | return model 959 | 960 | 961 | class BertModel(PreTrainedBertModel): 962 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 963 | 964 | Params: 965 | config: a BertConfig class instance with the configuration to build a new model 966 | 967 | Inputs: 968 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 969 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 970 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 971 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 972 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 973 | a `sentence B` token (see BERT paper for more details). 974 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 975 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 976 | input sequence length in the current batch. It's the mask that we typically use for attention when 977 | a batch has varying length sentences. 978 | `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. 979 | 980 | Outputs: Tuple of (encoded_layers, pooled_output) 981 | `encoded_layers`: controled by `output_all_encoded_layers` argument: 982 | - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end 983 | of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each 984 | encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], 985 | - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding 986 | to the last attention block of shape [batch_size, sequence_length, hidden_size], 987 | `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a 988 | classifier pretrained on top of the hidden state associated to the first character of the 989 | input (`CLF`) to train on the Next-Sentence task (see BERT's paper). 990 | ``` 991 | """ 992 | 993 | def __init__(self, config): 994 | super(BertModel, self).__init__(config) 995 | self.embeddings = BertEmbeddings(config) 996 | self.encoder = BertEncoder(config) 997 | self.pooler = BertPooler(config) 998 | self.apply(self.init_bert_weights) 999 | 1000 | def rescale_some_parameters(self): 1001 | for layer_id, layer in enumerate(self.encoder.layer): 1002 | layer.attention.output.dense.weight.data.div_( 1003 | math.sqrt(2.0*(layer_id + 1))) 1004 | layer.output.dense.weight.data.div_(math.sqrt(2.0*(layer_id + 1))) 1005 | 1006 | def get_extended_attention_mask(self, input_ids, token_type_ids, attention_mask): 1007 | if attention_mask is None: 1008 | attention_mask = torch.ones_like(input_ids) 1009 | if token_type_ids is None: 1010 | token_type_ids = torch.zeros_like(input_ids) 1011 | 1012 | # We create a 3D attention mask from a 2D tensor mask. 1013 | # Sizes are [batch_size, 1, 1, to_seq_length] 1014 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 1015 | # this attention mask is more simple than the triangular masking of causal attention 1016 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 1017 | if attention_mask.dim() == 2: 1018 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 1019 | elif attention_mask.dim() == 3: 1020 | extended_attention_mask = attention_mask.unsqueeze(1) 1021 | else: 1022 | raise NotImplementedError 1023 | 1024 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 1025 | # masked positions, this operation will create a tensor which is 0.0 for 1026 | # positions we want to attend and -10000.0 for masked positions. 1027 | # Since we are adding it to the raw scores before the softmax, this is 1028 | # effectively the same as removing these entirely. 1029 | extended_attention_mask = extended_attention_mask.to( 1030 | dtype=next(self.parameters()).dtype) # fp16 compatibility 1031 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 1032 | return extended_attention_mask 1033 | 1034 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, mask_qkv=None, task_idx=None): 1035 | extended_attention_mask = self.get_extended_attention_mask( 1036 | input_ids, token_type_ids, attention_mask) 1037 | 1038 | embedding_output = self.embeddings( 1039 | input_ids, token_type_ids, task_idx=task_idx) 1040 | encoded_layers = self.encoder(embedding_output, extended_attention_mask, 1041 | output_all_encoded_layers=output_all_encoded_layers, mask_qkv=mask_qkv, seg_ids=token_type_ids) 1042 | sequence_output = encoded_layers[-1] 1043 | pooled_output = self.pooler(sequence_output) 1044 | if not output_all_encoded_layers: 1045 | encoded_layers = encoded_layers[-1] 1046 | return encoded_layers, pooled_output 1047 | 1048 | 1049 | class BertModelIncr(BertModel): 1050 | def __init__(self, config): 1051 | super(BertModelIncr, self).__init__(config) 1052 | 1053 | def forward(self, input_ids, token_type_ids, position_ids, attention_mask, output_all_encoded_layers=True, prev_embedding=None, 1054 | prev_encoded_layers=None, mask_qkv=None, task_idx=None): 1055 | extended_attention_mask = self.get_extended_attention_mask( 1056 | input_ids, token_type_ids, attention_mask) 1057 | 1058 | embedding_output = self.embeddings( 1059 | input_ids, token_type_ids, position_ids, task_idx=task_idx) 1060 | encoded_layers = self.encoder(embedding_output, 1061 | extended_attention_mask, 1062 | output_all_encoded_layers=output_all_encoded_layers, 1063 | prev_embedding=prev_embedding, 1064 | prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv, seg_ids=token_type_ids) 1065 | sequence_output = encoded_layers[-1] 1066 | pooled_output = self.pooler(sequence_output) 1067 | if not output_all_encoded_layers: 1068 | encoded_layers = encoded_layers[-1] 1069 | return embedding_output, encoded_layers, pooled_output 1070 | 1071 | 1072 | class BertForPreTraining(PreTrainedBertModel): 1073 | """BERT model with pre-training heads. 1074 | This module comprises the BERT model followed by the two pre-training heads: 1075 | - the masked language modeling head, and 1076 | - the next sentence classification head. 1077 | Params: 1078 | config: a BertConfig class instance with the configuration to build a new model. 1079 | Inputs: 1080 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1081 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1082 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1083 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1084 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1085 | a `sentence B` token (see BERT paper for more details). 1086 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1087 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1088 | input sequence length in the current batch. It's the mask that we typically use for attention when 1089 | a batch has varying length sentences. 1090 | `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 1091 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 1092 | is only computed for the labels set in [0, ..., vocab_size] 1093 | `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] 1094 | with indices selected in [0, 1]. 1095 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence. 1096 | Outputs: 1097 | if `masked_lm_labels` and `next_sentence_label` are not `None`: 1098 | Outputs the total_loss which is the sum of the masked language modeling loss and the next 1099 | sentence classification loss. 1100 | if `masked_lm_labels` or `next_sentence_label` is `None`: 1101 | Outputs a tuple comprising 1102 | - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and 1103 | - the next sentence classification logits of shape [batch_size, 2]. 1104 | Example usage: 1105 | ```python 1106 | # Already been converted into WordPiece token ids 1107 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1108 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1109 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1110 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1111 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1112 | model = BertForPreTraining(config) 1113 | masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) 1114 | ``` 1115 | """ 1116 | 1117 | def __init__(self, config): 1118 | super(BertForPreTraining, self).__init__(config) 1119 | self.bert = BertModel(config) 1120 | self.cls = BertPreTrainingHeads( 1121 | config, self.bert.embeddings.word_embeddings.weight) 1122 | self.apply(self.init_bert_weights) 1123 | 1124 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, mask_qkv=None, task_idx=None): 1125 | sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, 1126 | output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) 1127 | prediction_scores, seq_relationship_score = self.cls( 1128 | sequence_output, pooled_output) 1129 | 1130 | if masked_lm_labels is not None and next_sentence_label is not None: 1131 | loss_fct = CrossEntropyLoss(ignore_index=-1) 1132 | masked_lm_loss = loss_fct( 1133 | prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 1134 | next_sentence_loss = loss_fct( 1135 | seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 1136 | total_loss = masked_lm_loss + next_sentence_loss 1137 | return total_loss 1138 | else: 1139 | return prediction_scores, seq_relationship_score 1140 | 1141 | 1142 | class BertPreTrainingPairTransform(nn.Module): 1143 | def __init__(self, config): 1144 | super(BertPreTrainingPairTransform, self).__init__() 1145 | self.dense = nn.Linear(config.hidden_size*2, config.hidden_size) 1146 | self.transform_act_fn = ACT2FN[config.hidden_act] \ 1147 | if isinstance(config.hidden_act, str) else config.hidden_act 1148 | # self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) 1149 | 1150 | def forward(self, pair_x, pair_y): 1151 | hidden_states = torch.cat([pair_x, pair_y], dim=-1) 1152 | hidden_states = self.dense(hidden_states) 1153 | hidden_states = self.transform_act_fn(hidden_states) 1154 | # hidden_states = self.LayerNorm(hidden_states) 1155 | return hidden_states 1156 | 1157 | 1158 | class BertPreTrainingPairRel(nn.Module): 1159 | def __init__(self, config, num_rel=0): 1160 | super(BertPreTrainingPairRel, self).__init__() 1161 | self.R_xy = BertPreTrainingPairTransform(config) 1162 | self.rel_emb = nn.Embedding(num_rel, config.hidden_size) 1163 | 1164 | def forward(self, pair_x, pair_y, pair_r, pair_pos_neg_mask): 1165 | # (batch, num_pair, hidden) 1166 | xy = self.R_xy(pair_x, pair_y) 1167 | r = self.rel_emb(pair_r) 1168 | _batch, _num_pair, _hidden = xy.size() 1169 | pair_score = (xy * r).sum(-1) 1170 | # torch.bmm(xy.view(-1, 1, _hidden),r.view(-1, _hidden, 1)).view(_batch, _num_pair) 1171 | # .mul_(-1.0): objective to loss 1172 | return F.logsigmoid(pair_score * pair_pos_neg_mask.type_as(pair_score)).mul_(-1.0) 1173 | 1174 | 1175 | class BertForPreTrainingLossMask(PreTrainedBertModel): 1176 | """refer to BertForPreTraining""" 1177 | 1178 | def __init__(self, config, num_labels=2, num_rel=0, num_sentlvl_labels=0, no_nsp=False): 1179 | super(BertForPreTrainingLossMask, self).__init__(config) 1180 | self.bert = BertModel(config) 1181 | self.cls = BertPreTrainingHeads( 1182 | config, self.bert.embeddings.word_embeddings.weight, num_labels=num_labels) 1183 | self.num_sentlvl_labels = num_sentlvl_labels 1184 | self.cls2 = None 1185 | if self.num_sentlvl_labels > 0: 1186 | self.secondary_pred_proj = nn.Embedding( 1187 | num_sentlvl_labels, config.hidden_size) 1188 | self.cls2 = BertPreTrainingHeads( 1189 | config, self.secondary_pred_proj.weight, num_labels=num_sentlvl_labels) 1190 | self.crit_mask_lm = nn.CrossEntropyLoss(reduction='none') 1191 | if no_nsp: 1192 | self.crit_next_sent = None 1193 | else: 1194 | self.crit_next_sent = nn.CrossEntropyLoss(ignore_index=-1) 1195 | self.num_labels = num_labels 1196 | self.num_rel = num_rel 1197 | if self.num_rel > 0: 1198 | self.crit_pair_rel = BertPreTrainingPairRel( 1199 | config, num_rel=num_rel) 1200 | if hasattr(config, 'label_smoothing') and config.label_smoothing: 1201 | self.crit_mask_lm_smoothed = LabelSmoothingLoss( 1202 | config.label_smoothing, config.vocab_size, ignore_index=0, reduction='none') 1203 | else: 1204 | self.crit_mask_lm_smoothed = None 1205 | self.apply(self.init_bert_weights) 1206 | self.bert.rescale_some_parameters() 1207 | 1208 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, 1209 | next_sentence_label=None, masked_pos=None, masked_weights=None, task_idx=None, pair_x=None, 1210 | pair_x_mask=None, pair_y=None, pair_y_mask=None, pair_r=None, pair_pos_neg_mask=None, 1211 | pair_loss_mask=None, masked_pos_2=None, masked_weights_2=None, masked_labels_2=None, 1212 | num_tokens_a=None, num_tokens_b=None, mask_qkv=None): 1213 | if token_type_ids is None and attention_mask is None: 1214 | task_0 = (task_idx == 0) 1215 | task_1 = (task_idx == 1) 1216 | task_2 = (task_idx == 2) 1217 | task_3 = (task_idx == 3) 1218 | 1219 | sequence_length = input_ids.shape[-1] 1220 | index_matrix = torch.arange(sequence_length).view( 1221 | 1, sequence_length).to(input_ids.device) 1222 | 1223 | num_tokens = num_tokens_a + num_tokens_b 1224 | 1225 | base_mask = (index_matrix < num_tokens.view(-1, 1) 1226 | ).type_as(input_ids) 1227 | segment_a_mask = ( 1228 | index_matrix < num_tokens_a.view(-1, 1)).type_as(input_ids) 1229 | 1230 | token_type_ids = ( 1231 | task_idx + 1 + task_3.type_as(task_idx)).view(-1, 1) * base_mask 1232 | token_type_ids = token_type_ids - segment_a_mask * \ 1233 | (task_0 | task_3).type_as(segment_a_mask).view(-1, 1) 1234 | 1235 | index_matrix = index_matrix.view(1, 1, sequence_length) 1236 | index_matrix_t = index_matrix.view(1, sequence_length, 1) 1237 | 1238 | tril = index_matrix <= index_matrix_t 1239 | 1240 | attention_mask_task_0 = ( 1241 | index_matrix < num_tokens.view(-1, 1, 1)) & (index_matrix_t < num_tokens.view(-1, 1, 1)) 1242 | attention_mask_task_1 = tril & attention_mask_task_0 1243 | attention_mask_task_2 = torch.transpose( 1244 | tril, dim0=-2, dim1=-1) & attention_mask_task_0 1245 | attention_mask_task_3 = ( 1246 | (index_matrix < num_tokens_a.view(-1, 1, 1)) | tril) & attention_mask_task_0 1247 | 1248 | attention_mask = (attention_mask_task_0 & task_0.view(-1, 1, 1)) | \ 1249 | (attention_mask_task_1 & task_1.view(-1, 1, 1)) | \ 1250 | (attention_mask_task_2 & task_2.view(-1, 1, 1)) | \ 1251 | (attention_mask_task_3 & task_3.view(-1, 1, 1)) 1252 | attention_mask = attention_mask.type_as(input_ids) 1253 | sequence_output, pooled_output = self.bert( 1254 | input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) 1255 | 1256 | def gather_seq_out_by_pos(seq, pos): 1257 | return torch.gather(seq, 1, pos.unsqueeze(2).expand(-1, -1, seq.size(-1))) 1258 | 1259 | def gather_seq_out_by_pos_average(seq, pos, mask): 1260 | # pos/mask: (batch, num_pair, max_token_num) 1261 | batch_size, max_token_num = pos.size(0), pos.size(-1) 1262 | # (batch, num_pair, max_token_num, seq.size(-1)) 1263 | pos_vec = torch.gather(seq, 1, pos.view(batch_size, -1).unsqueeze( 1264 | 2).expand(-1, -1, seq.size(-1))).view(batch_size, -1, max_token_num, seq.size(-1)) 1265 | # (batch, num_pair, seq.size(-1)) 1266 | mask = mask.type_as(pos_vec) 1267 | pos_vec_masked_sum = ( 1268 | pos_vec * mask.unsqueeze(3).expand_as(pos_vec)).sum(2) 1269 | return pos_vec_masked_sum / mask.sum(2, keepdim=True).expand_as(pos_vec_masked_sum) 1270 | 1271 | def loss_mask_and_normalize(loss, mask): 1272 | mask = mask.type_as(loss) 1273 | loss = loss * mask 1274 | denominator = torch.sum(mask) + 1e-5 1275 | return (loss / denominator).sum() 1276 | 1277 | if masked_lm_labels is None: 1278 | if masked_pos is None: 1279 | prediction_scores, seq_relationship_score = self.cls( 1280 | sequence_output, pooled_output, task_idx=task_idx) 1281 | else: 1282 | sequence_output_masked = gather_seq_out_by_pos( 1283 | sequence_output, masked_pos) 1284 | prediction_scores, seq_relationship_score = self.cls( 1285 | sequence_output_masked, pooled_output, task_idx=task_idx) 1286 | return prediction_scores, seq_relationship_score 1287 | 1288 | # masked lm 1289 | sequence_output_masked = gather_seq_out_by_pos( 1290 | sequence_output, masked_pos) 1291 | prediction_scores_masked, seq_relationship_score = self.cls( 1292 | sequence_output_masked, pooled_output, task_idx=task_idx) 1293 | if self.crit_mask_lm_smoothed: 1294 | masked_lm_loss = self.crit_mask_lm_smoothed( 1295 | F.log_softmax(prediction_scores_masked.float(), dim=-1), masked_lm_labels) 1296 | else: 1297 | masked_lm_loss = self.crit_mask_lm( 1298 | prediction_scores_masked.transpose(1, 2).float(), masked_lm_labels) 1299 | masked_lm_loss = loss_mask_and_normalize( 1300 | masked_lm_loss.float(), masked_weights) 1301 | 1302 | # next sentence 1303 | if self.crit_next_sent is None or next_sentence_label is None: 1304 | next_sentence_loss = 0.0 1305 | else: 1306 | next_sentence_loss = self.crit_next_sent( 1307 | seq_relationship_score.view(-1, self.num_labels).float(), next_sentence_label.view(-1)) 1308 | 1309 | if self.cls2 is not None and masked_pos_2 is not None: 1310 | sequence_output_masked_2 = gather_seq_out_by_pos( 1311 | sequence_output, masked_pos_2) 1312 | prediction_scores_masked_2, _ = self.cls2( 1313 | sequence_output_masked_2, None) 1314 | masked_lm_loss_2 = self.crit_mask_lm( 1315 | prediction_scores_masked_2.transpose(1, 2).float(), masked_labels_2) 1316 | masked_lm_loss_2 = loss_mask_and_normalize( 1317 | masked_lm_loss_2.float(), masked_weights_2) 1318 | masked_lm_loss = masked_lm_loss + masked_lm_loss_2 1319 | 1320 | if pair_x is None or pair_y is None or pair_r is None or pair_pos_neg_mask is None or pair_loss_mask is None: 1321 | return masked_lm_loss, next_sentence_loss 1322 | 1323 | # pair and relation 1324 | if pair_x_mask is None or pair_y_mask is None: 1325 | pair_x_output_masked = gather_seq_out_by_pos( 1326 | sequence_output, pair_x) 1327 | pair_y_output_masked = gather_seq_out_by_pos( 1328 | sequence_output, pair_y) 1329 | else: 1330 | pair_x_output_masked = gather_seq_out_by_pos_average( 1331 | sequence_output, pair_x, pair_x_mask) 1332 | pair_y_output_masked = gather_seq_out_by_pos_average( 1333 | sequence_output, pair_y, pair_y_mask) 1334 | pair_loss = self.crit_pair_rel( 1335 | pair_x_output_masked, pair_y_output_masked, pair_r, pair_pos_neg_mask) 1336 | pair_loss = loss_mask_and_normalize( 1337 | pair_loss.float(), pair_loss_mask) 1338 | return masked_lm_loss, next_sentence_loss, pair_loss 1339 | 1340 | 1341 | class BertForExtractiveSummarization(PreTrainedBertModel): 1342 | """refer to BertForPreTraining""" 1343 | 1344 | def __init__(self, config): 1345 | super(BertForExtractiveSummarization, self).__init__(config) 1346 | self.bert = BertModel(config) 1347 | self.secondary_pred_proj = nn.Embedding(2, config.hidden_size) 1348 | self.cls2 = BertPreTrainingHeads( 1349 | config, self.secondary_pred_proj.weight, num_labels=2) 1350 | self.apply(self.init_bert_weights) 1351 | 1352 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_pos_2=None, masked_weights_2=None, task_idx=None, mask_qkv=None): 1353 | sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, 1354 | output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) 1355 | 1356 | def gather_seq_out_by_pos(seq, pos): 1357 | return torch.gather(seq, 1, pos.unsqueeze(2).expand(-1, -1, seq.size(-1))) 1358 | 1359 | sequence_output_masked_2 = gather_seq_out_by_pos( 1360 | sequence_output, masked_pos_2) 1361 | prediction_scores_masked_2, _ = self.cls2( 1362 | sequence_output_masked_2, None, task_idx=task_idx) 1363 | 1364 | predicted_probs = torch.nn.functional.softmax( 1365 | prediction_scores_masked_2, dim=-1) 1366 | 1367 | return predicted_probs, masked_pos_2, masked_weights_2 1368 | 1369 | 1370 | class BertForSeq2SeqDecoder(PreTrainedBertModel): 1371 | """refer to BertForPreTraining""" 1372 | 1373 | def __init__(self, config, mask_word_id=0, num_labels=2, num_rel=0, 1374 | search_beam_size=1, length_penalty=1.0, eos_id=0, sos_id=0, 1375 | forbid_duplicate_ngrams=False, forbid_ignore_set=None, not_predict_set=None, ngram_size=3, min_len=0, mode="s2s", pos_shift=False): 1376 | super(BertForSeq2SeqDecoder, self).__init__(config) 1377 | self.bert = BertModelIncr(config) 1378 | self.cls = BertPreTrainingHeads( 1379 | config, self.bert.embeddings.word_embeddings.weight, num_labels=num_labels) 1380 | self.apply(self.init_bert_weights) 1381 | self.crit_mask_lm = nn.CrossEntropyLoss(reduction='none') 1382 | self.crit_next_sent = nn.CrossEntropyLoss(ignore_index=-1) 1383 | self.mask_word_id = mask_word_id 1384 | self.num_labels = num_labels 1385 | self.num_rel = num_rel 1386 | if self.num_rel > 0: 1387 | self.crit_pair_rel = BertPreTrainingPairRel( 1388 | config, num_rel=num_rel) 1389 | self.search_beam_size = search_beam_size 1390 | self.length_penalty = length_penalty 1391 | self.eos_id = eos_id 1392 | self.sos_id = sos_id 1393 | self.forbid_duplicate_ngrams = forbid_duplicate_ngrams 1394 | self.forbid_ignore_set = forbid_ignore_set 1395 | self.not_predict_set = not_predict_set 1396 | self.ngram_size = ngram_size 1397 | self.min_len = min_len 1398 | assert mode in ("s2s", "l2r") 1399 | self.mode = mode 1400 | self.pos_shift = pos_shift 1401 | 1402 | def forward(self, input_ids, token_type_ids, position_ids, attention_mask, task_idx=None, mask_qkv=None): 1403 | if self.search_beam_size > 1: 1404 | return self.beam_search(input_ids, token_type_ids, position_ids, attention_mask, task_idx=task_idx, mask_qkv=mask_qkv) 1405 | 1406 | input_shape = list(input_ids.size()) 1407 | batch_size = input_shape[0] 1408 | input_length = input_shape[1] 1409 | output_shape = list(token_type_ids.size()) 1410 | output_length = output_shape[1] 1411 | 1412 | output_ids = [] 1413 | prev_embedding = None 1414 | prev_encoded_layers = None 1415 | curr_ids = input_ids 1416 | mask_ids = input_ids.new(batch_size, 1).fill_(self.mask_word_id) 1417 | next_pos = input_length 1418 | if self.pos_shift: 1419 | sos_ids = input_ids.new(batch_size, 1).fill_(self.sos_id) 1420 | 1421 | while next_pos < output_length: 1422 | curr_length = list(curr_ids.size())[1] 1423 | 1424 | if self.pos_shift: 1425 | if next_pos == input_length: 1426 | x_input_ids = torch.cat((curr_ids, sos_ids), dim=1) 1427 | start_pos = 0 1428 | else: 1429 | x_input_ids = curr_ids 1430 | start_pos = next_pos 1431 | else: 1432 | start_pos = next_pos - curr_length 1433 | x_input_ids = torch.cat((curr_ids, mask_ids), dim=1) 1434 | 1435 | curr_token_type_ids = token_type_ids[:, start_pos:next_pos+1] 1436 | curr_attention_mask = attention_mask[:, 1437 | start_pos:next_pos+1, :next_pos+1] 1438 | curr_position_ids = position_ids[:, start_pos:next_pos+1] 1439 | new_embedding, new_encoded_layers, _ = \ 1440 | self.bert(x_input_ids, curr_token_type_ids, curr_position_ids, curr_attention_mask, 1441 | output_all_encoded_layers=True, prev_embedding=prev_embedding, prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv) 1442 | 1443 | last_hidden = new_encoded_layers[-1][:, -1:, :] 1444 | prediction_scores, _ = self.cls( 1445 | last_hidden, None, task_idx=task_idx) 1446 | if self.not_predict_set: 1447 | for token_id in self.not_predict_set: 1448 | prediction_scores[:, :, token_id].fill_(-10000.0) 1449 | _, max_ids = torch.max(prediction_scores, dim=-1) 1450 | output_ids.append(max_ids) 1451 | 1452 | if self.pos_shift: 1453 | if prev_embedding is None: 1454 | prev_embedding = new_embedding 1455 | else: 1456 | prev_embedding = torch.cat( 1457 | (prev_embedding, new_embedding), dim=1) 1458 | if prev_encoded_layers is None: 1459 | prev_encoded_layers = [x for x in new_encoded_layers] 1460 | else: 1461 | prev_encoded_layers = [torch.cat((x[0], x[1]), dim=1) for x in zip( 1462 | prev_encoded_layers, new_encoded_layers)] 1463 | else: 1464 | if prev_embedding is None: 1465 | prev_embedding = new_embedding[:, :-1, :] 1466 | else: 1467 | prev_embedding = torch.cat( 1468 | (prev_embedding, new_embedding[:, :-1, :]), dim=1) 1469 | if prev_encoded_layers is None: 1470 | prev_encoded_layers = [x[:, :-1, :] 1471 | for x in new_encoded_layers] 1472 | else: 1473 | prev_encoded_layers = [torch.cat((x[0], x[1][:, :-1, :]), dim=1) 1474 | for x in zip(prev_encoded_layers, new_encoded_layers)] 1475 | curr_ids = max_ids 1476 | next_pos += 1 1477 | 1478 | return torch.cat(output_ids, dim=1) 1479 | 1480 | def beam_search(self, input_ids, token_type_ids, position_ids, attention_mask, task_idx=None, mask_qkv=None): 1481 | input_shape = list(input_ids.size()) 1482 | batch_size = input_shape[0] 1483 | input_length = input_shape[1] 1484 | output_shape = list(token_type_ids.size()) 1485 | output_length = output_shape[1] 1486 | 1487 | output_ids = [] 1488 | prev_embedding = None 1489 | prev_encoded_layers = None 1490 | curr_ids = input_ids 1491 | mask_ids = input_ids.new(batch_size, 1).fill_(self.mask_word_id) 1492 | next_pos = input_length 1493 | if self.pos_shift: 1494 | sos_ids = input_ids.new(batch_size, 1).fill_(self.sos_id) 1495 | 1496 | K = self.search_beam_size 1497 | 1498 | total_scores = [] 1499 | beam_masks = [] 1500 | step_ids = [] 1501 | step_back_ptrs = [] 1502 | partial_seqs = [] 1503 | forbid_word_mask = None 1504 | buf_matrix = None 1505 | 1506 | while next_pos < output_length: 1507 | curr_length = list(curr_ids.size())[1] 1508 | 1509 | if self.pos_shift: 1510 | if next_pos == input_length: 1511 | x_input_ids = torch.cat((curr_ids, sos_ids), dim=1) 1512 | start_pos = 0 1513 | else: 1514 | x_input_ids = curr_ids 1515 | start_pos = next_pos 1516 | else: 1517 | start_pos = next_pos - curr_length 1518 | x_input_ids = torch.cat((curr_ids, mask_ids), dim=1) 1519 | 1520 | curr_token_type_ids = token_type_ids[:, start_pos:next_pos + 1] 1521 | curr_attention_mask = attention_mask[:, 1522 | start_pos:next_pos + 1, :next_pos + 1] 1523 | curr_position_ids = position_ids[:, start_pos:next_pos + 1] 1524 | new_embedding, new_encoded_layers, _ = \ 1525 | self.bert(x_input_ids, curr_token_type_ids, curr_position_ids, curr_attention_mask, 1526 | output_all_encoded_layers=True, prev_embedding=prev_embedding, prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv) 1527 | 1528 | last_hidden = new_encoded_layers[-1][:, -1:, :] 1529 | prediction_scores, _ = self.cls( 1530 | last_hidden, None, task_idx=task_idx) 1531 | log_scores = torch.nn.functional.log_softmax( 1532 | prediction_scores, dim=-1) 1533 | if forbid_word_mask is not None: 1534 | log_scores += (forbid_word_mask * -10000.0) 1535 | if self.min_len and (next_pos-input_length+1 <= self.min_len): 1536 | log_scores[:, :, self.eos_id].fill_(-10000.0) 1537 | if self.not_predict_set: 1538 | for token_id in self.not_predict_set: 1539 | log_scores[:, :, token_id].fill_(-10000.0) 1540 | kk_scores, kk_ids = torch.topk(log_scores, k=K) 1541 | if len(total_scores) == 0: 1542 | k_ids = torch.reshape(kk_ids, [batch_size, K]) 1543 | back_ptrs = torch.zeros(batch_size, K, dtype=torch.long) 1544 | k_scores = torch.reshape(kk_scores, [batch_size, K]) 1545 | else: 1546 | last_eos = torch.reshape( 1547 | beam_masks[-1], [batch_size * K, 1, 1]) 1548 | last_seq_scores = torch.reshape( 1549 | total_scores[-1], [batch_size * K, 1, 1]) 1550 | kk_scores += last_eos * (-10000.0) + last_seq_scores 1551 | kk_scores = torch.reshape(kk_scores, [batch_size, K * K]) 1552 | k_scores, k_ids = torch.topk(kk_scores, k=K) 1553 | back_ptrs = torch.div(k_ids, K) 1554 | kk_ids = torch.reshape(kk_ids, [batch_size, K * K]) 1555 | k_ids = torch.gather(kk_ids, 1, k_ids) 1556 | step_back_ptrs.append(back_ptrs) 1557 | step_ids.append(k_ids) 1558 | beam_masks.append(torch.eq(k_ids, self.eos_id).float()) 1559 | total_scores.append(k_scores) 1560 | 1561 | def first_expand(x): 1562 | input_shape = list(x.size()) 1563 | expanded_shape = input_shape[:1] + [1] + input_shape[1:] 1564 | x = torch.reshape(x, expanded_shape) 1565 | repeat_count = [1, K] + [1] * (len(input_shape) - 1) 1566 | x = x.repeat(*repeat_count) 1567 | x = torch.reshape(x, [input_shape[0] * K] + input_shape[1:]) 1568 | return x 1569 | 1570 | def select_beam_items(x, ids): 1571 | id_shape = list(ids.size()) 1572 | id_rank = len(id_shape) 1573 | assert len(id_shape) == 2 1574 | x_shape = list(x.size()) 1575 | x = torch.reshape(x, [batch_size, K] + x_shape[1:]) 1576 | x_rank = len(x_shape) + 1 1577 | assert x_rank >= 2 1578 | if id_rank < x_rank: 1579 | ids = torch.reshape( 1580 | ids, id_shape + [1] * (x_rank - id_rank)) 1581 | ids = ids.expand(id_shape + x_shape[1:]) 1582 | y = torch.gather(x, 1, ids) 1583 | y = torch.reshape(y, x_shape) 1584 | return y 1585 | 1586 | is_first = (prev_embedding is None) 1587 | 1588 | if self.pos_shift: 1589 | if prev_embedding is None: 1590 | prev_embedding = first_expand(new_embedding) 1591 | else: 1592 | prev_embedding = torch.cat( 1593 | (prev_embedding, new_embedding), dim=1) 1594 | prev_embedding = select_beam_items( 1595 | prev_embedding, back_ptrs) 1596 | if prev_encoded_layers is None: 1597 | prev_encoded_layers = [first_expand( 1598 | x) for x in new_encoded_layers] 1599 | else: 1600 | prev_encoded_layers = [torch.cat((x[0], x[1]), dim=1) for x in zip( 1601 | prev_encoded_layers, new_encoded_layers)] 1602 | prev_encoded_layers = [select_beam_items( 1603 | x, back_ptrs) for x in prev_encoded_layers] 1604 | else: 1605 | if prev_embedding is None: 1606 | prev_embedding = first_expand(new_embedding[:, :-1, :]) 1607 | else: 1608 | prev_embedding = torch.cat( 1609 | (prev_embedding, new_embedding[:, :-1, :]), dim=1) 1610 | prev_embedding = select_beam_items( 1611 | prev_embedding, back_ptrs) 1612 | if prev_encoded_layers is None: 1613 | prev_encoded_layers = [first_expand( 1614 | x[:, :-1, :]) for x in new_encoded_layers] 1615 | else: 1616 | prev_encoded_layers = [torch.cat((x[0], x[1][:, :-1, :]), dim=1) 1617 | for x in zip(prev_encoded_layers, new_encoded_layers)] 1618 | prev_encoded_layers = [select_beam_items( 1619 | x, back_ptrs) for x in prev_encoded_layers] 1620 | 1621 | curr_ids = torch.reshape(k_ids, [batch_size * K, 1]) 1622 | 1623 | if is_first: 1624 | token_type_ids = first_expand(token_type_ids) 1625 | position_ids = first_expand(position_ids) 1626 | attention_mask = first_expand(attention_mask) 1627 | mask_ids = first_expand(mask_ids) 1628 | if mask_qkv is not None: 1629 | mask_qkv = first_expand(mask_qkv) 1630 | 1631 | if self.forbid_duplicate_ngrams: 1632 | wids = step_ids[-1].tolist() 1633 | ptrs = step_back_ptrs[-1].tolist() 1634 | if is_first: 1635 | partial_seqs = [] 1636 | for b in range(batch_size): 1637 | for k in range(K): 1638 | partial_seqs.append([wids[b][k]]) 1639 | else: 1640 | new_partial_seqs = [] 1641 | for b in range(batch_size): 1642 | for k in range(K): 1643 | new_partial_seqs.append( 1644 | partial_seqs[ptrs[b][k] + b * K] + [wids[b][k]]) 1645 | partial_seqs = new_partial_seqs 1646 | 1647 | def get_dup_ngram_candidates(seq, n): 1648 | cands = set() 1649 | if len(seq) < n: 1650 | return [] 1651 | tail = seq[-(n-1):] 1652 | if self.forbid_ignore_set and any(tk in self.forbid_ignore_set for tk in tail): 1653 | return [] 1654 | for i in range(len(seq) - (n - 1)): 1655 | mismatch = False 1656 | for j in range(n - 1): 1657 | if tail[j] != seq[i + j]: 1658 | mismatch = True 1659 | break 1660 | if (not mismatch) and not(self.forbid_ignore_set and (seq[i + n - 1] in self.forbid_ignore_set)): 1661 | cands.add(seq[i + n - 1]) 1662 | return list(sorted(cands)) 1663 | 1664 | if len(partial_seqs[0]) >= self.ngram_size: 1665 | dup_cands = [] 1666 | for seq in partial_seqs: 1667 | dup_cands.append( 1668 | get_dup_ngram_candidates(seq, self.ngram_size)) 1669 | if max(len(x) for x in dup_cands) > 0: 1670 | if buf_matrix is None: 1671 | vocab_size = list(log_scores.size())[-1] 1672 | buf_matrix = np.zeros( 1673 | (batch_size * K, vocab_size), dtype=float) 1674 | else: 1675 | buf_matrix.fill(0) 1676 | for bk, cands in enumerate(dup_cands): 1677 | for i, wid in enumerate(cands): 1678 | buf_matrix[bk, wid] = 1.0 1679 | forbid_word_mask = torch.tensor( 1680 | buf_matrix, dtype=log_scores.dtype) 1681 | forbid_word_mask = torch.reshape( 1682 | forbid_word_mask, [batch_size * K, 1, vocab_size]).cuda() 1683 | else: 1684 | forbid_word_mask = None 1685 | next_pos += 1 1686 | 1687 | # [(batch, beam)] 1688 | total_scores = [x.tolist() for x in total_scores] 1689 | step_ids = [x.tolist() for x in step_ids] 1690 | step_back_ptrs = [x.tolist() for x in step_back_ptrs] 1691 | # back tracking 1692 | traces = {'pred_seq': [], 'scores': [], 'wids': [], 'ptrs': []} 1693 | for b in range(batch_size): 1694 | # [(beam,)] 1695 | scores = [x[b] for x in total_scores] 1696 | wids_list = [x[b] for x in step_ids] 1697 | ptrs = [x[b] for x in step_back_ptrs] 1698 | traces['scores'].append(scores) 1699 | traces['wids'].append(wids_list) 1700 | traces['ptrs'].append(ptrs) 1701 | # first we need to find the eos frame where all symbols are eos 1702 | # any frames after the eos frame are invalid 1703 | last_frame_id = len(scores) - 1 1704 | for i, wids in enumerate(wids_list): 1705 | if all(wid == self.eos_id for wid in wids): 1706 | last_frame_id = i 1707 | break 1708 | max_score = -math.inf 1709 | frame_id = -1 1710 | pos_in_frame = -1 1711 | 1712 | for fid in range(last_frame_id + 1): 1713 | for i, wid in enumerate(wids_list[fid]): 1714 | if wid == self.eos_id or fid == last_frame_id: 1715 | s = scores[fid][i] 1716 | if self.length_penalty > 0: 1717 | s /= math.pow((5 + fid + 1) / 6.0, 1718 | self.length_penalty) 1719 | if s > max_score: 1720 | max_score = s 1721 | frame_id = fid 1722 | pos_in_frame = i 1723 | if frame_id == -1: 1724 | traces['pred_seq'].append([0]) 1725 | else: 1726 | seq = [wids_list[frame_id][pos_in_frame]] 1727 | for fid in range(frame_id, 0, -1): 1728 | pos_in_frame = ptrs[fid][pos_in_frame] 1729 | seq.append(wids_list[fid - 1][pos_in_frame]) 1730 | seq.reverse() 1731 | traces['pred_seq'].append(seq) 1732 | 1733 | def _pad_sequence(sequences, max_len, padding_value=0): 1734 | trailing_dims = sequences[0].size()[1:] 1735 | out_dims = (len(sequences), max_len) + trailing_dims 1736 | 1737 | out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) 1738 | for i, tensor in enumerate(sequences): 1739 | length = tensor.size(0) 1740 | # use index notation to prevent duplicate references to the tensor 1741 | out_tensor[i, :length, ...] = tensor 1742 | return out_tensor 1743 | 1744 | # convert to tensors for DataParallel 1745 | for k in ('pred_seq', 'scores', 'wids', 'ptrs'): 1746 | ts_list = traces[k] 1747 | if not isinstance(ts_list[0], torch.Tensor): 1748 | dt = torch.float if k == 'scores' else torch.long 1749 | ts_list = [torch.tensor(it, dtype=dt) for it in ts_list] 1750 | traces[k] = _pad_sequence( 1751 | ts_list, output_length, padding_value=0).to(input_ids.device) 1752 | 1753 | return traces 1754 | 1755 | 1756 | class BertForMaskedLM(PreTrainedBertModel): 1757 | """BERT model with the masked language modeling head. 1758 | This module comprises the BERT model followed by the masked language modeling head. 1759 | 1760 | Params: 1761 | config: a BertConfig class instance with the configuration to build a new model. 1762 | 1763 | Inputs: 1764 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1765 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1766 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1767 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1768 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1769 | a `sentence B` token (see BERT paper for more details). 1770 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1771 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1772 | input sequence length in the current batch. It's the mask that we typically use for attention when 1773 | a batch has varying length sentences. 1774 | `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 1775 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 1776 | is only computed for the labels set in [0, ..., vocab_size] 1777 | 1778 | Outputs: 1779 | if `masked_lm_labels` is `None`: 1780 | Outputs the masked language modeling loss. 1781 | if `masked_lm_labels` is `None`: 1782 | Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size]. 1783 | 1784 | Example usage: 1785 | ```python 1786 | # Already been converted into WordPiece token ids 1787 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1788 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1789 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1790 | 1791 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1792 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1793 | 1794 | model = BertForMaskedLM(config) 1795 | masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) 1796 | ``` 1797 | """ 1798 | 1799 | def __init__(self, config): 1800 | super(BertForMaskedLM, self).__init__(config) 1801 | self.bert = BertModel(config) 1802 | self.cls = BertOnlyMLMHead( 1803 | config, self.bert.embeddings.word_embeddings.weight) 1804 | self.apply(self.init_bert_weights) 1805 | 1806 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, mask_qkv=None, task_idx=None): 1807 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, 1808 | output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) 1809 | prediction_scores = self.cls(sequence_output) 1810 | 1811 | if masked_lm_labels is not None: 1812 | loss_fct = CrossEntropyLoss(ignore_index=-1) 1813 | masked_lm_loss = loss_fct( 1814 | prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 1815 | return masked_lm_loss 1816 | else: 1817 | return prediction_scores 1818 | 1819 | 1820 | class BertForNextSentencePrediction(PreTrainedBertModel): 1821 | """BERT model with next sentence prediction head. 1822 | This module comprises the BERT model followed by the next sentence classification head. 1823 | 1824 | Params: 1825 | config: a BertConfig class instance with the configuration to build a new model. 1826 | 1827 | Inputs: 1828 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1829 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1830 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1831 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1832 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1833 | a `sentence B` token (see BERT paper for more details). 1834 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1835 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1836 | input sequence length in the current batch. It's the mask that we typically use for attention when 1837 | a batch has varying length sentences. 1838 | `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] 1839 | with indices selected in [0, 1]. 1840 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence. 1841 | 1842 | Outputs: 1843 | if `next_sentence_label` is not `None`: 1844 | Outputs the total_loss which is the sum of the masked language modeling loss and the next 1845 | sentence classification loss. 1846 | if `next_sentence_label` is `None`: 1847 | Outputs the next sentence classification logits of shape [batch_size, 2]. 1848 | 1849 | Example usage: 1850 | ```python 1851 | # Already been converted into WordPiece token ids 1852 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1853 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1854 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1855 | 1856 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1857 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1858 | 1859 | model = BertForNextSentencePrediction(config) 1860 | seq_relationship_logits = model(input_ids, token_type_ids, input_mask) 1861 | ``` 1862 | """ 1863 | 1864 | def __init__(self, config): 1865 | super(BertForNextSentencePrediction, self).__init__(config) 1866 | self.bert = BertModel(config) 1867 | self.cls = BertOnlyNSPHead(config) 1868 | self.apply(self.init_bert_weights) 1869 | 1870 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, mask_qkv=None, task_idx=None): 1871 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, 1872 | output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) 1873 | seq_relationship_score = self.cls(pooled_output) 1874 | 1875 | if next_sentence_label is not None: 1876 | loss_fct = CrossEntropyLoss(ignore_index=-1) 1877 | next_sentence_loss = loss_fct( 1878 | seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 1879 | return next_sentence_loss 1880 | else: 1881 | return seq_relationship_score 1882 | 1883 | 1884 | class BertForSequenceClassification(PreTrainedBertModel): 1885 | """BERT model for classification. 1886 | This module is composed of the BERT model with a linear layer on top of 1887 | the pooled output. 1888 | 1889 | Params: 1890 | `config`: a BertConfig class instance with the configuration to build a new model. 1891 | `num_labels`: the number of classes for the classifier. Default = 2. 1892 | 1893 | Inputs: 1894 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1895 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1896 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1897 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1898 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1899 | a `sentence B` token (see BERT paper for more details). 1900 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1901 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1902 | input sequence length in the current batch. It's the mask that we typically use for attention when 1903 | a batch has varying length sentences. 1904 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 1905 | with indices selected in [0, ..., num_labels]. 1906 | 1907 | Outputs: 1908 | if `labels` is not `None`: 1909 | Outputs the CrossEntropy classification loss of the output with the labels. 1910 | if `labels` is `None`: 1911 | Outputs the classification logits of shape [batch_size, num_labels]. 1912 | 1913 | Example usage: 1914 | ```python 1915 | # Already been converted into WordPiece token ids 1916 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1917 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1918 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1919 | 1920 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1921 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1922 | 1923 | num_labels = 2 1924 | 1925 | model = BertForSequenceClassification(config, num_labels) 1926 | logits = model(input_ids, token_type_ids, input_mask) 1927 | ``` 1928 | """ 1929 | 1930 | def __init__(self, config, num_labels=2): 1931 | super(BertForSequenceClassification, self).__init__(config) 1932 | self.num_labels = num_labels 1933 | self.bert = BertModel(config) 1934 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1935 | self.classifier = nn.Linear(config.hidden_size, num_labels) 1936 | self.apply(self.init_bert_weights) 1937 | 1938 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, mask_qkv=None, task_idx=None): 1939 | _, pooled_output = self.bert( 1940 | input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) 1941 | pooled_output = self.dropout(pooled_output) 1942 | logits = self.classifier(pooled_output) 1943 | 1944 | if labels is not None: 1945 | if labels.dtype == torch.long: 1946 | loss_fct = CrossEntropyLoss() 1947 | loss = loss_fct( 1948 | logits.view(-1, self.num_labels), labels.view(-1)) 1949 | elif labels.dtype == torch.half or labels.dtype == torch.float: 1950 | loss_fct = MSELoss() 1951 | loss = loss_fct(logits.view(-1), labels.view(-1)) 1952 | else: 1953 | print('unkown labels.dtype') 1954 | loss = None 1955 | return loss 1956 | else: 1957 | return logits 1958 | 1959 | 1960 | class BertForMultipleChoice(PreTrainedBertModel): 1961 | """BERT model for multiple choice tasks. 1962 | This module is composed of the BERT model with a linear layer on top of 1963 | the pooled output. 1964 | 1965 | Params: 1966 | `config`: a BertConfig class instance with the configuration to build a new model. 1967 | `num_choices`: the number of classes for the classifier. Default = 2. 1968 | 1969 | Inputs: 1970 | `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] 1971 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1972 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1973 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] 1974 | with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` 1975 | and type 1 corresponds to a `sentence B` token (see BERT paper for more details). 1976 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices 1977 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1978 | input sequence length in the current batch. It's the mask that we typically use for attention when 1979 | a batch has varying length sentences. 1980 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 1981 | with indices selected in [0, ..., num_choices]. 1982 | 1983 | Outputs: 1984 | if `labels` is not `None`: 1985 | Outputs the CrossEntropy classification loss of the output with the labels. 1986 | if `labels` is `None`: 1987 | Outputs the classification logits of shape [batch_size, num_labels]. 1988 | 1989 | Example usage: 1990 | ```python 1991 | # Already been converted into WordPiece token ids 1992 | input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) 1993 | input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) 1994 | token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) 1995 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1996 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1997 | 1998 | num_choices = 2 1999 | 2000 | model = BertForMultipleChoice(config, num_choices) 2001 | logits = model(input_ids, token_type_ids, input_mask) 2002 | ``` 2003 | """ 2004 | 2005 | def __init__(self, config, num_choices=2): 2006 | super(BertForMultipleChoice, self).__init__(config) 2007 | self.num_choices = num_choices 2008 | self.bert = BertModel(config) 2009 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 2010 | self.classifier = nn.Linear(config.hidden_size, 1) 2011 | self.apply(self.init_bert_weights) 2012 | 2013 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, mask_qkv=None, task_idx=None): 2014 | flat_input_ids = input_ids.view(-1, input_ids.size(-1)) 2015 | flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) 2016 | flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) 2017 | _, pooled_output = self.bert( 2018 | flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) 2019 | pooled_output = self.dropout(pooled_output) 2020 | logits = self.classifier(pooled_output) 2021 | reshaped_logits = logits.view(-1, self.num_choices) 2022 | 2023 | if labels is not None: 2024 | loss_fct = CrossEntropyLoss() 2025 | loss = loss_fct(reshaped_logits, labels) 2026 | return loss 2027 | else: 2028 | return reshaped_logits 2029 | 2030 | 2031 | class BertForTokenClassification(PreTrainedBertModel): 2032 | """BERT model for token-level classification. 2033 | This module is composed of the BERT model with a linear layer on top of 2034 | the full hidden state of the last layer. 2035 | 2036 | Params: 2037 | `config`: a BertConfig class instance with the configuration to build a new model. 2038 | `num_labels`: the number of classes for the classifier. Default = 2. 2039 | 2040 | Inputs: 2041 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 2042 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 2043 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 2044 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 2045 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 2046 | a `sentence B` token (see BERT paper for more details). 2047 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 2048 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 2049 | input sequence length in the current batch. It's the mask that we typically use for attention when 2050 | a batch has varying length sentences. 2051 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 2052 | with indices selected in [0, ..., num_labels]. 2053 | 2054 | Outputs: 2055 | if `labels` is not `None`: 2056 | Outputs the CrossEntropy classification loss of the output with the labels. 2057 | if `labels` is `None`: 2058 | Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. 2059 | 2060 | Example usage: 2061 | ```python 2062 | # Already been converted into WordPiece token ids 2063 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 2064 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 2065 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 2066 | 2067 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 2068 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 2069 | 2070 | num_labels = 2 2071 | 2072 | model = BertForTokenClassification(config, num_labels) 2073 | logits = model(input_ids, token_type_ids, input_mask) 2074 | ``` 2075 | """ 2076 | 2077 | def __init__(self, config, num_labels=2): 2078 | super(BertForTokenClassification, self).__init__(config) 2079 | self.num_labels = num_labels 2080 | self.bert = BertModel(config) 2081 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 2082 | self.classifier = nn.Linear(config.hidden_size, num_labels) 2083 | self.apply(self.init_bert_weights) 2084 | 2085 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, mask_qkv=None, task_idx=None): 2086 | sequence_output, _ = self.bert( 2087 | input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx) 2088 | sequence_output = self.dropout(sequence_output) 2089 | logits = self.classifier(sequence_output) 2090 | 2091 | if labels is not None: 2092 | loss_fct = CrossEntropyLoss() 2093 | # Only keep active parts of the loss 2094 | if attention_mask is not None: 2095 | active_loss = attention_mask.view(-1) == 1 2096 | active_logits = logits.view(-1, self.num_labels)[active_loss] 2097 | active_labels = labels.view(-1)[active_loss] 2098 | loss = loss_fct(active_logits, active_labels) 2099 | else: 2100 | loss = loss_fct( 2101 | logits.view(-1, self.num_labels), labels.view(-1)) 2102 | return loss 2103 | else: 2104 | return logits 2105 | 2106 | 2107 | class BertForQuestionAnswering(PreTrainedBertModel): 2108 | """BERT model for Question Answering (span extraction). 2109 | This module is composed of the BERT model with a linear layer on top of 2110 | the sequence output that computes start_logits and end_logits 2111 | 2112 | Params: 2113 | `config`: either 2114 | - a BertConfig class instance with the configuration to build a new model, or 2115 | - a str with the name of a pre-trained model to load selected in the list of: 2116 | . `bert-base-uncased` 2117 | . `bert-large-uncased` 2118 | . `bert-base-cased` 2119 | . `bert-base-multilingual` 2120 | . `bert-base-chinese` 2121 | The pre-trained model will be downloaded and cached if needed. 2122 | 2123 | Inputs: 2124 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 2125 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 2126 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 2127 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 2128 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 2129 | a `sentence B` token (see BERT paper for more details). 2130 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 2131 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 2132 | input sequence length in the current batch. It's the mask that we typically use for attention when 2133 | a batch has varying length sentences. 2134 | `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. 2135 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken 2136 | into account for computing the loss. 2137 | `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. 2138 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken 2139 | into account for computing the loss. 2140 | 2141 | Outputs: 2142 | if `start_positions` and `end_positions` are not `None`: 2143 | Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. 2144 | if `start_positions` or `end_positions` is `None`: 2145 | Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end 2146 | position tokens of shape [batch_size, sequence_length]. 2147 | 2148 | Example usage: 2149 | ```python 2150 | # Already been converted into WordPiece token ids 2151 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 2152 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 2153 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 2154 | 2155 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 2156 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 2157 | 2158 | model = BertForQuestionAnswering(config) 2159 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 2160 | ``` 2161 | """ 2162 | 2163 | def __init__(self, config): 2164 | super(BertForQuestionAnswering, self).__init__(config) 2165 | self.bert = BertModel(config) 2166 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 2167 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 2168 | self.apply(self.init_bert_weights) 2169 | 2170 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None, task_idx=None): 2171 | sequence_output, _ = self.bert( 2172 | input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, task_idx=task_idx) 2173 | logits = self.qa_outputs(sequence_output) 2174 | start_logits, end_logits = logits.split(1, dim=-1) 2175 | start_logits = start_logits.squeeze(-1) 2176 | end_logits = end_logits.squeeze(-1) 2177 | 2178 | if start_positions is not None and end_positions is not None: 2179 | # If we are on multi-GPU, split add a dimension 2180 | if len(start_positions.size()) > 1: 2181 | start_positions = start_positions.squeeze(-1) 2182 | if len(end_positions.size()) > 1: 2183 | end_positions = end_positions.squeeze(-1) 2184 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 2185 | ignored_index = start_logits.size(1) 2186 | start_positions.clamp_(0, ignored_index) 2187 | end_positions.clamp_(0, ignored_index) 2188 | 2189 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 2190 | start_loss = loss_fct(start_logits, start_positions) 2191 | end_loss = loss_fct(end_logits, end_positions) 2192 | total_loss = (start_loss + end_loss) / 2 2193 | return total_loss 2194 | else: 2195 | return start_logits, end_logits 2196 | -------------------------------------------------------------------------------- /question_generator/text2text/pytorch_pretrained_bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import os 24 | import logging 25 | 26 | from .file_utils import cached_path 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 31 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 32 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 33 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 34 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 35 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 36 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 37 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 38 | } 39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 40 | 'bert-base-uncased': 512, 41 | 'bert-large-uncased': 512, 42 | 'bert-base-cased': 512, 43 | 'bert-large-cased': 512, 44 | 'bert-base-multilingual-uncased': 512, 45 | 'bert-base-multilingual-cased': 512, 46 | 'bert-base-chinese': 512, 47 | } 48 | VOCAB_NAME = 'vocab.txt' 49 | 50 | 51 | def load_vocab(vocab_file): 52 | """Loads a vocabulary file into a dictionary.""" 53 | # mapping unused tokens to special tokens 54 | extra_map = {} 55 | extra_map['[unused1]'] = '[X_SEP]' 56 | for i in range(10): 57 | extra_map['[unused{}]'.format(i+2)] = '[SEP_{}]'.format(i) 58 | extra_map['[unused12]'] = '[S2S_SEP]' 59 | extra_map['[unused13]'] = '[S2S_CLS]' 60 | extra_map['[unused14]'] = '[L2R_SEP]' 61 | extra_map['[unused15]'] = '[L2R_CLS]' 62 | extra_map['[unused16]'] = '[R2L_SEP]' 63 | extra_map['[unused17]'] = '[R2L_CLS]' 64 | extra_map['[unused18]'] = '[S2S_SOS]' 65 | 66 | vocab = collections.OrderedDict() 67 | index = 0 68 | with open(vocab_file, "r", encoding="utf-8") as reader: 69 | while True: 70 | token = reader.readline() 71 | if not token: 72 | break 73 | token = token.strip() 74 | if token in extra_map: 75 | token = extra_map[token] 76 | vocab[token] = index 77 | index += 1 78 | return vocab 79 | 80 | 81 | def whitespace_tokenize(text): 82 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 83 | text = text.strip() 84 | if not text: 85 | return [] 86 | tokens = text.split() 87 | return tokens 88 | 89 | 90 | class BertTokenizer(object): 91 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 92 | 93 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, never_split=("[UNK]", "[SEP]", "[X_SEP]", "[PAD]", "[CLS]", "[MASK]")): 94 | if not os.path.isfile(vocab_file): 95 | raise ValueError( 96 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 97 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 98 | self.vocab = load_vocab(vocab_file) 99 | self.ids_to_tokens = collections.OrderedDict( 100 | [(ids, tok) for tok, ids in self.vocab.items()]) 101 | self.basic_tokenizer = BasicTokenizer( 102 | do_lower_case=do_lower_case, never_split=never_split) 103 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 104 | self.max_len = max_len if max_len is not None else int(1e12) 105 | 106 | def tokenize(self, text): 107 | split_tokens = [] 108 | for token in self.basic_tokenizer.tokenize(text): 109 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 110 | split_tokens.append(sub_token) 111 | return split_tokens 112 | 113 | def convert_tokens_to_ids(self, tokens): 114 | """Converts a sequence of tokens into ids using the vocab.""" 115 | ids = [] 116 | for token in tokens: 117 | ids.append(self.vocab[token]) 118 | if len(ids) > self.max_len: 119 | raise ValueError( 120 | "Token indices sequence length is longer than the specified maximum " 121 | " sequence length for this BERT model ({} > {}). Running this" 122 | " sequence through BERT will result in indexing errors".format( 123 | len(ids), self.max_len) 124 | ) 125 | return ids 126 | 127 | def convert_ids_to_tokens(self, ids): 128 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 129 | tokens = [] 130 | for i in ids: 131 | tokens.append(self.ids_to_tokens[i]) 132 | return tokens 133 | 134 | @classmethod 135 | def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): 136 | """ 137 | Instantiate a PreTrainedBertModel from a pre-trained model file. 138 | Download and cache the pre-trained model file if needed. 139 | """ 140 | if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP: 141 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name] 142 | else: 143 | vocab_file = pretrained_model_name 144 | if os.path.isdir(vocab_file): 145 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 146 | # redirect to the cache, if necessary 147 | try: 148 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 149 | except FileNotFoundError: 150 | logger.error( 151 | "Model name '{}' was not found in model name list ({}). " 152 | "We assumed '{}' was a path or url but couldn't find any file " 153 | "associated to this path or url.".format( 154 | pretrained_model_name, 155 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 156 | vocab_file)) 157 | return None 158 | if resolved_vocab_file == vocab_file: 159 | logger.info("loading vocabulary file {}".format(vocab_file)) 160 | else: 161 | logger.info("loading vocabulary file {} from cache at {}".format( 162 | vocab_file, resolved_vocab_file)) 163 | if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 164 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 165 | # than the number of positional embeddings 166 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name] 167 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 168 | # Instantiate tokenizer. 169 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 170 | return tokenizer 171 | 172 | 173 | class WhitespaceTokenizer(object): 174 | def tokenize(self, text): 175 | return whitespace_tokenize(text) 176 | 177 | 178 | class BasicTokenizer(object): 179 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 180 | 181 | def __init__(self, do_lower_case=True, never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 182 | """Constructs a BasicTokenizer. 183 | 184 | Args: 185 | do_lower_case: Whether to lower case the input. 186 | """ 187 | self.do_lower_case = do_lower_case 188 | self.never_split = never_split 189 | 190 | def tokenize(self, text): 191 | """Tokenizes a piece of text.""" 192 | text = self._clean_text(text) 193 | # This was added on November 1st, 2018 for the multilingual and Chinese 194 | # models. This is also applied to the English models now, but it doesn't 195 | # matter since the English models were not trained on any Chinese data 196 | # and generally don't have any Chinese data in them (there are Chinese 197 | # characters in the vocabulary because Wikipedia does have some Chinese 198 | # words in the English Wikipedia.). 199 | text = self._tokenize_chinese_chars(text) 200 | orig_tokens = whitespace_tokenize(text) 201 | split_tokens = [] 202 | for token in orig_tokens: 203 | if self.do_lower_case and token not in self.never_split: 204 | token = token.lower() 205 | token = self._run_strip_accents(token) 206 | split_tokens.extend(self._run_split_on_punc(token)) 207 | 208 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 209 | return output_tokens 210 | 211 | def _run_strip_accents(self, text): 212 | """Strips accents from a piece of text.""" 213 | text = unicodedata.normalize("NFD", text) 214 | output = [] 215 | for char in text: 216 | cat = unicodedata.category(char) 217 | if cat == "Mn": 218 | continue 219 | output.append(char) 220 | return "".join(output) 221 | 222 | def _run_split_on_punc(self, text): 223 | """Splits punctuation on a piece of text.""" 224 | if text in self.never_split: 225 | return [text] 226 | chars = list(text) 227 | i = 0 228 | start_new_word = True 229 | output = [] 230 | while i < len(chars): 231 | char = chars[i] 232 | if _is_punctuation(char): 233 | output.append([char]) 234 | start_new_word = True 235 | else: 236 | if start_new_word: 237 | output.append([]) 238 | start_new_word = False 239 | output[-1].append(char) 240 | i += 1 241 | 242 | return ["".join(x) for x in output] 243 | 244 | def _tokenize_chinese_chars(self, text): 245 | """Adds whitespace around any CJK character.""" 246 | output = [] 247 | for char in text: 248 | cp = ord(char) 249 | if self._is_chinese_char(cp): 250 | output.append(" ") 251 | output.append(char) 252 | output.append(" ") 253 | else: 254 | output.append(char) 255 | return "".join(output) 256 | 257 | def _is_chinese_char(self, cp): 258 | """Checks whether CP is the codepoint of a CJK character.""" 259 | # This defines a "chinese character" as anything in the CJK Unicode block: 260 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 261 | # 262 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 263 | # despite its name. The modern Korean Hangul alphabet is a different block, 264 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 265 | # space-separated words, so they are not treated specially and handled 266 | # like the all of the other languages. 267 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 268 | (cp >= 0x3400 and cp <= 0x4DBF) or # 269 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 270 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 271 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 272 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 273 | (cp >= 0xF900 and cp <= 0xFAFF) or # 274 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 275 | return True 276 | 277 | return False 278 | 279 | def _clean_text(self, text): 280 | """Performs invalid character removal and whitespace cleanup on text.""" 281 | output = [] 282 | for char in text: 283 | cp = ord(char) 284 | if cp == 0 or cp == 0xfffd or _is_control(char): 285 | continue 286 | if _is_whitespace(char): 287 | output.append(" ") 288 | else: 289 | output.append(char) 290 | return "".join(output) 291 | 292 | 293 | class WordpieceTokenizer(object): 294 | """Runs WordPiece tokenization.""" 295 | 296 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 297 | self.vocab = vocab 298 | self.unk_token = unk_token 299 | self.max_input_chars_per_word = max_input_chars_per_word 300 | 301 | def tokenize(self, text): 302 | """Tokenizes a piece of text into its word pieces. 303 | 304 | This uses a greedy longest-match-first algorithm to perform tokenization 305 | using the given vocabulary. 306 | 307 | For example: 308 | input = "unaffable" 309 | output = ["un", "##aff", "##able"] 310 | 311 | Args: 312 | text: A single token or whitespace separated tokens. This should have 313 | already been passed through `BasicTokenizer`. 314 | 315 | Returns: 316 | A list of wordpiece tokens. 317 | """ 318 | 319 | output_tokens = [] 320 | for token in whitespace_tokenize(text): 321 | chars = list(token) 322 | if len(chars) > self.max_input_chars_per_word: 323 | output_tokens.append(self.unk_token) 324 | continue 325 | 326 | is_bad = False 327 | start = 0 328 | sub_tokens = [] 329 | while start < len(chars): 330 | end = len(chars) 331 | cur_substr = None 332 | while start < end: 333 | substr = "".join(chars[start:end]) 334 | if start > 0: 335 | substr = "##" + substr 336 | if substr in self.vocab: 337 | cur_substr = substr 338 | break 339 | end -= 1 340 | if cur_substr is None: 341 | is_bad = True 342 | break 343 | sub_tokens.append(cur_substr) 344 | start = end 345 | 346 | if is_bad: 347 | output_tokens.append(self.unk_token) 348 | else: 349 | output_tokens.extend(sub_tokens) 350 | return output_tokens 351 | 352 | 353 | def _is_whitespace(char): 354 | """Checks whether `chars` is a whitespace character.""" 355 | # \t, \n, and \r are technically contorl characters but we treat them 356 | # as whitespace since they are generally considered as such. 357 | if char == " " or char == "\t" or char == "\n" or char == "\r": 358 | return True 359 | cat = unicodedata.category(char) 360 | if cat == "Zs": 361 | return True 362 | return False 363 | 364 | 365 | def _is_control(char): 366 | """Checks whether `chars` is a control character.""" 367 | # These are technically control characters but we count them as whitespace 368 | # characters. 369 | if char == "\t" or char == "\n" or char == "\r": 370 | return False 371 | cat = unicodedata.category(char) 372 | if cat.startswith("C"): 373 | return True 374 | return False 375 | 376 | 377 | def _is_punctuation(char): 378 | """Checks whether `chars` is a punctuation character.""" 379 | cp = ord(char) 380 | # We treat all non-letter/number ASCII as punctuation. 381 | # Characters such as "^", "$", and "`" are not in the Unicode 382 | # Punctuation class but we treat them as punctuation anyways, for 383 | # consistency. 384 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 385 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 386 | return True 387 | cat = unicodedata.category(char) 388 | if cat.startswith("P"): 389 | return True 390 | return False 391 | -------------------------------------------------------------------------------- /question_generator/text2text/text_generator.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import math 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | import random 7 | import requests, zipfile, io 8 | import os 9 | 10 | from .pytorch_pretrained_bert.tokenization import BertTokenizer, WhitespaceTokenizer 11 | from .pytorch_pretrained_bert.modeling import BertForSeq2SeqDecoder, BertConfig 12 | 13 | from .biunilm import seq2seq_loader 14 | 15 | STOP_WORDS = ["0o", "0s", "3a", "3b", "3d", "6b", "6o", "a", "a1", "a2", "a3", "a4", "ab", "able", "about", "above", "abst", "ac", "accordance", "according", "accordingly", "across", "act", "actually", "ad", "added", "adj", "ae", "af", "affected", "affecting", "affects", "after", "afterwards", "ag", "again", "against", "ah", "ain", "ain't", "aj", "al", "all", "allow", "allows", "almost", "alone", "along", "already", "also", "although", "always", "am", "among", "amongst", "amoungst", "amount", "an", "and", "announce", "another", "any", "anybody", "anyhow", "anymore", "anyone", "anything", "anyway", "anyways", "anywhere", "ao", "ap", "apart", "apparently", "appear", "appreciate", "appropriate", "approximately", "ar", "are", "aren", "arent", "aren't", "arise", "around", "as", "a's", "aside", "ask", "asking", "associated", "at", "au", "auth", "av", "available", "aw", "away", "awfully", "ax", "ay", "az", "b", "b1", "b2", "b3", "ba", "back", "bc", "bd", "be", "became", "because", "become", "becomes", "becoming", "been", "before", "beforehand", "begin", "beginning", "beginnings", "begins", "behind", "being", "believe", "below", "beside", "besides", "best", "better", "between", "beyond", "bi", "bill", "biol", "bj", "bk", "bl", "bn", "both", "bottom", "bp", "br", "brief", "briefly", "bs", "bt", "bu", "but", "bx", "by", "c", "c1", "c2", "c3", "ca", "call", "came", "can", "cannot", "cant", "can't", "cause", "causes", "cc", "cd", "ce", "certain", "certainly", "cf", "cg", "ch", "changes", "ci", "cit", "cj", "cl", "clearly", "cm", "c'mon", "cn", "co", "com", "come", "comes", "con", "concerning", "consequently", "consider", "considering", "contain", "containing", "contains", "corresponding", "could", "couldn", "couldnt", "couldn't", "course", "cp", "cq", "cr", "cry", "cs", "c's", "ct", "cu", "currently", "cv", "cx", "cy", "cz", "d", "d2", "da", "date", "dc", "dd", "de", "definitely", "describe", "described", "despite", "detail", "df", "di", "did", "didn", "didn't", "different", "dj", "dk", "dl", "do", "does", "doesn", "doesn't", "doing", "don", "done", "don't", "down", "downwards", "dp", "dr", "ds", "dt", "du", "due", "during", "dx", "dy", "e", "e2", "e3", "ea", "each", "ec", "ed", "edu", "ee", "ef", "effect", "eg", "ei", "eight", "eighty", "either", "ej", "el", "eleven", "else", "elsewhere", "em", "empty", "en", "end", "ending", "enough", "entirely", "eo", "ep", "eq", "er", "es", "especially", "est", "et", "et-al", "etc", "eu", "ev", "even", "ever", "every", "everybody", "everyone", "everything", "everywhere", "ex", "exactly", "example", "except", "ey", "f", "f2", "fa", "far", "fc", "few", "ff", "fi", "fifteen", "fifth", "fify", "fill", "find", "fire", "first", "five", "fix", "fj", "fl", "fn", "fo", "followed", "following", "follows", "for", "former", "formerly", "forth", "forty", "found", "four", "fr", "from", "front", "fs", "ft", "fu", "full", "further", "furthermore", "fy", "g", "ga", "gave", "ge", "get", "gets", "getting", "gi", "give", "given", "gives", "giving", "gj", "gl", "go", "goes", "going", "gone", "got", "gotten", "gr", "greetings", "gs", "gy", "h", "h2", "h3", "had", "hadn", "hadn't", "happens", "hardly", "has", "hasn", "hasnt", "hasn't", "have", "haven", "haven't", "having", "he", "hed", "he'd", "he'll", "hello", "help", "hence", "her", "here", "hereafter", "hereby", "herein", "heres", "here's", "hereupon", "hers", "herself", "hes", "he's", "hh", "hi", "hid", "him", "himself", "his", "hither", "hj", "ho", "home", "hopefully", "how", "howbeit", "however", "how's", "hr", "hs", "http", "hu", "hundred", "hy", "i", "i2", "i3", "i4", "i6", "i7", "i8", "ia", "ib", "ibid", "ic", "id", "i'd", "ie", "if", "ig", "ignored", "ih", "ii", "ij", "il", "i'll", "im", "i'm", "immediate", "immediately", "importance", "important", "in", "inasmuch", "inc", "indeed", "index", "indicate", "indicated", "indicates", "information", "inner", "insofar", "instead", "interest", "into", "invention", "inward", "io", "ip", "iq", "ir", "is", "isn", "isn't", "it", "itd", "it'd", "it'll", "its", "it's", "itself", "iv", "i've", "ix", "iy", "iz", "j", "jj", "jr", "js", "jt", "ju", "just", "k", "ke", "keep", "keeps", "kept", "kg", "kj", "km", "know", "known", "knows", "ko", "l", "l2", "la", "largely", "last", "lately", "later", "latter", "latterly", "lb", "lc", "le", "least", "les", "less", "lest", "let", "lets", "let's", "lf", "like", "liked", "likely", "line", "little", "lj", "ll", "ll", "ln", "lo", "look", "looking", "looks", "los", "lr", "ls", "lt", "ltd", "m", "m2", "ma", "made", "mainly", "make", "makes", "many", "may", "maybe", "me", "mean", "means", "meantime", "meanwhile", "merely", "mg", "might", "mightn", "mightn't", "mill", "million", "mine", "miss", "ml", "mn", "mo", "more", "moreover", "most", "mostly", "move", "mr", "mrs", "ms", "mt", "mu", "much", "mug", "must", "mustn", "mustn't", "my", "myself", "n", "n2", "na", "name", "namely", "nay", "nc", "nd", "ne", "near", "nearly", "necessarily", "necessary", "need", "needn", "needn't", "needs", "neither", "never", "nevertheless", "new", "next", "ng", "ni", "nine", "ninety", "nj", "nl", "nn", "no", "nobody", "non", "none", "nonetheless", "noone", "nor", "normally", "nos", "not", "noted", "nothing", "novel", "now", "nowhere", "nr", "ns", "nt", "ny", "o", "oa", "ob", "obtain", "obtained", "obviously", "oc", "od", "of", "off", "often", "og", "oh", "oi", "oj", "ok", "okay", "ol", "old", "om", "omitted", "on", "once", "one", "ones", "only", "onto", "oo", "op", "oq", "or", "ord", "os", "ot", "other", "others", "otherwise", "ou", "ought", "our", "ours", "ourselves", "out", "outside", "over", "overall", "ow", "owing", "own", "ox", "oz", "p", "p1", "p2", "p3", "page", "pagecount", "pages", "par", "part", "particular", "particularly", "pas", "past", "pc", "pd", "pe", "per", "perhaps", "pf", "ph", "pi", "pj", "pk", "pl", "placed", "please", "plus", "pm", "pn", "po", "poorly", "possible", "possibly", "potentially", "pp", "pq", "pr", "predominantly", "present", "presumably", "previously", "primarily", "probably", "promptly", "proud", "provides", "ps", "pt", "pu", "put", "py", "q", "qj", "qu", "que", "quickly", "quite", "qv", "r", "r2", "ra", "ran", "rather", "rc", "rd", "re", "readily", "really", "reasonably", "recent", "recently", "ref", "refs", "regarding", "regardless", "regards", "related", "relatively", "research", "research-articl", "respectively", "resulted", "resulting", "results", "rf", "rh", "ri", "right", "rj", "rl", "rm", "rn", "ro", "rq", "rr", "rs", "rt", "ru", "run", "rv", "ry", "s", "s2", "sa", "said", "same", "saw", "say", "saying", "says", "sc", "sd", "se", "sec", "second", "secondly", "section", "see", "seeing", "seem", "seemed", "seeming", "seems", "seen", "self", "selves", "sensible", "sent", "serious", "seriously", "seven", "several", "sf", "shall", "shan", "shan't", "she", "shed", "she'd", "she'll", "shes", "she's", "should", "shouldn", "shouldn't", "should've", "show", "showed", "shown", "showns", "shows", "si", "side", "significant", "significantly", "similar", "similarly", "since", "sincere", "six", "sixty", "sj", "sl", "slightly", "sm", "sn", "so", "some", "somebody", "somehow", "someone", "somethan", "something", "sometime", "sometimes", "somewhat", "somewhere", "soon", "sorry", "sp", "specifically", "specified", "specify", "specifying", "sq", "sr", "ss", "st", "still", "stop", "strongly", "sub", "substantially", "successfully", "such", "sufficiently", "suggest", "sup", "sure", "sy", "system", "sz", "t", "t1", "t2", "t3", "take", "taken", "taking", "tb", "tc", "td", "te", "tell", "ten", "tends", "tf", "th", "than", "thank", "thanks", "thanx", "that", "that'll", "thats", "that's", "that've", "the", "their", "theirs", "them", "themselves", "then", "thence", "there", "thereafter", "thereby", "thered", "therefore", "therein", "there'll", "thereof", "therere", "theres", "there's", "thereto", "thereupon", "there've", "these", "they", "theyd", "they'd", "they'll", "theyre", "they're", "they've", "thickv", "thin", "think", "third", "this", "thorough", "thoroughly", "those", "thou", "though", "thoughh", "thousand", "three", "throug", "through", "throughout", "thru", "thus", "ti", "til", "tip", "tj", "tl", "tm", "tn", "to", "together", "too", "took", "top", "toward", "towards", "tp", "tq", "tr", "tried", "tries", "truly", "try", "trying", "ts", "t's", "tt", "tv", "twelve", "twenty", "twice", "two", "tx", "u", "u201d", "ue", "ui", "uj", "uk", "um", "un", "under", "unfortunately", "unless", "unlike", "unlikely", "until", "unto", "uo", "up", "upon", "ups", "ur", "us", "use", "used", "useful", "usefully", "usefulness", "uses", "using", "usually", "ut", "v", "va", "value", "various", "vd", "ve", "ve", "very", "via", "viz", "vj", "vo", "vol", "vols", "volumtype", "vq", "vs", "vt", "vu", "w", "wa", "want", "wants", "was", "wasn", "wasnt", "wasn't", "way", "we", "wed", "we'd", "welcome", "well", "we'll", "well-b", "went", "were", "we're", "weren", "werent", "weren't", "we've", "what", "whatever", "what'll", "whats", "what's", "when", "whence", "whenever", "when's", "where", "whereafter", "whereas", "whereby", "wherein", "wheres", "where's", "whereupon", "wherever", "whether", "which", "while", "whim", "whither", "who", "whod", "whoever", "whole", "who'll", "whom", "whomever", "whos", "who's", "whose", "why", "why's", "wi", "widely", "will", "willing", "wish", "with", "within", "without", "wo", "won", "wonder", "wont", "won't", "words", "world", "would", "wouldn", "wouldnt", "wouldn't", "www", "x", "x1", "x2", "x3", "xf", "xi", "xj", "xk", "xl", "xn", "xo", "xs", "xt", "xv", "xx", "y", "y2", "yes", "yet", "yj", "yl", "you", "youd", "you'd", "you'll", "your", "youre", "you're", "yours", "yourself", "yourselves", "you've", "yr", "ys", "yt", "z", "zero", "zi", "zz",] 16 | 17 | def detokenize(tk_list): 18 | r_list = [] 19 | for tk in tk_list: 20 | if tk.startswith('##') and len(r_list) > 0: 21 | r_list[-1] = r_list[-1] + tk[2:] 22 | else: 23 | r_list.append(tk) 24 | return r_list 25 | 26 | def ascii_print(text): 27 | text = text.encode("ascii", "ignore") 28 | print(text) 29 | 30 | PRETRAINED_PARAMS = { 31 | "question": { 32 | "file_id": "1JN2wnkSRotwUnJ_Z-AbWwoPdP53Gcfsn", 33 | "fp16": False, 34 | "amp": False, 35 | "model_recover_path": "qg_model.bin", 36 | "max_seq_length": 512, 37 | "max_tgt_length": 48, 38 | "batch_size": 16, 39 | "beam_size": 1, 40 | "length_penalty": 0, 41 | "forbid_duplicate_ngrams": False, 42 | "forbid_ignore_word": None 43 | }, 44 | "summary": { 45 | "file_id": "1RyJxShxC9tDYVAyZwUwqkSoQ3l5DfjuE", 46 | "fp16": True, 47 | "amp": True, 48 | "model_recover_path": "cnndm_model.bin", 49 | "max_seq_length": 768, 50 | "max_tgt_length": 128, 51 | "batch_size": 64, 52 | "beam_size": 5, 53 | "length_penalty": 0, 54 | "forbid_duplicate_ngrams": True, 55 | "forbid_ignore_word": ".|[X_SEP]" 56 | } 57 | } 58 | 59 | class TextGenerator(object): 60 | 61 | def __init__(self, output_type="question", **kwargs): 62 | self.output_type = output_type 63 | self.bert_model = "bert-large-cased" 64 | self.ffn_type = 0 65 | self.num_qkv = 0 66 | self.seg_emb = False 67 | self.split = "test" 68 | self.seed = 123 69 | self.do_lower_case = False 70 | self.new_segment_ids = True 71 | self.new_pos_ids = False 72 | self.min_len = None 73 | self.ngram_size = 3 74 | self.mode = "s2s" 75 | self.s2s_special_token = False 76 | self.s2s_add_segment = False 77 | self.s2s_share_segment = False 78 | self.pos_shift = False 79 | self.not_predict_token = None 80 | self.__dict__.update(PRETRAINED_PARAMS[output_type]) 81 | self.__dict__.update(kwargs) 82 | 83 | if output_type not in ["question","summary"]: 84 | raise ValueError(f'{output_type} unacceptable for output_type. Choose either "question" or "summary".') 85 | 86 | if self.max_tgt_length >= self.max_seq_length - 2: 87 | raise ValueError("Maximum tgt length exceeds max seq length - 2.") 88 | 89 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 90 | n_gpu = torch.cuda.device_count() 91 | 92 | random.seed(self.seed) 93 | np.random.seed(self.seed) 94 | torch.manual_seed(self.seed) 95 | if n_gpu > 0: 96 | torch.cuda.manual_seed_all(self.seed) 97 | 98 | self.tokenizer = BertTokenizer.from_pretrained(self.bert_model, do_lower_case=self.do_lower_case) 99 | 100 | self.tokenizer.max_len = self.max_seq_length 101 | 102 | pair_num_relation = 0 103 | self.bi_uni_pipeline = [] 104 | self.bi_uni_pipeline.append(seq2seq_loader.Preprocess4Seq2seqDecoder(list(self.tokenizer.vocab.keys()), self.tokenizer.convert_tokens_to_ids, self.max_seq_length, max_tgt_length=self.max_tgt_length, new_segment_ids=self.new_segment_ids, 105 | mode="s2s", num_qkv=self.num_qkv, s2s_special_token=self.s2s_special_token, s2s_add_segment=self.s2s_add_segment, s2s_share_segment=self.s2s_share_segment, pos_shift=self.pos_shift)) 106 | 107 | # Prepare model 108 | cls_num_labels = 2 109 | type_vocab_size = 6 + \ 110 | (1 if self.s2s_add_segment else 0) if self.new_segment_ids else 2 111 | mask_word_id, eos_word_ids, sos_word_id = self.tokenizer.convert_tokens_to_ids(["[MASK]", "[SEP]", "[S2S_SOS]"]) 112 | 113 | forbid_ignore_set = self._get_token_id_set(self.forbid_ignore_word) 114 | not_predict_set = self._get_token_id_set(self.not_predict_token) 115 | 116 | self.download_pretrained_model() 117 | 118 | for model_recover_path in glob.glob(self.model_recover_path.strip()): 119 | print("***** Recover model: %s *****", model_recover_path) 120 | map_device = None 121 | if not torch.cuda.is_available(): 122 | map_device='cpu' 123 | #model_recover = torch.load(model_recover_path,map_location=map_device) 124 | #config = BertConfig.from_json_file("/data1/DialoAttack/AgentChat/text2text/config.json") 125 | self.model = BertForSeq2SeqDecoder.from_pretrained("/data1/DialoAttack/AgentChat/text2text/model", num_labels=cls_num_labels, num_rel=pair_num_relation, type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, search_beam_size=self.beam_size, length_penalty=self.length_penalty, eos_id=eos_word_ids, sos_id=sos_word_id, forbid_duplicate_ngrams=self.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, not_predict_set=not_predict_set, ngram_size=self.ngram_size, min_len=self.min_len, mode=self.mode, max_position_embeddings=self.max_seq_length, ffn_type=self.ffn_type, num_qkv=self.num_qkv, seg_emb=self.seg_emb, pos_shift=self.pos_shift) 126 | #self.model.load_state_dict(model_recover) 127 | #del model_recover 128 | 129 | if self.fp16: 130 | self.model.half() 131 | self.model.to(self.device) 132 | if n_gpu > 1: 133 | self.model = torch.nn.DataParallel(self.model) 134 | 135 | torch.cuda.empty_cache() 136 | self.model.eval() 137 | 138 | def download_pretrained_model(self): 139 | if os.path.isfile(self.model_recover_path): 140 | print(f"{self.model_recover_path} found in current directory.") 141 | return 142 | s = requests.session() 143 | file_id = self.file_id 144 | r = s.get(f'https://docs.google.com/uc?export=download&id={file_id}') 145 | confirm_code = r.text.split("/uc?export=download&confirm=")[1].split("&id=")[0] 146 | r = s.get(f'https://docs.google.com/uc?export=download&confirm={confirm_code}&id={file_id}') 147 | z = zipfile.ZipFile(io.BytesIO(r.content)) 148 | z.extractall() 149 | 150 | def _get_token_id_set(self, s): 151 | r = None 152 | if s: 153 | w_list = [] 154 | for w in s.split('|'): 155 | if w.startswith('[') and w.endswith(']'): 156 | w_list.append(w.upper()) 157 | else: 158 | w_list.append(w) 159 | r = set(self.tokenizer.convert_tokens_to_ids(w_list)) 160 | return r 161 | 162 | def _get_answer_tokens(self, tkns): 163 | words = detokenize(tkns) 164 | answers = [] 165 | for w in words: 166 | if len(w) > 1: 167 | if w.lower() not in STOP_WORDS: 168 | answers.append(w) 169 | return self.tokenizer.tokenize(random.choice(answers) if answers else words[0]) 170 | 171 | def predict(self, input_lines, tokenized_input=False): 172 | data_tokenizer = WhitespaceTokenizer() if tokenized_input else self.tokenizer 173 | max_src_length = self.max_seq_length - 2 - self.max_tgt_length 174 | input_lines = [data_tokenizer.tokenize(x)[:max_src_length] for x in input_lines] 175 | 176 | if self.output_type=="question": 177 | input_lines = [x + ["[SEP]"] + self._get_answer_tokens(x) if "[SEP]" not in x else x for x in input_lines] 178 | 179 | input_lines = sorted(list(enumerate(input_lines)), key=lambda x: -len(x[1])) 180 | output_lines = [""] * len(input_lines) 181 | score_trace_list = [None] * len(input_lines) 182 | total_batch = math.ceil(len(input_lines) / self.batch_size) 183 | next_i = 0 184 | with torch.no_grad(): 185 | while next_i < len(input_lines): 186 | _chunk = input_lines[next_i:next_i + self.batch_size] 187 | buf_id = [x[0] for x in _chunk] 188 | buf = [x[1] for x in _chunk] 189 | next_i += self.batch_size 190 | max_a_len = max([len(x) for x in buf]) 191 | instances = [] 192 | for instance in [(x, max_a_len) for x in buf]: 193 | for proc in self.bi_uni_pipeline: 194 | instances.append(proc(instance)) 195 | with torch.no_grad(): 196 | batch = seq2seq_loader.batch_list_to_batch_tensors(instances) 197 | batch = [t.to(self.device) if t is not None else None for t in batch] 198 | input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch 199 | traces = self.model(input_ids, token_type_ids, 200 | position_ids, input_mask, task_idx=task_idx, mask_qkv=mask_qkv) 201 | if self.beam_size > 1: 202 | traces = {k: v.tolist() for k, v in traces.items()} 203 | output_ids = traces['pred_seq'] 204 | else: 205 | output_ids = traces.tolist() 206 | for i in range(len(buf)): 207 | w_ids = output_ids[i] 208 | output_buf = self.tokenizer.convert_ids_to_tokens(w_ids) 209 | output_tokens = [] 210 | for t in output_buf: 211 | if t in ("[SEP]", "[PAD]"): 212 | break 213 | output_tokens.append(t) 214 | output_sequence = ' '.join(detokenize(output_tokens)) 215 | output_sequence = output_sequence.replace(" ' ", "'").replace(" ?", "?") 216 | if self.output_type=="question": 217 | ans_idx = buf[i].index("[SEP]") 218 | corresponding_answer = ' '.join(detokenize(buf[i][ans_idx+1:])) 219 | output_lines[buf_id[i]] = (output_sequence, corresponding_answer) 220 | else: 221 | output_lines[buf_id[i]] = output_sequence 222 | 223 | return output_lines 224 | --------------------------------------------------------------------------------