├── plms ├── chinese_bert_wwm │ ├── added_tokens.json │ ├── tokenizer_config.json │ ├── special_tokens_map.json │ └── config.json ├── chinese_pert_base │ ├── added_tokens.json │ ├── tokenizer_config.json │ ├── special_tokens_map.json │ └── config.json ├── chinese_macbert_base │ ├── added_tokens.json │ ├── special_tokens_map.json │ └── config.json ├── chinese_bert_wwm_ext │ └── config.json ├── chinese_roberta_wwm_ext │ └── config.json └── PCL_MedBERT │ └── config.json ├── data └── README.md ├── LICENSE ├── README.md ├── data_loader_ir.py ├── predict.py ├── trainer.py ├── modeling_ir.py └── modeling_bert.py /plms/chinese_bert_wwm/added_tokens.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /plms/chinese_pert_base/added_tokens.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /plms/chinese_macbert_base/added_tokens.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /plms/chinese_bert_wwm/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"init_inputs": []} -------------------------------------------------------------------------------- /plms/chinese_pert_base/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"init_inputs": []} -------------------------------------------------------------------------------- /plms/chinese_bert_wwm/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"} -------------------------------------------------------------------------------- /plms/chinese_pert_base/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"} -------------------------------------------------------------------------------- /plms/chinese_macbert_base/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"} -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ## 数据集 2 | 3 | 数据申请地址:https://tianchi.aliyun.com/dataset/dataDetail?dataId=95414 4 | 5 | CBLUE挑战榜公开的3,052条数据包括: 6 | 7 | - 1,824条训练数据(IMCS_train.json) 8 | - 616条验证数据(IMCS_dev.json) 9 | - 612条测试数据(IMCS_test.json) 10 | 11 | 请将下载后的数据文件保存在当前路径下 12 | -------------------------------------------------------------------------------- /plms/chinese_bert_wwm/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "directionality": "bidi", 7 | "hidden_act": "gelu", 8 | "hidden_dropout_prob": 0.1, 9 | "hidden_size": 768, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 3072, 12 | "layer_norm_eps": 1e-12, 13 | "max_position_embeddings": 512, 14 | "model_type": "bert", 15 | "num_attention_heads": 12, 16 | "num_hidden_layers": 12, 17 | "output_past": true, 18 | "pad_token_id": 0, 19 | "pooler_fc_size": 768, 20 | "pooler_num_attention_heads": 12, 21 | "pooler_num_fc_layers": 3, 22 | "pooler_size_per_head": 128, 23 | "pooler_type": "first_token_transform", 24 | "type_vocab_size": 2, 25 | "vocab_size": 21128, 26 | "speaker_type_size": 2 27 | } 28 | -------------------------------------------------------------------------------- /plms/chinese_bert_wwm_ext/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "directionality": "bidi", 7 | "hidden_act": "gelu", 8 | "hidden_dropout_prob": 0.1, 9 | "hidden_size": 768, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 3072, 12 | "layer_norm_eps": 1e-12, 13 | "max_position_embeddings": 512, 14 | "model_type": "bert", 15 | "num_attention_heads": 12, 16 | "num_hidden_layers": 12, 17 | "output_past": true, 18 | "pad_token_id": 0, 19 | "pooler_fc_size": 768, 20 | "pooler_num_attention_heads": 12, 21 | "pooler_num_fc_layers": 3, 22 | "pooler_size_per_head": 128, 23 | "pooler_type": "first_token_transform", 24 | "type_vocab_size": 2, 25 | "vocab_size": 21128, 26 | "speaker_type_size": 2 27 | } 28 | -------------------------------------------------------------------------------- /plms/chinese_macbert_base/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "directionality": "bidi", 7 | "gradient_checkpointing": false, 8 | "hidden_act": "gelu", 9 | "hidden_dropout_prob": 0.1, 10 | "hidden_size": 768, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 3072, 13 | "layer_norm_eps": 1e-12, 14 | "max_position_embeddings": 512, 15 | "model_type": "bert", 16 | "num_attention_heads": 12, 17 | "num_hidden_layers": 12, 18 | "pad_token_id": 0, 19 | "pooler_fc_size": 768, 20 | "pooler_num_attention_heads": 12, 21 | "pooler_num_fc_layers": 3, 22 | "pooler_size_per_head": 128, 23 | "pooler_type": "first_token_transform", 24 | "type_vocab_size": 2, 25 | "vocab_size": 21128, 26 | "speaker_type_size": 2 27 | } 28 | -------------------------------------------------------------------------------- /plms/chinese_roberta_wwm_ext/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "bos_token_id": 0, 7 | "directionality": "bidi", 8 | "eos_token_id": 2, 9 | "hidden_act": "gelu", 10 | "hidden_dropout_prob": 0.1, 11 | "hidden_size": 768, 12 | "initializer_range": 0.02, 13 | "intermediate_size": 3072, 14 | "layer_norm_eps": 1e-12, 15 | "max_position_embeddings": 512, 16 | "model_type": "bert", 17 | "num_attention_heads": 12, 18 | "num_hidden_layers": 12, 19 | "output_past": true, 20 | "pad_token_id": 1, 21 | "pooler_fc_size": 768, 22 | "pooler_num_attention_heads": 12, 23 | "pooler_num_fc_layers": 3, 24 | "pooler_size_per_head": 128, 25 | "pooler_type": "first_token_transform", 26 | "type_vocab_size": 2, 27 | "speaker_type_size": 2, 28 | "turn_type_size": 80, 29 | "vocab_size": 21128 30 | } 31 | -------------------------------------------------------------------------------- /plms/chinese_pert_base/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "classifier_dropout": null, 7 | "directionality": "bidi", 8 | "hidden_act": "gelu", 9 | "hidden_dropout_prob": 0.1, 10 | "hidden_size": 768, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 3072, 13 | "layer_norm_eps": 1e-12, 14 | "max_position_embeddings": 512, 15 | "model_type": "bert", 16 | "num_attention_heads": 12, 17 | "num_hidden_layers": 12, 18 | "pad_token_id": 0, 19 | "pooler_fc_size": 768, 20 | "pooler_num_attention_heads": 12, 21 | "pooler_num_fc_layers": 3, 22 | "pooler_size_per_head": 128, 23 | "pooler_type": "first_token_transform", 24 | "position_embedding_type": "absolute", 25 | "transformers_version": "4.16.2", 26 | "type_vocab_size": 2, 27 | "use_cache": true, 28 | "vocab_size": 21128, 29 | "speaker_type_size": 2 30 | } 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 winninghealth 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /plms/PCL_MedBERT/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "bos_token_id": null, 7 | "directionality": "bidi", 8 | "do_sample": false, 9 | "eos_token_ids": null, 10 | "finetuning_task": null, 11 | "hidden_act": "gelu", 12 | "hidden_dropout_prob": 0.1, 13 | "hidden_size": 768, 14 | "id2label": { 15 | "0": "LABEL_0", 16 | "1": "LABEL_1" 17 | }, 18 | "initializer_range": 0.02, 19 | "intermediate_size": 3072, 20 | "is_decoder": false, 21 | "label2id": { 22 | "LABEL_0": 0, 23 | "LABEL_1": 1 24 | }, 25 | "layer_norm_eps": 1e-12, 26 | "length_penalty": 1.0, 27 | "max_length": 20, 28 | "max_position_embeddings": 512, 29 | "model_type": "bert", 30 | "num_attention_heads": 12, 31 | "num_beams": 1, 32 | "num_hidden_layers": 12, 33 | "num_labels": 2, 34 | "num_return_sequences": 1, 35 | "output_attentions": false, 36 | "output_hidden_states": false, 37 | "output_past": true, 38 | "pad_token_id": null, 39 | "pooler_fc_size": 768, 40 | "pooler_num_attention_heads": 12, 41 | "pooler_num_fc_layers": 3, 42 | "pooler_size_per_head": 128, 43 | "pooler_type": "first_token_transform", 44 | "pruned_heads": {}, 45 | "repetition_penalty": 1.0, 46 | "temperature": 1.0, 47 | "top_k": 50, 48 | "top_p": 1.0, 49 | "torchscript": false, 50 | "type_vocab_size": 2, 51 | "use_bfloat16": false, 52 | "vocab_size": 21128, 53 | "speaker_type_size": 2 54 | } 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## CBLUE 智能对话诊疗意图识别 IMCS-IR 2 | 3 | ### 任务背景 4 | 5 | 任务名称:智能对话诊疗数据集-对话意图识别(Intent Recognization) 6 | 7 | 任务简介:针对互联网医患在线对话问诊的记录,该任务的目标是识别出对话的意图。IMCS21数据集中标注了医患对话行为,共包含16类对话意图,标注方式采用句子级标注。任务采用Macro-F1值作为评价指标, 对于测试集中每份对话段落的每条句子,预测其对应的标签。 8 | 9 | 方案思路:https://zhuanlan.zhihu.com/p/501295857 10 | 11 | 方案结果:79.08%(Macro-F1) 12 | 13 | 相关比赛:第一届智能对话诊疗评测比赛(第二十届中国计算语言学大会 CCL2021) 14 | 15 | 比赛官网:http://www.fudan-disc.com/sharedtask/imcs21/index.html 16 | 17 | ### 数据集 18 | 19 | IMCS21数据集由复旦大学大数据学院在复旦大学医学院专家的指导下构建。本次评测任务使用的IMCS-IR数据集在中文医疗信息处理挑战榜CBLUE持续开放下载,地址:https://tianchi.aliyun.com/dataset/dataDetail?dataId=95414 20 | 21 | CBLUE挑战榜公开的3,052条数据包括1,824条训练数据、616条验证数据和612条测试数据。请将下载后的数据保存在文件夹`data`中。 22 | 23 | ### 环境依赖 24 | 25 | - 主要基于 Python (3.7.3+) & AllenNLP 实现 26 | 27 | - 实验使用 GPU :GeForce GTX 1080Ti 28 | 29 | - Python 版本依赖: 30 | 31 | ``` 32 | torch==1.7.1+cu101 33 | transformers==4.4.2 34 | allennlp==2.4.0 35 | ``` 36 | 37 | ### 快速开始 38 | 39 | #### 预训练模型 40 | 41 | 实验中选择了6种不同的开源预训练模型: 42 | 43 | 1. chinese-bert-wwm,下载地址:https://huggingface.co/hfl/chinese-bert-wwm 44 | 2. chinese-bert-wwm-ext,下载地址:https://huggingface.co/hfl/chinese-bert-wwm-ext 45 | 3. chinese-macbert-base,下载地址:https://huggingface.co/hfl/chinese-macbert-base 46 | 4. chinese-roberta-wwm-ext,下载地址:https://huggingface.co/hfl/chinese-roberta-wwm-ext 47 | 5. chinese-pert-base,下载地址:https://huggingface.co/hfl/chinese-pert-base 48 | 6. PCL-MedBERT,下载地址:https://code.ihub.org.cn/projects/1775 49 | 50 | 请将下载后的模型权重`pytorch_model.bin`保存在`plms`路径下相应名称的模型文件夹中。 51 | 52 | #### 模型训练 53 | 54 | ```python 55 | python trainer.py --train_file ./data/IMCS_train.json --dev_file ./data/IMCS_dev.json --pretrained_model_dir ./plms/chinese_bert_wwm --output_model_dir ./save_model/chinese_bert_wwm --cuda_id cuda:0 --batch_size 1 --num_epochs 10 --patience 3 56 | ``` 57 | 58 | - 参数:{train_file}: 训练数据集路径,{dev_file}: 验证数据集路径,{pretrained_model_dir}: 预训练语言模型路径,{output_model_dir}: 模型保存路径 59 | 60 | #### 模型预测 61 | 62 | ```python 63 | python predict.py --test_input_file ./data/IMCS_test.json --test_output_file IMCS-IR_test.json --model_dir ./save_model/chinese_bert_wwm --pretrained_model_dir ./plms/chinese_bert_wwm --cuda_id cuda:0 64 | ``` 65 | 66 | - 参数:{test_input_file}: 测试数据集路径,{test_output_file}: 预测结果输出路径,{model_dir}: 加载已训练模型的路径,{pretrained_model_dir}: 预训练语言模型的路径 67 | 68 | ### 如何引用 69 | 70 | ``` 71 | @Misc{Jiang2022Shared, 72 | author={Yiwen Jiang}, 73 | title={Solutions of Intent Recognization Task within Online Medical Dialogues}, 74 | year={2022}, 75 | howpublished={GitHub}, 76 | url={https://github.com/winninghealth/imcs-ir}, 77 | } 78 | ``` 79 | 80 | ### 版权 81 | 82 | MIT License - 详见 [LICENSE](LICENSE) 83 | 84 | -------------------------------------------------------------------------------- /data_loader_ir.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @author: Yiwen Jiang @Winning Health Group 3 | 4 | import json 5 | import torch 6 | import logging 7 | logger = logging.getLogger(__name__) 8 | 9 | from typing import Dict, List 10 | from overrides import overrides 11 | from transformers import BertTokenizer 12 | 13 | from allennlp.data.instance import Instance 14 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 15 | from allennlp.data.fields import Field, TensorField, LabelField, ListField 16 | 17 | SPECIAL_TOKENS = {'患者':'[unused1]', '医生':'[unused2]'} 18 | SPECIAL_LABELS = {'Other', 'Diagnose'} 19 | 20 | WINDOW = 30 # Decided by my GPU Restriction 21 | 22 | class IntentionRecognitionDatasetReader(DatasetReader): 23 | def __init__(self, transformer_load_path: str, **kwargs, 24 | ) -> None: 25 | super().__init__(**kwargs) 26 | self._transformer_indexers = BertTokenizer.from_pretrained(transformer_load_path) 27 | 28 | @overrides 29 | def _read(self, file_path): 30 | with open(file_path, "r", encoding='utf-8') as file: 31 | data_file = json.load(file) 32 | for eid in data_file.keys(): 33 | dialogue, speaker_ids, intentions, actions = [], [], [], [] 34 | for sid in data_file[eid]['dialogue']: 35 | speaker_ids.append(sid['speaker']) 36 | speaker = [SPECIAL_TOKENS[sid['speaker']]] 37 | utterance = list(sid['sentence']) 38 | utterance = speaker + utterance 39 | dialogue.append(utterance) 40 | if sid['dialogue_act'] not in SPECIAL_LABELS: 41 | intention, action = sid['dialogue_act'].split('-') 42 | else: 43 | intention = sid['dialogue_act'] 44 | action = sid['dialogue_act'] 45 | intentions.append(intention) 46 | actions.append(action) 47 | # If you have sufficient GPU Memory, Put Whole Dialogue in will be better. 48 | for i in range(0, len(dialogue), WINDOW): 49 | y = i + WINDOW 50 | yield self.text_to_instance(dialogue[i:y], 51 | speaker_ids[i:y], 52 | intentions[i:y], 53 | actions[i:y]) 54 | if y >= len(dialogue): 55 | break 56 | 57 | def text_to_instance( 58 | self, 59 | dialogue: List[List[str]], 60 | speaker_ids: List[str], 61 | intentions: List[str] = None, 62 | actions: List[str] = None, 63 | ) -> Instance: 64 | fields: Dict[str, Field] = {} 65 | dialogue = [['[CLS]'] + utterance + ['[SEP]'] for utterance in dialogue] 66 | dialogue_field = [self._transformer_indexers.convert_tokens_to_ids(utterance) for utterance in dialogue] 67 | dialogue_field = [TensorField(torch.tensor(u)) for u in dialogue_field] 68 | fields["dialogue"] = ListField(dialogue_field) 69 | speaker_field = [LabelField(speaker, label_namespace='speaker_labels') for speaker in speaker_ids] 70 | fields["speaker"] = ListField(speaker_field) 71 | if intentions != None: 72 | intents_field = [LabelField(intention, label_namespace='intention_labels') for intention in intentions] 73 | fields["intentions"] = ListField(intents_field) 74 | if actions != None: 75 | actions_field = [LabelField(action, label_namespace='action_labels') for action in actions] 76 | fields["actions"] = ListField(actions_field) 77 | return Instance(fields) 78 | 79 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @author: Yiwen Jiang @Winning Health Group 3 | 4 | import os 5 | import json 6 | import torch 7 | import logging 8 | import argparse 9 | 10 | from tqdm import tqdm 11 | from overrides import overrides 12 | from allennlp.common.util import JsonDict 13 | from allennlp.data import DatasetReader, Instance, Vocabulary 14 | from allennlp.models import Model 15 | from allennlp.predictors.predictor import Predictor 16 | 17 | from trainer import build_model 18 | from transformers import BertTokenizer 19 | from data_loader_ir import IntentionRecognitionDatasetReader, SPECIAL_TOKENS, SPECIAL_LABELS 20 | 21 | def init_logger(): 22 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 23 | datefmt='%m/%d/%Y %H:%M:%S', 24 | level=logging.INFO) 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | class IRPredictor(Predictor): 29 | def __init__(self, 30 | model: Model, 31 | dataset_reader: DatasetReader, 32 | transformer_load_path : str 33 | ) -> None: 34 | super().__init__(model, dataset_reader) 35 | self.vocab = model.vocab 36 | self._transformer_indexers = BertTokenizer.from_pretrained(transformer_load_path) 37 | 38 | def predict(self, dialogue, speaker_ids) -> JsonDict: 39 | result = self.predict_json({"dialogue": dialogue, "speaker": speaker_ids}) 40 | instances = dict() 41 | instances['actions'] = [self.vocab.get_token_from_index(i,namespace='action_labels') for i in result['actions']] 42 | instances['intentions'] = [self.vocab.get_token_from_index(i,namespace='intention_labels') for i in result['intentions']] 43 | return instances 44 | 45 | @overrides 46 | def _json_to_instance(self, json_dict: JsonDict) -> Instance: 47 | dialogue = json_dict["dialogue"] 48 | speaker_ids = json_dict["speaker"] 49 | return self._dataset_reader.text_to_instance(dialogue, speaker_ids) 50 | 51 | def read_input_file(input_path): 52 | eids, dialogues, speaker_ids = [], [], [] 53 | with open(input_path, 'r', encoding='utf-8') as f: 54 | data = json.load(f) 55 | for k in data.keys(): 56 | dialogue, speaker_id = [], [] 57 | for sent in data[k]['dialogue']: 58 | speaker = [SPECIAL_TOKENS[sent['speaker']]] 59 | utterance = list(sent['sentence']) 60 | utterance = speaker + utterance 61 | dialogue.append(utterance) 62 | speaker_id.append(sent['speaker']) 63 | eids.append(k) 64 | dialogues.append(dialogue) 65 | speaker_ids.append(speaker_id) 66 | return eids, dialogues, speaker_ids 67 | 68 | def predict(pred_config): 69 | serialization_dir = pred_config.model_dir 70 | vocabulary_dir = os.path.join(serialization_dir, "vocabulary") 71 | vocab = Vocabulary.from_files(vocabulary_dir) 72 | 73 | model_dir = os.path.join(serialization_dir, pred_config.model_name) 74 | model = build_model(vocab, pred_config.pretrained_model_dir, pred_config.pretrained_hidden_size) 75 | device = torch.device(pred_config.cuda_id if torch.cuda.is_available() else "cpu") 76 | model.load_state_dict(torch.load(model_dir, map_location=device)) 77 | model = model.to(device) 78 | 79 | dataset_reader = IntentionRecognitionDatasetReader(transformer_load_path=pred_config.pretrained_model_dir) 80 | predictor = IRPredictor(model=model, 81 | dataset_reader=dataset_reader, 82 | transformer_load_path=pred_config.pretrained_model_dir) 83 | 84 | eids, dialogues, speaker_ids = read_input_file(os.path.join(pred_config.test_input_file)) 85 | predict_result = dict() 86 | predict_subres = dict() 87 | for i in tqdm(range(0, len(eids))): 88 | predict_result[eids[i]] = dict() 89 | predict_subres[eids[i]] = dict() 90 | result = predictor.predict(dialogues[i], speaker_ids[i]) 91 | assert len(result['actions']) == len(result['intentions']) 92 | for idx, j in enumerate(zip(result['actions'], result['intentions'])): 93 | _act, _int = j 94 | if _act not in SPECIAL_LABELS and _int not in SPECIAL_LABELS: 95 | res = _int + '-' + _act 96 | else: 97 | res = _act if _act in SPECIAL_LABELS else _int 98 | # Restore BAD index from original test file 99 | if eids[i] == '10708561' and int(idx+1) >= 43: 100 | predict_result[eids[i]][str(idx+2)] = res 101 | predict_subres[eids[i]][str(idx+2)] = dict() 102 | predict_subres[eids[i]][str(idx+2)]['act'] = _act 103 | predict_subres[eids[i]][str(idx+2)]['int'] = _int 104 | else: 105 | predict_result[eids[i]][str(idx+1)] = res 106 | predict_subres[eids[i]][str(idx+1)] = dict() 107 | predict_subres[eids[i]][str(idx+1)]['act'] = _act 108 | predict_subres[eids[i]][str(idx+1)]['int'] = _int 109 | pred_path = os.path.join(pred_config.test_output_file) 110 | with open(pred_path, 'w', encoding='utf-8') as json_file: 111 | json.dump(predict_result, json_file, ensure_ascii=False, indent=4) 112 | pred_path_sub = os.path.join(pred_config.test_output_file + '.sub') 113 | with open(pred_path_sub, 'w', encoding='utf-8') as json_file: 114 | json.dump(predict_subres, json_file, ensure_ascii=False, indent=4) 115 | logger.info("Prediction Done!") 116 | 117 | if __name__ == "__main__": 118 | init_logger() 119 | parser = argparse.ArgumentParser() 120 | 121 | parser.add_argument("--test_input_file", default="./data/IMCS_test.json", type=str) 122 | parser.add_argument("--test_output_file", default="IMCS-IR_test.json", type=str) 123 | parser.add_argument("--model_dir", default="./save_model", type=str) 124 | parser.add_argument("--model_name", default="best.th", type=str) 125 | parser.add_argument("--pretrained_model_dir", default="./plms", type=str) 126 | parser.add_argument("--pretrained_hidden_size", default=768, type=int) 127 | parser.add_argument("--cuda_id", default='cuda:0', type=str) 128 | 129 | pred_config = parser.parse_args() 130 | predict(pred_config) 131 | 132 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @author: Yiwen Jiang @Winning Health Group 3 | 4 | from allennlp.common.params import Params 5 | from allennlp.common.util import prepare_environment 6 | prepare_environment(Params({"random_seed":1000, "numpy_seed":2000, "pytorch_seed":3000})) 7 | 8 | import os 9 | import torch 10 | import logging 11 | import argparse 12 | 13 | from allennlp.models.model import Model 14 | from allennlp.data import DataLoader, Vocabulary 15 | from allennlp.training.checkpointer import Checkpointer 16 | from allennlp.training.trainer import GradientDescentTrainer, Trainer 17 | from allennlp.data.data_loaders.multiprocess_data_loader import MultiProcessDataLoader 18 | from allennlp.modules.seq2seq_encoders.pytorch_seq2seq_wrapper import LstmSeq2SeqEncoder 19 | from allennlp.training.learning_rate_schedulers.linear_with_warmup import LinearWithWarmup 20 | 21 | from transformers.optimization import AdamW 22 | 23 | from modeling_ir import IntentionLabelTagger 24 | from data_loader_ir import IntentionRecognitionDatasetReader 25 | 26 | 27 | def init_logger(): 28 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 29 | datefmt='%m/%d/%Y %H:%M:%S', 30 | level=logging.INFO) 31 | 32 | def build_vocab(instances): 33 | return Vocabulary.from_instances(instances) 34 | 35 | def build_model(vocab: Vocabulary, 36 | transformer_load_path: str, 37 | pretrained_hidden_size: int) -> Model: 38 | lstmencoder = LstmSeq2SeqEncoder(input_size=pretrained_hidden_size, 39 | hidden_size=128, 40 | num_layers=1, 41 | bidirectional=True) 42 | return IntentionLabelTagger(vocab=vocab, 43 | dialogue_encoder=lstmencoder, 44 | transformer_load_path=transformer_load_path, 45 | dropout=0.1) 46 | 47 | def build_trainer(model: Model, 48 | train_loader: DataLoader, 49 | dev_loader: DataLoader, 50 | serialization_dir: str, 51 | cuda_device: torch.device, 52 | num_epochs: int, 53 | patience: int) -> Trainer: 54 | 55 | no_bigger = ["dialogue_encoder", "crf_act", "crf_int", 56 | "act_decoder", "intent_decoder"] 57 | 58 | parameter_groups = [ 59 | { 60 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_bigger)], 61 | "weight_decay": 0.0, 62 | }, 63 | { 64 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_bigger)], 65 | "lr": 0.0001 66 | } 67 | ] 68 | optimizer = AdamW(parameter_groups, lr=1e-5, eps=1e-8) 69 | lrschedule = LinearWithWarmup(optimizer=optimizer, 70 | num_epochs=num_epochs, 71 | num_steps_per_epoch=len(train_loader), 72 | warmup_steps=150) 73 | 74 | ckp = Checkpointer(serialization_dir=serialization_dir, 75 | num_serialized_models_to_keep=-1) 76 | 77 | trainer = GradientDescentTrainer(model=model, 78 | optimizer=optimizer, 79 | data_loader=train_loader, 80 | patience=patience, 81 | validation_data_loader=dev_loader, 82 | validation_metric='+Macro-F1', 83 | num_epochs=num_epochs, 84 | serialization_dir=serialization_dir, 85 | cuda_device=cuda_device if str(cuda_device) != 'cpu' else -1, 86 | learning_rate_scheduler=lrschedule, 87 | num_gradient_accumulation_steps=1, 88 | checkpointer=ckp) 89 | 90 | return trainer 91 | 92 | def run_training_loop(config): 93 | 94 | serialization_dir = config.output_model_dir 95 | vocabulary_dir = os.path.join(serialization_dir, "vocabulary") 96 | os.makedirs(serialization_dir, exist_ok=True) 97 | 98 | dataset_reader = IntentionRecognitionDatasetReader(transformer_load_path=config.pretrained_model_dir) 99 | train_path = config.train_file 100 | dev_path = config.dev_file 101 | train_data = list(dataset_reader.read(train_path)) 102 | dev_data = list(dataset_reader.read(dev_path)) 103 | vocab = build_vocab(train_data + dev_data) 104 | vocab.save_to_files(vocabulary_dir) 105 | 106 | train_loader = MultiProcessDataLoader(dataset_reader, train_path, batch_size=config.batch_size, shuffle=True) 107 | dev_loader = MultiProcessDataLoader(dataset_reader, dev_path, batch_size=config.batch_size, shuffle=False) 108 | train_loader.index_with(vocab) 109 | dev_loader.index_with(vocab) 110 | 111 | device = torch.device(config.cuda_id if torch.cuda.is_available() else "cpu") 112 | model = build_model(vocab, config.pretrained_model_dir, config.pretrained_hidden_size) 113 | model = model.to(device) 114 | 115 | trainer = build_trainer(model, 116 | train_loader, 117 | dev_loader, 118 | serialization_dir, 119 | device, 120 | config.num_epochs, 121 | config.patience) 122 | trainer.train() 123 | return trainer 124 | 125 | if __name__ == '__main__': 126 | init_logger() 127 | parser = argparse.ArgumentParser() 128 | 129 | parser.add_argument("--train_file", default='./data/IMCS_train.json', type=str) 130 | parser.add_argument("--dev_file", default='./data/IMCS_dev.json', type=str) 131 | parser.add_argument("--output_model_dir", default='./save_model', type=str) 132 | parser.add_argument("--pretrained_model_dir", default='./plms', type=str) 133 | parser.add_argument("--pretrained_hidden_size", default=768, type=int) 134 | parser.add_argument("--cuda_id", default='cuda:0', type=str) 135 | 136 | parser.add_argument("--batch_size", default=1, type=int) 137 | parser.add_argument("--num_epochs", default=10, type=int) 138 | parser.add_argument("--patience", default=3, type=int) 139 | 140 | config = parser.parse_args() 141 | run_training_loop(config) 142 | 143 | -------------------------------------------------------------------------------- /modeling_ir.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @author: Yiwen Jiang @Winning Health Group 3 | 4 | import torch 5 | import torch.nn as nn 6 | from overrides import overrides 7 | from modeling_bert import BertModel 8 | from typing import Dict, Optional, cast, List 9 | 10 | from allennlp.data import Vocabulary 11 | from allennlp.models.model import Model 12 | from allennlp.nn import InitializerApplicator 13 | from allennlp.training.metrics import FBetaMeasure 14 | from allennlp.modules import Seq2SeqEncoder, ConditionalRandomField 15 | 16 | class ClassificationHead(nn.Module): 17 | def __init__( 18 | self, 19 | input_dim: int, 20 | inner_dim: int, 21 | num_classes: int, 22 | pooler_dropout: float, 23 | ): 24 | super().__init__() 25 | self.dense = nn.Linear(input_dim, inner_dim) 26 | self.dropout = nn.Dropout(p=pooler_dropout) 27 | self.out_proj = nn.Linear(inner_dim, num_classes) 28 | 29 | def forward(self, hidden_states: torch.Tensor): 30 | hidden_states = self.dropout(hidden_states) 31 | hidden_states = self.dense(hidden_states) 32 | hidden_states = torch.tanh(hidden_states) 33 | hidden_states = self.dropout(hidden_states) 34 | hidden_states = self.out_proj(hidden_states) 35 | return hidden_states 36 | 37 | class IntentionLabelTagger(Model): 38 | def __init__( 39 | self, 40 | vocab: Vocabulary, 41 | transformer_load_path: str, 42 | dialogue_encoder: Seq2SeqEncoder, 43 | dropout: Optional[float] = None, 44 | initializer: InitializerApplicator = InitializerApplicator(), 45 | **kwargs, 46 | ) -> None: 47 | super().__init__(vocab, **kwargs) 48 | self.utterance_encoder = BertModel.from_pretrained(transformer_load_path) 49 | self.dialogue_encoder = dialogue_encoder 50 | self.act_decoder = ClassificationHead(input_dim=self.dialogue_encoder.get_output_dim(), 51 | inner_dim=self.dialogue_encoder.get_output_dim(), 52 | num_classes=self.vocab.get_vocab_size('action_labels'), 53 | pooler_dropout=0.3) 54 | self.intent_decoder = ClassificationHead(input_dim=self.dialogue_encoder.get_output_dim(), 55 | inner_dim=self.dialogue_encoder.get_output_dim(), 56 | num_classes=self.vocab.get_vocab_size('intention_labels'), 57 | pooler_dropout=0.3) 58 | self.speaker_embeds = self.utterance_encoder.embeddings.speaker_embeddings 59 | self.dropout = torch.nn.Dropout(dropout) if dropout else None 60 | self.calculate_f1_act = { 61 | "F1-macro-act": FBetaMeasure(average='macro'), 62 | "F1-class-act": FBetaMeasure(average=None) 63 | } 64 | self.calculate_f1_int = { 65 | "F1-macro-int": FBetaMeasure(average='macro'), 66 | "F1-class-int": FBetaMeasure(average=None) 67 | } 68 | self.crf_act = ConditionalRandomField(self.vocab.get_vocab_size('action_labels')) 69 | self.crf_int = ConditionalRandomField(self.vocab.get_vocab_size('intention_labels')) 70 | initializer(self) 71 | 72 | @overrides 73 | def forward(self, dialogue, speaker, intentions = None, actions = None, **kwargs,): 74 | batch_size, utter_len, seq_len = dialogue.shape 75 | dialogue = dialogue.reshape(batch_size * utter_len, seq_len) 76 | ''' 77 | utterance feature 78 | ''' 79 | speaker_ids = torch.repeat_interleave(speaker, seq_len, -1) 80 | speaker_ids = speaker_ids.reshape(batch_size * utter_len, seq_len) 81 | encoded_utterance = self.utterance_encoder(input_ids=dialogue, 82 | attention_mask=dialogue != 0, 83 | speaker_ids=torch.clamp(speaker_ids,min=0), 84 | use_cache=True, 85 | return_dict=True)['last_hidden_state'] 86 | encoded_utterance = encoded_utterance.reshape(batch_size, utter_len, seq_len, -1) 87 | encoded_utterance = encoded_utterance[:,:,0,:] 88 | encoded_utterance = self.dropout(encoded_utterance) if self.dropout else encoded_utterance 89 | ''' 90 | dialogue feature 91 | ''' 92 | speaker_embeds = self.speaker_embeds(speaker) 93 | encoded_utterance = encoded_utterance + speaker_embeds 94 | encoded_dialogue = self.dialogue_encoder(encoded_utterance, None) 95 | encoded_dialogue = self.dropout(encoded_dialogue) if self.dropout else encoded_dialogue 96 | # [batch_size, utterance_number, utterance_embedding_size] 97 | ''' 98 | decoder 99 | ''' 100 | encoded_dialogue_act = self.act_decoder(encoded_dialogue) 101 | encoded_dialogue_int = self.intent_decoder(encoded_dialogue) 102 | ''' 103 | metric 104 | ''' 105 | output = dict() 106 | ''' 107 | actions metric 108 | ''' 109 | labels_mask = speaker != -1 110 | best_paths_act = self.crf_act.viterbi_tags(encoded_dialogue_act,labels_mask,top_k=1) 111 | predicted_acts = cast(List[List[int]], [x[0][0] for x in best_paths_act]) 112 | output['actions'] = predicted_acts 113 | if actions != None: 114 | class_probabilities = encoded_dialogue_act * 0.0 115 | for i, instance_tags in enumerate(predicted_acts): 116 | for j, tag_id in enumerate(instance_tags): 117 | class_probabilities[i, j, tag_id] = 1 118 | self.calculate_f1_act['F1-macro-act'](class_probabilities, actions, labels_mask) 119 | self.calculate_f1_act['F1-class-act'](class_probabilities, actions, labels_mask) 120 | ''' 121 | intentions metric 122 | ''' 123 | best_paths_int = self.crf_int.viterbi_tags(encoded_dialogue_int,labels_mask,top_k=1) 124 | predicted_ints = cast(List[List[int]], [x[0][0] for x in best_paths_int]) 125 | output['intentions'] = predicted_ints 126 | if intentions != None: 127 | class_probabilities = encoded_dialogue_int * 0.0 128 | for i, instance_tags in enumerate(predicted_ints): 129 | for j, tag_id in enumerate(instance_tags): 130 | class_probabilities[i, j, tag_id] = 1 131 | self.calculate_f1_int['F1-macro-int'](class_probabilities, intentions, labels_mask) 132 | self.calculate_f1_int['F1-class-int'](class_probabilities, intentions, labels_mask) 133 | ''' 134 | loss 135 | ''' 136 | if actions != None and intentions != None: 137 | log_likelihood_act = self.crf_act(encoded_dialogue_act, actions, labels_mask) 138 | log_likelihood_int = self.crf_int(encoded_dialogue_int, intentions, labels_mask) 139 | output["loss"] = (-log_likelihood_act) + (-log_likelihood_int) 140 | return output 141 | 142 | @overrides 143 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 144 | metrics_to_return = dict() 145 | ''' 146 | actions metric 147 | ''' 148 | act_macro = self.calculate_f1_act['F1-macro-act'].get_metric(reset) 149 | act_class = self.calculate_f1_act['F1-class-act'].get_metric(reset) 150 | metrics_to_return['Macro-act-f1'] = act_macro['fscore'] 151 | idx2label = self.vocab.get_index_to_token_vocabulary(namespace='action_labels') 152 | for idx in range(len(act_class['fscore'])): 153 | lc= idx2label[idx] 154 | metrics_to_return[lc+'-act-f1'] = act_class['fscore'][idx] 155 | ''' 156 | intentions metric 157 | ''' 158 | int_macro = self.calculate_f1_int['F1-macro-int'].get_metric(reset) 159 | int_class = self.calculate_f1_int['F1-class-int'].get_metric(reset) 160 | metrics_to_return['Macro-int-f1'] = int_macro['fscore'] 161 | idx2label = self.vocab.get_index_to_token_vocabulary(namespace='intention_labels') 162 | for idx in range(len(int_class['fscore'])): 163 | lc= idx2label[idx] 164 | metrics_to_return[lc+'-int-f1'] = int_class['fscore'][idx] 165 | ''' 166 | average 167 | ''' 168 | metrics_to_return['Macro-F1'] = (act_macro['fscore'] + int_macro['fscore']) / 2 169 | return metrics_to_return 170 | 171 | -------------------------------------------------------------------------------- /modeling_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model. """ 17 | 18 | 19 | import math 20 | import os 21 | import warnings 22 | from dataclasses import dataclass 23 | from typing import Optional, Tuple 24 | 25 | import torch 26 | import torch.utils.checkpoint 27 | from torch import nn 28 | from torch.nn import CrossEntropyLoss, MSELoss 29 | 30 | from transformers.activations import ACT2FN 31 | from transformers.file_utils import ( 32 | ModelOutput, 33 | add_code_sample_docstrings, 34 | add_start_docstrings, 35 | add_start_docstrings_to_model_forward, 36 | replace_return_docstrings, 37 | ) 38 | from transformers.modeling_outputs import ( 39 | BaseModelOutputWithPastAndCrossAttentions, 40 | BaseModelOutputWithPoolingAndCrossAttentions, 41 | CausalLMOutputWithCrossAttentions, 42 | MaskedLMOutput, 43 | MultipleChoiceModelOutput, 44 | NextSentencePredictorOutput, 45 | QuestionAnsweringModelOutput, 46 | SequenceClassifierOutput, 47 | TokenClassifierOutput, 48 | ) 49 | from transformers.modeling_utils import ( 50 | PreTrainedModel, 51 | apply_chunking_to_forward, 52 | find_pruneable_heads_and_indices, 53 | prune_linear_layer, 54 | ) 55 | from transformers.utils import logging 56 | from transformers.models.bert.configuration_bert import BertConfig 57 | 58 | 59 | logger = logging.get_logger(__name__) 60 | 61 | _CHECKPOINT_FOR_DOC = "bert-base-uncased" 62 | _CONFIG_FOR_DOC = "BertConfig" 63 | _TOKENIZER_FOR_DOC = "BertTokenizer" 64 | 65 | BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ 66 | "bert-base-uncased", 67 | "bert-large-uncased", 68 | "bert-base-cased", 69 | "bert-large-cased", 70 | "bert-base-multilingual-uncased", 71 | "bert-base-multilingual-cased", 72 | "bert-base-chinese", 73 | "bert-base-german-cased", 74 | "bert-large-uncased-whole-word-masking", 75 | "bert-large-cased-whole-word-masking", 76 | "bert-large-uncased-whole-word-masking-finetuned-squad", 77 | "bert-large-cased-whole-word-masking-finetuned-squad", 78 | "bert-base-cased-finetuned-mrpc", 79 | "bert-base-german-dbmdz-cased", 80 | "bert-base-german-dbmdz-uncased", 81 | "cl-tohoku/bert-base-japanese", 82 | "cl-tohoku/bert-base-japanese-whole-word-masking", 83 | "cl-tohoku/bert-base-japanese-char", 84 | "cl-tohoku/bert-base-japanese-char-whole-word-masking", 85 | "TurkuNLP/bert-base-finnish-cased-v1", 86 | "TurkuNLP/bert-base-finnish-uncased-v1", 87 | "wietsedv/bert-base-dutch-cased", 88 | # See all BERT models at https://huggingface.co/models?filter=bert 89 | ] 90 | 91 | 92 | def load_tf_weights_in_bert(model, config, tf_checkpoint_path): 93 | """Load tf checkpoints in a pytorch model.""" 94 | try: 95 | import re 96 | 97 | import numpy as np 98 | import tensorflow as tf 99 | except ImportError: 100 | logger.error( 101 | "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " 102 | "https://www.tensorflow.org/install/ for installation instructions." 103 | ) 104 | raise 105 | tf_path = os.path.abspath(tf_checkpoint_path) 106 | logger.info(f"Converting TensorFlow checkpoint from {tf_path}") 107 | # Load weights from TF model 108 | init_vars = tf.train.list_variables(tf_path) 109 | names = [] 110 | arrays = [] 111 | for name, shape in init_vars: 112 | logger.info(f"Loading TF weight {name} with shape {shape}") 113 | array = tf.train.load_variable(tf_path, name) 114 | names.append(name) 115 | arrays.append(array) 116 | 117 | for name, array in zip(names, arrays): 118 | name = name.split("/") 119 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v 120 | # which are not required for using pretrained model 121 | if any( 122 | n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] 123 | for n in name 124 | ): 125 | logger.info(f"Skipping {'/'.join(name)}") 126 | continue 127 | pointer = model 128 | for m_name in name: 129 | if re.fullmatch(r"[A-Za-z]+_\d+", m_name): 130 | scope_names = re.split(r"_(\d+)", m_name) 131 | else: 132 | scope_names = [m_name] 133 | if scope_names[0] == "kernel" or scope_names[0] == "gamma": 134 | pointer = getattr(pointer, "weight") 135 | elif scope_names[0] == "output_bias" or scope_names[0] == "beta": 136 | pointer = getattr(pointer, "bias") 137 | elif scope_names[0] == "output_weights": 138 | pointer = getattr(pointer, "weight") 139 | elif scope_names[0] == "squad": 140 | pointer = getattr(pointer, "classifier") 141 | else: 142 | try: 143 | pointer = getattr(pointer, scope_names[0]) 144 | except AttributeError: 145 | logger.info(f"Skipping {'/'.join(name)}") 146 | continue 147 | if len(scope_names) >= 2: 148 | num = int(scope_names[1]) 149 | pointer = pointer[num] 150 | if m_name[-11:] == "_embeddings": 151 | pointer = getattr(pointer, "weight") 152 | elif m_name == "kernel": 153 | array = np.transpose(array) 154 | try: 155 | assert ( 156 | pointer.shape == array.shape 157 | ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" 158 | except AssertionError as e: 159 | e.args += (pointer.shape, array.shape) 160 | raise 161 | logger.info(f"Initialize PyTorch weight {name}") 162 | pointer.data = torch.from_numpy(array) 163 | return model 164 | 165 | 166 | class BertEmbeddings(nn.Module): 167 | """Construct the embeddings from word, position and token_type embeddings.""" 168 | 169 | def __init__(self, config): 170 | super().__init__() 171 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) 172 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 173 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 174 | self.speaker_embeddings = nn.Embedding(config.speaker_type_size, config.hidden_size) 175 | 176 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 177 | # any TensorFlow checkpoint file 178 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 179 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 180 | 181 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized 182 | self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) 183 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") 184 | 185 | def forward( 186 | self, input_ids=None, token_type_ids=None, position_ids=None, speaker_ids=None, inputs_embeds=None, past_key_values_length=0 187 | ): 188 | if input_ids is not None: 189 | input_shape = input_ids.size() 190 | else: 191 | input_shape = inputs_embeds.size()[:-1] 192 | 193 | seq_length = input_shape[1] 194 | 195 | if position_ids is None: 196 | position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] 197 | 198 | if token_type_ids is None: 199 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) 200 | 201 | if inputs_embeds is None: 202 | inputs_embeds = self.word_embeddings(input_ids) 203 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 204 | 205 | speaker_type_embeddings = self.speaker_embeddings(speaker_ids) 206 | 207 | embeddings = inputs_embeds + token_type_embeddings 208 | if self.position_embedding_type == "absolute": 209 | position_embeddings = self.position_embeddings(position_ids) 210 | embeddings += position_embeddings 211 | embeddings += speaker_type_embeddings 212 | embeddings = self.LayerNorm(embeddings) 213 | embeddings = self.dropout(embeddings) 214 | return embeddings 215 | 216 | 217 | class BertSelfAttention(nn.Module): 218 | def __init__(self, config): 219 | super().__init__() 220 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): 221 | raise ValueError( 222 | f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " 223 | f"heads ({config.num_attention_heads})" 224 | ) 225 | 226 | self.num_attention_heads = config.num_attention_heads 227 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 228 | self.all_head_size = self.num_attention_heads * self.attention_head_size 229 | 230 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 231 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 232 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 233 | 234 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 235 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") 236 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 237 | self.max_position_embeddings = config.max_position_embeddings 238 | self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) 239 | 240 | self.is_decoder = config.is_decoder 241 | 242 | def transpose_for_scores(self, x): 243 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 244 | x = x.view(*new_x_shape) 245 | return x.permute(0, 2, 1, 3) 246 | 247 | def forward( 248 | self, 249 | hidden_states, 250 | attention_mask=None, 251 | head_mask=None, 252 | encoder_hidden_states=None, 253 | encoder_attention_mask=None, 254 | past_key_value=None, 255 | output_attentions=False, 256 | ): 257 | mixed_query_layer = self.query(hidden_states) 258 | 259 | # If this is instantiated as a cross-attention module, the keys 260 | # and values come from an encoder; the attention mask needs to be 261 | # such that the encoder's padding tokens are not attended to. 262 | is_cross_attention = encoder_hidden_states is not None 263 | 264 | if is_cross_attention and past_key_value is not None: 265 | # reuse k,v, cross_attentions 266 | key_layer = past_key_value[0] 267 | value_layer = past_key_value[1] 268 | attention_mask = encoder_attention_mask 269 | elif is_cross_attention: 270 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) 271 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) 272 | attention_mask = encoder_attention_mask 273 | elif past_key_value is not None: 274 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 275 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 276 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2) 277 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2) 278 | else: 279 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 280 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 281 | 282 | query_layer = self.transpose_for_scores(mixed_query_layer) 283 | 284 | if self.is_decoder: 285 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 286 | # Further calls to cross_attention layer can then reuse all cross-attention 287 | # key/value_states (first "if" case) 288 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 289 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 290 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 291 | # if encoder bi-directional self-attention `past_key_value` is always `None` 292 | past_key_value = (key_layer, value_layer) 293 | 294 | # Take the dot product between "query" and "key" to get the raw attention scores. 295 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 296 | 297 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 298 | seq_length = hidden_states.size()[1] 299 | position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) 300 | position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) 301 | distance = position_ids_l - position_ids_r 302 | positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) 303 | positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility 304 | 305 | if self.position_embedding_type == "relative_key": 306 | relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 307 | attention_scores = attention_scores + relative_position_scores 308 | elif self.position_embedding_type == "relative_key_query": 309 | relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 310 | relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) 311 | attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key 312 | 313 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 314 | if attention_mask is not None: 315 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 316 | attention_scores = attention_scores + attention_mask 317 | 318 | # Normalize the attention scores to probabilities. 319 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 320 | 321 | # This is actually dropping out entire tokens to attend to, which might 322 | # seem a bit unusual, but is taken from the original Transformer paper. 323 | attention_probs = self.dropout(attention_probs) 324 | 325 | # Mask heads if we want to 326 | if head_mask is not None: 327 | attention_probs = attention_probs * head_mask 328 | 329 | context_layer = torch.matmul(attention_probs.float(), value_layer) 330 | 331 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 332 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 333 | context_layer = context_layer.view(*new_context_layer_shape) 334 | 335 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) 336 | 337 | if self.is_decoder: 338 | outputs = outputs + (past_key_value,) 339 | return outputs 340 | 341 | 342 | class BertSelfOutput(nn.Module): 343 | def __init__(self, config): 344 | super().__init__() 345 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 346 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 347 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 348 | 349 | def forward(self, hidden_states, input_tensor): 350 | hidden_states = self.dense(hidden_states) 351 | hidden_states = self.dropout(hidden_states) 352 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 353 | return hidden_states 354 | 355 | 356 | class BertAttention(nn.Module): 357 | def __init__(self, config): 358 | super().__init__() 359 | self.self = BertSelfAttention(config) 360 | self.output = BertSelfOutput(config) 361 | self.pruned_heads = set() 362 | 363 | def prune_heads(self, heads): 364 | if len(heads) == 0: 365 | return 366 | heads, index = find_pruneable_heads_and_indices( 367 | heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads 368 | ) 369 | 370 | # Prune linear layers 371 | self.self.query = prune_linear_layer(self.self.query, index) 372 | self.self.key = prune_linear_layer(self.self.key, index) 373 | self.self.value = prune_linear_layer(self.self.value, index) 374 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) 375 | 376 | # Update hyper params and store pruned heads 377 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads) 378 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads 379 | self.pruned_heads = self.pruned_heads.union(heads) 380 | 381 | def forward( 382 | self, 383 | hidden_states, 384 | attention_mask=None, 385 | head_mask=None, 386 | encoder_hidden_states=None, 387 | encoder_attention_mask=None, 388 | past_key_value=None, 389 | output_attentions=False, 390 | ): 391 | self_outputs = self.self( 392 | hidden_states, 393 | attention_mask, 394 | head_mask, 395 | encoder_hidden_states, 396 | encoder_attention_mask, 397 | past_key_value, 398 | output_attentions, 399 | ) 400 | attention_output = self.output(self_outputs[0], hidden_states) 401 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them 402 | return outputs 403 | 404 | 405 | class BertIntermediate(nn.Module): 406 | def __init__(self, config): 407 | super().__init__() 408 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 409 | if isinstance(config.hidden_act, str): 410 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 411 | else: 412 | self.intermediate_act_fn = config.hidden_act 413 | 414 | def forward(self, hidden_states): 415 | hidden_states = self.dense(hidden_states) 416 | hidden_states = self.intermediate_act_fn(hidden_states) 417 | return hidden_states 418 | 419 | 420 | class BertOutput(nn.Module): 421 | def __init__(self, config): 422 | super().__init__() 423 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 424 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 425 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 426 | 427 | def forward(self, hidden_states, input_tensor): 428 | hidden_states = self.dense(hidden_states) 429 | hidden_states = self.dropout(hidden_states) 430 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 431 | return hidden_states 432 | 433 | 434 | class BertLayer(nn.Module): 435 | def __init__(self, config): 436 | super().__init__() 437 | self.chunk_size_feed_forward = config.chunk_size_feed_forward 438 | self.seq_len_dim = 1 439 | self.attention = BertAttention(config) 440 | self.is_decoder = config.is_decoder 441 | self.add_cross_attention = config.add_cross_attention 442 | if self.add_cross_attention: 443 | assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added" 444 | self.crossattention = BertAttention(config) 445 | self.intermediate = BertIntermediate(config) 446 | self.output = BertOutput(config) 447 | 448 | def forward( 449 | self, 450 | hidden_states, 451 | attention_mask=None, 452 | head_mask=None, 453 | encoder_hidden_states=None, 454 | encoder_attention_mask=None, 455 | past_key_value=None, 456 | output_attentions=False, 457 | ): 458 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 459 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 460 | self_attention_outputs = self.attention( 461 | hidden_states, 462 | attention_mask, 463 | head_mask, 464 | output_attentions=output_attentions, 465 | past_key_value=self_attn_past_key_value, 466 | ) 467 | attention_output = self_attention_outputs[0] 468 | 469 | # if decoder, the last output is tuple of self-attn cache 470 | if self.is_decoder: 471 | outputs = self_attention_outputs[1:-1] 472 | present_key_value = self_attention_outputs[-1] 473 | else: 474 | outputs = self_attention_outputs[1:] # add self attentions if we output attention weights 475 | 476 | cross_attn_present_key_value = None 477 | if self.is_decoder and encoder_hidden_states is not None: 478 | assert hasattr( 479 | self, "crossattention" 480 | ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" 481 | 482 | # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple 483 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None 484 | cross_attention_outputs = self.crossattention( 485 | attention_output, 486 | attention_mask, 487 | head_mask, 488 | encoder_hidden_states, 489 | encoder_attention_mask, 490 | cross_attn_past_key_value, 491 | output_attentions, 492 | ) 493 | attention_output = cross_attention_outputs[0] 494 | outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights 495 | 496 | # add cross-attn cache to positions 3,4 of present_key_value tuple 497 | cross_attn_present_key_value = cross_attention_outputs[-1] 498 | present_key_value = present_key_value + cross_attn_present_key_value 499 | 500 | layer_output = apply_chunking_to_forward( 501 | self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output 502 | ) 503 | outputs = (layer_output,) + outputs 504 | 505 | # if decoder, return the attn key/values as the last output 506 | if self.is_decoder: 507 | outputs = outputs + (present_key_value,) 508 | 509 | return outputs 510 | 511 | def feed_forward_chunk(self, attention_output): 512 | intermediate_output = self.intermediate(attention_output) 513 | layer_output = self.output(intermediate_output, attention_output) 514 | return layer_output 515 | 516 | 517 | class BertEncoder(nn.Module): 518 | def __init__(self, config): 519 | super().__init__() 520 | self.config = config 521 | self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) 522 | 523 | def forward( 524 | self, 525 | hidden_states, 526 | attention_mask=None, 527 | head_mask=None, 528 | encoder_hidden_states=None, 529 | encoder_attention_mask=None, 530 | past_key_values=None, 531 | use_cache=None, 532 | output_attentions=False, 533 | output_hidden_states=False, 534 | return_dict=True, 535 | ): 536 | all_hidden_states = () if output_hidden_states else None 537 | all_self_attentions = () if output_attentions else None 538 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 539 | 540 | next_decoder_cache = () if use_cache else None 541 | for i, layer_module in enumerate(self.layer): 542 | if output_hidden_states: 543 | all_hidden_states = all_hidden_states + (hidden_states,) 544 | 545 | layer_head_mask = head_mask[i] if head_mask is not None else None 546 | past_key_value = past_key_values[i] if past_key_values is not None else None 547 | 548 | if getattr(self.config, "gradient_checkpointing", False) and self.training: 549 | 550 | if use_cache: 551 | logger.warn( 552 | "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " 553 | "`use_cache=False`..." 554 | ) 555 | use_cache = False 556 | 557 | def create_custom_forward(module): 558 | def custom_forward(*inputs): 559 | return module(*inputs, past_key_value, output_attentions) 560 | 561 | return custom_forward 562 | 563 | layer_outputs = torch.utils.checkpoint.checkpoint( 564 | create_custom_forward(layer_module), 565 | hidden_states, 566 | attention_mask, 567 | layer_head_mask, 568 | encoder_hidden_states, 569 | encoder_attention_mask, 570 | ) 571 | else: 572 | layer_outputs = layer_module( 573 | hidden_states, 574 | attention_mask, 575 | layer_head_mask, 576 | encoder_hidden_states, 577 | encoder_attention_mask, 578 | past_key_value, 579 | output_attentions, 580 | ) 581 | 582 | hidden_states = layer_outputs[0] 583 | if use_cache: 584 | next_decoder_cache += (layer_outputs[-1],) 585 | if output_attentions: 586 | all_self_attentions = all_self_attentions + (layer_outputs[1],) 587 | if self.config.add_cross_attention: 588 | all_cross_attentions = all_cross_attentions + (layer_outputs[2],) 589 | 590 | if output_hidden_states: 591 | all_hidden_states = all_hidden_states + (hidden_states,) 592 | 593 | if not return_dict: 594 | return tuple( 595 | v 596 | for v in [ 597 | hidden_states, 598 | next_decoder_cache, 599 | all_hidden_states, 600 | all_self_attentions, 601 | all_cross_attentions, 602 | ] 603 | if v is not None 604 | ) 605 | return BaseModelOutputWithPastAndCrossAttentions( 606 | last_hidden_state=hidden_states, 607 | past_key_values=next_decoder_cache, 608 | hidden_states=all_hidden_states, 609 | attentions=all_self_attentions, 610 | cross_attentions=all_cross_attentions, 611 | ) 612 | 613 | 614 | class BertPooler(nn.Module): 615 | def __init__(self, config): 616 | super().__init__() 617 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 618 | self.activation = nn.Tanh() 619 | 620 | def forward(self, hidden_states): 621 | # We "pool" the model by simply taking the hidden state corresponding 622 | # to the first token. 623 | first_token_tensor = hidden_states[:, 0] 624 | pooled_output = self.dense(first_token_tensor) 625 | pooled_output = self.activation(pooled_output) 626 | return pooled_output 627 | 628 | 629 | class BertPredictionHeadTransform(nn.Module): 630 | def __init__(self, config): 631 | super().__init__() 632 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 633 | if isinstance(config.hidden_act, str): 634 | self.transform_act_fn = ACT2FN[config.hidden_act] 635 | else: 636 | self.transform_act_fn = config.hidden_act 637 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 638 | 639 | def forward(self, hidden_states): 640 | hidden_states = self.dense(hidden_states) 641 | hidden_states = self.transform_act_fn(hidden_states) 642 | hidden_states = self.LayerNorm(hidden_states) 643 | return hidden_states 644 | 645 | 646 | class BertLMPredictionHead(nn.Module): 647 | def __init__(self, config): 648 | super().__init__() 649 | self.transform = BertPredictionHeadTransform(config) 650 | 651 | # The output weights are the same as the input embeddings, but there is 652 | # an output-only bias for each token. 653 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 654 | 655 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 656 | 657 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` 658 | self.decoder.bias = self.bias 659 | 660 | def forward(self, hidden_states): 661 | hidden_states = self.transform(hidden_states) 662 | hidden_states = self.decoder(hidden_states) 663 | return hidden_states 664 | 665 | 666 | class BertOnlyMLMHead(nn.Module): 667 | def __init__(self, config): 668 | super().__init__() 669 | self.predictions = BertLMPredictionHead(config) 670 | 671 | def forward(self, sequence_output): 672 | prediction_scores = self.predictions(sequence_output) 673 | return prediction_scores 674 | 675 | 676 | class BertOnlyNSPHead(nn.Module): 677 | def __init__(self, config): 678 | super().__init__() 679 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 680 | 681 | def forward(self, pooled_output): 682 | seq_relationship_score = self.seq_relationship(pooled_output) 683 | return seq_relationship_score 684 | 685 | 686 | class BertPreTrainingHeads(nn.Module): 687 | def __init__(self, config): 688 | super().__init__() 689 | self.predictions = BertLMPredictionHead(config) 690 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 691 | 692 | def forward(self, sequence_output, pooled_output): 693 | prediction_scores = self.predictions(sequence_output) 694 | seq_relationship_score = self.seq_relationship(pooled_output) 695 | return prediction_scores, seq_relationship_score 696 | 697 | 698 | class BertPreTrainedModel(PreTrainedModel): 699 | """ 700 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 701 | models. 702 | """ 703 | 704 | config_class = BertConfig 705 | load_tf_weights = load_tf_weights_in_bert 706 | base_model_prefix = "bert" 707 | _keys_to_ignore_on_load_missing = [r"position_ids"] 708 | 709 | def _init_weights(self, module): 710 | """ Initialize the weights """ 711 | if isinstance(module, nn.Linear): 712 | # Slightly different from the TF version which uses truncated_normal for initialization 713 | # cf https://github.com/pytorch/pytorch/pull/5617 714 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 715 | if module.bias is not None: 716 | module.bias.data.zero_() 717 | elif isinstance(module, nn.Embedding): 718 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 719 | if module.padding_idx is not None: 720 | module.weight.data[module.padding_idx].zero_() 721 | elif isinstance(module, nn.LayerNorm): 722 | module.bias.data.zero_() 723 | module.weight.data.fill_(1.0) 724 | 725 | 726 | @dataclass 727 | class BertForPreTrainingOutput(ModelOutput): 728 | """ 729 | Output type of :class:`~transformers.BertForPreTraining`. 730 | 731 | Args: 732 | loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`): 733 | Total loss as the sum of the masked language modeling loss and the next sequence prediction 734 | (classification) loss. 735 | prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): 736 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 737 | seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`): 738 | Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation 739 | before SoftMax). 740 | hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): 741 | Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) 742 | of shape :obj:`(batch_size, sequence_length, hidden_size)`. 743 | 744 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 745 | attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): 746 | Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, 747 | sequence_length, sequence_length)`. 748 | 749 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 750 | heads. 751 | """ 752 | 753 | loss: Optional[torch.FloatTensor] = None 754 | prediction_logits: torch.FloatTensor = None 755 | seq_relationship_logits: torch.FloatTensor = None 756 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 757 | attentions: Optional[Tuple[torch.FloatTensor]] = None 758 | 759 | 760 | BERT_START_DOCSTRING = r""" 761 | 762 | This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic 763 | methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, 764 | pruning heads etc.) 765 | 766 | This model is also a PyTorch `torch.nn.Module `__ 767 | subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to 768 | general usage and behavior. 769 | 770 | Parameters: 771 | config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model. 772 | Initializing with a config file does not load the weights associated with the model, only the 773 | configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model 774 | weights. 775 | """ 776 | 777 | BERT_INPUTS_DOCSTRING = r""" 778 | Args: 779 | input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): 780 | Indices of input sequence tokens in the vocabulary. 781 | 782 | Indices can be obtained using :class:`~transformers.BertTokenizer`. See 783 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for 784 | details. 785 | 786 | `What are input IDs? <../glossary.html#input-ids>`__ 787 | attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): 788 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 789 | 790 | - 1 for tokens that are **not masked**, 791 | - 0 for tokens that are **masked**. 792 | 793 | `What are attention masks? <../glossary.html#attention-mask>`__ 794 | token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): 795 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 796 | 1]``: 797 | 798 | - 0 corresponds to a `sentence A` token, 799 | - 1 corresponds to a `sentence B` token. 800 | 801 | `What are token type IDs? <../glossary.html#token-type-ids>`_ 802 | position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): 803 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, 804 | config.max_position_embeddings - 1]``. 805 | 806 | `What are position IDs? <../glossary.html#position-ids>`_ 807 | head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): 808 | Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: 809 | 810 | - 1 indicates the head is **not masked**, 811 | - 0 indicates the head is **masked**. 812 | 813 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): 814 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. 815 | This is useful if you want more control over how to convert :obj:`input_ids` indices into associated 816 | vectors than the model's internal embedding lookup matrix. 817 | output_attentions (:obj:`bool`, `optional`): 818 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned 819 | tensors for more detail. 820 | output_hidden_states (:obj:`bool`, `optional`): 821 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for 822 | more detail. 823 | return_dict (:obj:`bool`, `optional`): 824 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 825 | """ 826 | 827 | 828 | @add_start_docstrings( 829 | "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", 830 | BERT_START_DOCSTRING, 831 | ) 832 | class BertModel(BertPreTrainedModel): 833 | """ 834 | 835 | The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of 836 | cross-attention is added between the self-attention layers, following the architecture described in `Attention is 837 | all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, 838 | Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. 839 | 840 | To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration 841 | set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder` 842 | argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an 843 | input to the forward pass. 844 | """ 845 | 846 | def __init__(self, config, add_pooling_layer=True): 847 | super().__init__(config) 848 | self.config = config 849 | 850 | self.embeddings = BertEmbeddings(config) 851 | self.encoder = BertEncoder(config) 852 | 853 | self.pooler = BertPooler(config) if add_pooling_layer else None 854 | 855 | self.init_weights() 856 | 857 | def get_input_embeddings(self): 858 | return self.embeddings.word_embeddings 859 | 860 | def set_input_embeddings(self, value): 861 | self.embeddings.word_embeddings = value 862 | 863 | def _prune_heads(self, heads_to_prune): 864 | """ 865 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 866 | class PreTrainedModel 867 | """ 868 | for layer, heads in heads_to_prune.items(): 869 | self.encoder.layer[layer].attention.prune_heads(heads) 870 | 871 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 872 | @add_code_sample_docstrings( 873 | tokenizer_class=_TOKENIZER_FOR_DOC, 874 | checkpoint=_CHECKPOINT_FOR_DOC, 875 | output_type=BaseModelOutputWithPoolingAndCrossAttentions, 876 | config_class=_CONFIG_FOR_DOC, 877 | ) 878 | def forward( 879 | self, 880 | input_ids=None, 881 | attention_mask=None, 882 | token_type_ids=None, 883 | position_ids=None, 884 | speaker_ids=None, 885 | head_mask=None, 886 | inputs_embeds=None, 887 | encoder_hidden_states=None, 888 | encoder_attention_mask=None, 889 | past_key_values=None, 890 | use_cache=None, 891 | output_attentions=None, 892 | output_hidden_states=None, 893 | return_dict=None, 894 | ): 895 | r""" 896 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 897 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if 898 | the model is configured as a decoder. 899 | encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 900 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in 901 | the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: 902 | 903 | - 1 for tokens that are **not masked**, 904 | - 0 for tokens that are **masked**. 905 | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 906 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 907 | 908 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` 909 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` 910 | instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. 911 | use_cache (:obj:`bool`, `optional`): 912 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up 913 | decoding (see :obj:`past_key_values`). 914 | """ 915 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 916 | output_hidden_states = ( 917 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 918 | ) 919 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 920 | 921 | if self.config.is_decoder: 922 | use_cache = use_cache if use_cache is not None else self.config.use_cache 923 | else: 924 | use_cache = False 925 | 926 | if input_ids is not None and inputs_embeds is not None: 927 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 928 | elif input_ids is not None: 929 | input_shape = input_ids.size() 930 | batch_size, seq_length = input_shape 931 | elif inputs_embeds is not None: 932 | input_shape = inputs_embeds.size()[:-1] 933 | batch_size, seq_length = input_shape 934 | else: 935 | raise ValueError("You have to specify either input_ids or inputs_embeds") 936 | 937 | device = input_ids.device if input_ids is not None else inputs_embeds.device 938 | 939 | # past_key_values_length 940 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 941 | 942 | if attention_mask is None: 943 | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) 944 | if token_type_ids is None: 945 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 946 | 947 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 948 | # ourselves in which case we just need to make it broadcastable to all heads. 949 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) 950 | 951 | # If a 2D or 3D attention mask is provided for the cross-attention 952 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 953 | if self.config.is_decoder and encoder_hidden_states is not None: 954 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 955 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 956 | if encoder_attention_mask is None: 957 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 958 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 959 | else: 960 | encoder_extended_attention_mask = None 961 | 962 | # Prepare head mask if needed 963 | # 1.0 in head_mask indicate we keep the head 964 | # attention_probs has shape bsz x n_heads x N x N 965 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 966 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 967 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 968 | 969 | embedding_output = self.embeddings( 970 | input_ids=input_ids, 971 | position_ids=position_ids, 972 | token_type_ids=token_type_ids, 973 | speaker_ids=speaker_ids, 974 | inputs_embeds=inputs_embeds, 975 | past_key_values_length=past_key_values_length, 976 | ) 977 | encoder_outputs = self.encoder( 978 | embedding_output, 979 | attention_mask=extended_attention_mask, 980 | head_mask=head_mask, 981 | encoder_hidden_states=encoder_hidden_states, 982 | encoder_attention_mask=encoder_extended_attention_mask, 983 | past_key_values=past_key_values, 984 | use_cache=use_cache, 985 | output_attentions=output_attentions, 986 | output_hidden_states=output_hidden_states, 987 | return_dict=return_dict, 988 | ) 989 | sequence_output = encoder_outputs[0] 990 | pooled_output = self.pooler(sequence_output) if self.pooler is not None else None 991 | 992 | if not return_dict: 993 | return (sequence_output, pooled_output) + encoder_outputs[1:] 994 | 995 | return BaseModelOutputWithPoolingAndCrossAttentions( 996 | last_hidden_state=sequence_output, 997 | pooler_output=pooled_output, 998 | past_key_values=encoder_outputs.past_key_values, 999 | hidden_states=encoder_outputs.hidden_states, 1000 | attentions=encoder_outputs.attentions, 1001 | cross_attentions=encoder_outputs.cross_attentions, 1002 | ) 1003 | 1004 | 1005 | @add_start_docstrings( 1006 | """ 1007 | Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next 1008 | sentence prediction (classification)` head. 1009 | """, 1010 | BERT_START_DOCSTRING, 1011 | ) 1012 | class BertForPreTraining(BertPreTrainedModel): 1013 | def __init__(self, config): 1014 | super().__init__(config) 1015 | 1016 | self.bert = BertModel(config) 1017 | self.cls = BertPreTrainingHeads(config) 1018 | 1019 | self.init_weights() 1020 | 1021 | def get_output_embeddings(self): 1022 | return self.cls.predictions.decoder 1023 | 1024 | def set_output_embeddings(self, new_embeddings): 1025 | self.cls.predictions.decoder = new_embeddings 1026 | 1027 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1028 | @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) 1029 | def forward( 1030 | self, 1031 | input_ids=None, 1032 | attention_mask=None, 1033 | token_type_ids=None, 1034 | position_ids=None, 1035 | head_mask=None, 1036 | inputs_embeds=None, 1037 | labels=None, 1038 | next_sentence_label=None, 1039 | output_attentions=None, 1040 | output_hidden_states=None, 1041 | return_dict=None, 1042 | ): 1043 | r""" 1044 | labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`): 1045 | Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., 1046 | config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored 1047 | (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` 1048 | next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`): 1049 | Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair 1050 | (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``: 1051 | 1052 | - 0 indicates sequence B is a continuation of sequence A, 1053 | - 1 indicates sequence B is a random sequence. 1054 | kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): 1055 | Used to hide legacy arguments that have been deprecated. 1056 | 1057 | Returns: 1058 | 1059 | Example:: 1060 | 1061 | >>> from transformers import BertTokenizer, BertForPreTraining 1062 | >>> import torch 1063 | 1064 | >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 1065 | >>> model = BertForPreTraining.from_pretrained('bert-base-uncased') 1066 | 1067 | >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") 1068 | >>> outputs = model(**inputs) 1069 | 1070 | >>> prediction_logits = outputs.prediction_logits 1071 | >>> seq_relationship_logits = outputs.seq_relationship_logits 1072 | """ 1073 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1074 | 1075 | outputs = self.bert( 1076 | input_ids, 1077 | attention_mask=attention_mask, 1078 | token_type_ids=token_type_ids, 1079 | position_ids=position_ids, 1080 | head_mask=head_mask, 1081 | inputs_embeds=inputs_embeds, 1082 | output_attentions=output_attentions, 1083 | output_hidden_states=output_hidden_states, 1084 | return_dict=return_dict, 1085 | ) 1086 | 1087 | sequence_output, pooled_output = outputs[:2] 1088 | prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) 1089 | 1090 | total_loss = None 1091 | if labels is not None and next_sentence_label is not None: 1092 | loss_fct = CrossEntropyLoss() 1093 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 1094 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 1095 | total_loss = masked_lm_loss + next_sentence_loss 1096 | 1097 | if not return_dict: 1098 | output = (prediction_scores, seq_relationship_score) + outputs[2:] 1099 | return ((total_loss,) + output) if total_loss is not None else output 1100 | 1101 | return BertForPreTrainingOutput( 1102 | loss=total_loss, 1103 | prediction_logits=prediction_scores, 1104 | seq_relationship_logits=seq_relationship_score, 1105 | hidden_states=outputs.hidden_states, 1106 | attentions=outputs.attentions, 1107 | ) 1108 | 1109 | 1110 | @add_start_docstrings( 1111 | """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING 1112 | ) 1113 | class BertLMHeadModel(BertPreTrainedModel): 1114 | 1115 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 1116 | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] 1117 | 1118 | def __init__(self, config): 1119 | super().__init__(config) 1120 | 1121 | if not config.is_decoder: 1122 | logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`") 1123 | 1124 | self.bert = BertModel(config, add_pooling_layer=False) 1125 | self.cls = BertOnlyMLMHead(config) 1126 | 1127 | self.init_weights() 1128 | 1129 | def get_output_embeddings(self): 1130 | return self.cls.predictions.decoder 1131 | 1132 | def set_output_embeddings(self, new_embeddings): 1133 | self.cls.predictions.decoder = new_embeddings 1134 | 1135 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1136 | @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) 1137 | def forward( 1138 | self, 1139 | input_ids=None, 1140 | attention_mask=None, 1141 | token_type_ids=None, 1142 | position_ids=None, 1143 | head_mask=None, 1144 | inputs_embeds=None, 1145 | encoder_hidden_states=None, 1146 | encoder_attention_mask=None, 1147 | labels=None, 1148 | past_key_values=None, 1149 | use_cache=None, 1150 | output_attentions=None, 1151 | output_hidden_states=None, 1152 | return_dict=None, 1153 | ): 1154 | r""" 1155 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 1156 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if 1157 | the model is configured as a decoder. 1158 | encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 1159 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in 1160 | the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: 1161 | 1162 | - 1 for tokens that are **not masked**, 1163 | - 0 for tokens that are **masked**. 1164 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 1165 | Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in 1166 | ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are 1167 | ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` 1168 | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 1169 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 1170 | 1171 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` 1172 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` 1173 | instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. 1174 | use_cache (:obj:`bool`, `optional`): 1175 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up 1176 | decoding (see :obj:`past_key_values`). 1177 | 1178 | Returns: 1179 | 1180 | Example:: 1181 | 1182 | >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig 1183 | >>> import torch 1184 | 1185 | >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') 1186 | >>> config = BertConfig.from_pretrained("bert-base-cased") 1187 | >>> config.is_decoder = True 1188 | >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) 1189 | 1190 | >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") 1191 | >>> outputs = model(**inputs) 1192 | 1193 | >>> prediction_logits = outputs.logits 1194 | """ 1195 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1196 | if labels is not None: 1197 | use_cache = False 1198 | 1199 | outputs = self.bert( 1200 | input_ids, 1201 | attention_mask=attention_mask, 1202 | token_type_ids=token_type_ids, 1203 | position_ids=position_ids, 1204 | head_mask=head_mask, 1205 | inputs_embeds=inputs_embeds, 1206 | encoder_hidden_states=encoder_hidden_states, 1207 | encoder_attention_mask=encoder_attention_mask, 1208 | past_key_values=past_key_values, 1209 | use_cache=use_cache, 1210 | output_attentions=output_attentions, 1211 | output_hidden_states=output_hidden_states, 1212 | return_dict=return_dict, 1213 | ) 1214 | 1215 | sequence_output = outputs[0] 1216 | prediction_scores = self.cls(sequence_output) 1217 | 1218 | lm_loss = None 1219 | if labels is not None: 1220 | # we are doing next-token prediction; shift prediction scores and input ids by one 1221 | shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() 1222 | labels = labels[:, 1:].contiguous() 1223 | loss_fct = CrossEntropyLoss() 1224 | lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 1225 | 1226 | if not return_dict: 1227 | output = (prediction_scores,) + outputs[2:] 1228 | return ((lm_loss,) + output) if lm_loss is not None else output 1229 | 1230 | return CausalLMOutputWithCrossAttentions( 1231 | loss=lm_loss, 1232 | logits=prediction_scores, 1233 | past_key_values=outputs.past_key_values, 1234 | hidden_states=outputs.hidden_states, 1235 | attentions=outputs.attentions, 1236 | cross_attentions=outputs.cross_attentions, 1237 | ) 1238 | 1239 | def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): 1240 | input_shape = input_ids.shape 1241 | # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly 1242 | if attention_mask is None: 1243 | attention_mask = input_ids.new_ones(input_shape) 1244 | 1245 | # cut decoder_input_ids if past is used 1246 | if past is not None: 1247 | input_ids = input_ids[:, -1:] 1248 | 1249 | return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} 1250 | 1251 | def _reorder_cache(self, past, beam_idx): 1252 | reordered_past = () 1253 | for layer_past in past: 1254 | reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) 1255 | return reordered_past 1256 | 1257 | 1258 | @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING) 1259 | class BertForMaskedLM(BertPreTrainedModel): 1260 | 1261 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 1262 | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] 1263 | 1264 | def __init__(self, config): 1265 | super().__init__(config) 1266 | 1267 | if config.is_decoder: 1268 | logger.warning( 1269 | "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " 1270 | "bi-directional self-attention." 1271 | ) 1272 | 1273 | self.bert = BertModel(config, add_pooling_layer=False) 1274 | self.cls = BertOnlyMLMHead(config) 1275 | 1276 | self.init_weights() 1277 | 1278 | def get_output_embeddings(self): 1279 | return self.cls.predictions.decoder 1280 | 1281 | def set_output_embeddings(self, new_embeddings): 1282 | self.cls.predictions.decoder = new_embeddings 1283 | 1284 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1285 | @add_code_sample_docstrings( 1286 | tokenizer_class=_TOKENIZER_FOR_DOC, 1287 | checkpoint=_CHECKPOINT_FOR_DOC, 1288 | output_type=MaskedLMOutput, 1289 | config_class=_CONFIG_FOR_DOC, 1290 | ) 1291 | def forward( 1292 | self, 1293 | input_ids=None, 1294 | attention_mask=None, 1295 | token_type_ids=None, 1296 | position_ids=None, 1297 | head_mask=None, 1298 | inputs_embeds=None, 1299 | encoder_hidden_states=None, 1300 | encoder_attention_mask=None, 1301 | labels=None, 1302 | output_attentions=None, 1303 | output_hidden_states=None, 1304 | return_dict=None, 1305 | ): 1306 | r""" 1307 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 1308 | Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., 1309 | config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored 1310 | (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` 1311 | """ 1312 | 1313 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1314 | 1315 | outputs = self.bert( 1316 | input_ids, 1317 | attention_mask=attention_mask, 1318 | token_type_ids=token_type_ids, 1319 | position_ids=position_ids, 1320 | head_mask=head_mask, 1321 | inputs_embeds=inputs_embeds, 1322 | encoder_hidden_states=encoder_hidden_states, 1323 | encoder_attention_mask=encoder_attention_mask, 1324 | output_attentions=output_attentions, 1325 | output_hidden_states=output_hidden_states, 1326 | return_dict=return_dict, 1327 | ) 1328 | 1329 | sequence_output = outputs[0] 1330 | prediction_scores = self.cls(sequence_output) 1331 | 1332 | masked_lm_loss = None 1333 | if labels is not None: 1334 | loss_fct = CrossEntropyLoss() # -100 index = padding token 1335 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 1336 | 1337 | if not return_dict: 1338 | output = (prediction_scores,) + outputs[2:] 1339 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 1340 | 1341 | return MaskedLMOutput( 1342 | loss=masked_lm_loss, 1343 | logits=prediction_scores, 1344 | hidden_states=outputs.hidden_states, 1345 | attentions=outputs.attentions, 1346 | ) 1347 | 1348 | def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): 1349 | input_shape = input_ids.shape 1350 | effective_batch_size = input_shape[0] 1351 | 1352 | # add a dummy token 1353 | assert self.config.pad_token_id is not None, "The PAD token should be defined for generation" 1354 | attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) 1355 | dummy_token = torch.full( 1356 | (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device 1357 | ) 1358 | input_ids = torch.cat([input_ids, dummy_token], dim=1) 1359 | 1360 | return {"input_ids": input_ids, "attention_mask": attention_mask} 1361 | 1362 | 1363 | @add_start_docstrings( 1364 | """Bert Model with a `next sentence prediction (classification)` head on top. """, 1365 | BERT_START_DOCSTRING, 1366 | ) 1367 | class BertForNextSentencePrediction(BertPreTrainedModel): 1368 | def __init__(self, config): 1369 | super().__init__(config) 1370 | 1371 | self.bert = BertModel(config) 1372 | self.cls = BertOnlyNSPHead(config) 1373 | 1374 | self.init_weights() 1375 | 1376 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1377 | @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) 1378 | def forward( 1379 | self, 1380 | input_ids=None, 1381 | attention_mask=None, 1382 | token_type_ids=None, 1383 | position_ids=None, 1384 | head_mask=None, 1385 | inputs_embeds=None, 1386 | labels=None, 1387 | output_attentions=None, 1388 | output_hidden_states=None, 1389 | return_dict=None, 1390 | **kwargs 1391 | ): 1392 | r""" 1393 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1394 | Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair 1395 | (see ``input_ids`` docstring). Indices should be in ``[0, 1]``: 1396 | 1397 | - 0 indicates sequence B is a continuation of sequence A, 1398 | - 1 indicates sequence B is a random sequence. 1399 | 1400 | Returns: 1401 | 1402 | Example:: 1403 | 1404 | >>> from transformers import BertTokenizer, BertForNextSentencePrediction 1405 | >>> import torch 1406 | 1407 | >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 1408 | >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased') 1409 | 1410 | >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." 1411 | >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." 1412 | >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt') 1413 | 1414 | >>> outputs = model(**encoding, labels=torch.LongTensor([1])) 1415 | >>> logits = outputs.logits 1416 | >>> assert logits[0, 0] < logits[0, 1] # next sentence was random 1417 | """ 1418 | 1419 | if "next_sentence_label" in kwargs: 1420 | warnings.warn( 1421 | "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.", 1422 | FutureWarning, 1423 | ) 1424 | labels = kwargs.pop("next_sentence_label") 1425 | 1426 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1427 | 1428 | outputs = self.bert( 1429 | input_ids, 1430 | attention_mask=attention_mask, 1431 | token_type_ids=token_type_ids, 1432 | position_ids=position_ids, 1433 | head_mask=head_mask, 1434 | inputs_embeds=inputs_embeds, 1435 | output_attentions=output_attentions, 1436 | output_hidden_states=output_hidden_states, 1437 | return_dict=return_dict, 1438 | ) 1439 | 1440 | pooled_output = outputs[1] 1441 | 1442 | seq_relationship_scores = self.cls(pooled_output) 1443 | 1444 | next_sentence_loss = None 1445 | if labels is not None: 1446 | loss_fct = CrossEntropyLoss() 1447 | next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) 1448 | 1449 | if not return_dict: 1450 | output = (seq_relationship_scores,) + outputs[2:] 1451 | return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output 1452 | 1453 | return NextSentencePredictorOutput( 1454 | loss=next_sentence_loss, 1455 | logits=seq_relationship_scores, 1456 | hidden_states=outputs.hidden_states, 1457 | attentions=outputs.attentions, 1458 | ) 1459 | 1460 | 1461 | @add_start_docstrings( 1462 | """ 1463 | Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled 1464 | output) e.g. for GLUE tasks. 1465 | """, 1466 | BERT_START_DOCSTRING, 1467 | ) 1468 | class BertForSequenceClassification(BertPreTrainedModel): 1469 | def __init__(self, config): 1470 | super().__init__(config) 1471 | self.num_labels = config.num_labels 1472 | 1473 | self.bert = BertModel(config) 1474 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1475 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 1476 | 1477 | self.init_weights() 1478 | 1479 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1480 | @add_code_sample_docstrings( 1481 | tokenizer_class=_TOKENIZER_FOR_DOC, 1482 | checkpoint=_CHECKPOINT_FOR_DOC, 1483 | output_type=SequenceClassifierOutput, 1484 | config_class=_CONFIG_FOR_DOC, 1485 | ) 1486 | def forward( 1487 | self, 1488 | input_ids=None, 1489 | attention_mask=None, 1490 | token_type_ids=None, 1491 | position_ids=None, 1492 | head_mask=None, 1493 | inputs_embeds=None, 1494 | labels=None, 1495 | output_attentions=None, 1496 | output_hidden_states=None, 1497 | return_dict=None, 1498 | ): 1499 | r""" 1500 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1501 | Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., 1502 | config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), 1503 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1504 | """ 1505 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1506 | 1507 | outputs = self.bert( 1508 | input_ids, 1509 | attention_mask=attention_mask, 1510 | token_type_ids=token_type_ids, 1511 | position_ids=position_ids, 1512 | head_mask=head_mask, 1513 | inputs_embeds=inputs_embeds, 1514 | output_attentions=output_attentions, 1515 | output_hidden_states=output_hidden_states, 1516 | return_dict=return_dict, 1517 | ) 1518 | 1519 | pooled_output = outputs[1] 1520 | 1521 | pooled_output = self.dropout(pooled_output) 1522 | logits = self.classifier(pooled_output) 1523 | 1524 | loss = None 1525 | if labels is not None: 1526 | if self.num_labels == 1: 1527 | # We are doing regression 1528 | loss_fct = MSELoss() 1529 | loss = loss_fct(logits.view(-1), labels.view(-1)) 1530 | else: 1531 | loss_fct = CrossEntropyLoss() 1532 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1533 | 1534 | if not return_dict: 1535 | output = (logits,) + outputs[2:] 1536 | return ((loss,) + output) if loss is not None else output 1537 | 1538 | return SequenceClassifierOutput( 1539 | loss=loss, 1540 | logits=logits, 1541 | hidden_states=outputs.hidden_states, 1542 | attentions=outputs.attentions, 1543 | ) 1544 | 1545 | 1546 | @add_start_docstrings( 1547 | """ 1548 | Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a 1549 | softmax) e.g. for RocStories/SWAG tasks. 1550 | """, 1551 | BERT_START_DOCSTRING, 1552 | ) 1553 | class BertForMultipleChoice(BertPreTrainedModel): 1554 | def __init__(self, config): 1555 | super().__init__(config) 1556 | 1557 | self.bert = BertModel(config) 1558 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1559 | self.classifier = nn.Linear(config.hidden_size, 1) 1560 | 1561 | self.init_weights() 1562 | 1563 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) 1564 | @add_code_sample_docstrings( 1565 | tokenizer_class=_TOKENIZER_FOR_DOC, 1566 | checkpoint=_CHECKPOINT_FOR_DOC, 1567 | output_type=MultipleChoiceModelOutput, 1568 | config_class=_CONFIG_FOR_DOC, 1569 | ) 1570 | def forward( 1571 | self, 1572 | input_ids=None, 1573 | attention_mask=None, 1574 | token_type_ids=None, 1575 | position_ids=None, 1576 | head_mask=None, 1577 | inputs_embeds=None, 1578 | labels=None, 1579 | output_attentions=None, 1580 | output_hidden_states=None, 1581 | return_dict=None, 1582 | ): 1583 | r""" 1584 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1585 | Labels for computing the multiple choice classification loss. Indices should be in ``[0, ..., 1586 | num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See 1587 | :obj:`input_ids` above) 1588 | """ 1589 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1590 | num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] 1591 | 1592 | input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None 1593 | attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None 1594 | token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None 1595 | position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None 1596 | inputs_embeds = ( 1597 | inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) 1598 | if inputs_embeds is not None 1599 | else None 1600 | ) 1601 | 1602 | outputs = self.bert( 1603 | input_ids, 1604 | attention_mask=attention_mask, 1605 | token_type_ids=token_type_ids, 1606 | position_ids=position_ids, 1607 | head_mask=head_mask, 1608 | inputs_embeds=inputs_embeds, 1609 | output_attentions=output_attentions, 1610 | output_hidden_states=output_hidden_states, 1611 | return_dict=return_dict, 1612 | ) 1613 | 1614 | pooled_output = outputs[1] 1615 | 1616 | pooled_output = self.dropout(pooled_output) 1617 | logits = self.classifier(pooled_output) 1618 | reshaped_logits = logits.view(-1, num_choices) 1619 | 1620 | loss = None 1621 | if labels is not None: 1622 | loss_fct = CrossEntropyLoss() 1623 | loss = loss_fct(reshaped_logits, labels) 1624 | 1625 | if not return_dict: 1626 | output = (reshaped_logits,) + outputs[2:] 1627 | return ((loss,) + output) if loss is not None else output 1628 | 1629 | return MultipleChoiceModelOutput( 1630 | loss=loss, 1631 | logits=reshaped_logits, 1632 | hidden_states=outputs.hidden_states, 1633 | attentions=outputs.attentions, 1634 | ) 1635 | 1636 | 1637 | @add_start_docstrings( 1638 | """ 1639 | Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for 1640 | Named-Entity-Recognition (NER) tasks. 1641 | """, 1642 | BERT_START_DOCSTRING, 1643 | ) 1644 | class BertForTokenClassification(BertPreTrainedModel): 1645 | 1646 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 1647 | 1648 | def __init__(self, config): 1649 | super().__init__(config) 1650 | self.num_labels = config.num_labels 1651 | 1652 | self.bert = BertModel(config, add_pooling_layer=False) 1653 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1654 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 1655 | 1656 | self.init_weights() 1657 | 1658 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1659 | @add_code_sample_docstrings( 1660 | tokenizer_class=_TOKENIZER_FOR_DOC, 1661 | checkpoint=_CHECKPOINT_FOR_DOC, 1662 | output_type=TokenClassifierOutput, 1663 | config_class=_CONFIG_FOR_DOC, 1664 | ) 1665 | def forward( 1666 | self, 1667 | input_ids=None, 1668 | attention_mask=None, 1669 | token_type_ids=None, 1670 | position_ids=None, 1671 | head_mask=None, 1672 | inputs_embeds=None, 1673 | labels=None, 1674 | output_attentions=None, 1675 | output_hidden_states=None, 1676 | return_dict=None, 1677 | ): 1678 | r""" 1679 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 1680 | Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - 1681 | 1]``. 1682 | """ 1683 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1684 | 1685 | outputs = self.bert( 1686 | input_ids, 1687 | attention_mask=attention_mask, 1688 | token_type_ids=token_type_ids, 1689 | position_ids=position_ids, 1690 | head_mask=head_mask, 1691 | inputs_embeds=inputs_embeds, 1692 | output_attentions=output_attentions, 1693 | output_hidden_states=output_hidden_states, 1694 | return_dict=return_dict, 1695 | ) 1696 | 1697 | sequence_output = outputs[0] 1698 | 1699 | sequence_output = self.dropout(sequence_output) 1700 | logits = self.classifier(sequence_output) 1701 | 1702 | loss = None 1703 | if labels is not None: 1704 | loss_fct = CrossEntropyLoss() 1705 | # Only keep active parts of the loss 1706 | if attention_mask is not None: 1707 | active_loss = attention_mask.view(-1) == 1 1708 | active_logits = logits.view(-1, self.num_labels) 1709 | active_labels = torch.where( 1710 | active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) 1711 | ) 1712 | loss = loss_fct(active_logits, active_labels) 1713 | else: 1714 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1715 | 1716 | if not return_dict: 1717 | output = (logits,) + outputs[2:] 1718 | return ((loss,) + output) if loss is not None else output 1719 | 1720 | return TokenClassifierOutput( 1721 | loss=loss, 1722 | logits=logits, 1723 | hidden_states=outputs.hidden_states, 1724 | attentions=outputs.attentions, 1725 | ) 1726 | 1727 | 1728 | @add_start_docstrings( 1729 | """ 1730 | Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear 1731 | layers on top of the hidden-states output to compute `span start logits` and `span end logits`). 1732 | """, 1733 | BERT_START_DOCSTRING, 1734 | ) 1735 | class BertForQuestionAnswering(BertPreTrainedModel): 1736 | 1737 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 1738 | 1739 | def __init__(self, config): 1740 | super().__init__(config) 1741 | self.num_labels = config.num_labels 1742 | 1743 | self.bert = BertModel(config, add_pooling_layer=False) 1744 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) 1745 | 1746 | self.init_weights() 1747 | 1748 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1749 | @add_code_sample_docstrings( 1750 | tokenizer_class=_TOKENIZER_FOR_DOC, 1751 | checkpoint=_CHECKPOINT_FOR_DOC, 1752 | output_type=QuestionAnsweringModelOutput, 1753 | config_class=_CONFIG_FOR_DOC, 1754 | ) 1755 | def forward( 1756 | self, 1757 | input_ids=None, 1758 | attention_mask=None, 1759 | token_type_ids=None, 1760 | position_ids=None, 1761 | head_mask=None, 1762 | inputs_embeds=None, 1763 | start_positions=None, 1764 | end_positions=None, 1765 | output_attentions=None, 1766 | output_hidden_states=None, 1767 | return_dict=None, 1768 | ): 1769 | r""" 1770 | start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1771 | Labels for position (index) of the start of the labelled span for computing the token classification loss. 1772 | Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the 1773 | sequence are not taken into account for computing the loss. 1774 | end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1775 | Labels for position (index) of the end of the labelled span for computing the token classification loss. 1776 | Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the 1777 | sequence are not taken into account for computing the loss. 1778 | """ 1779 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1780 | 1781 | outputs = self.bert( 1782 | input_ids, 1783 | attention_mask=attention_mask, 1784 | token_type_ids=token_type_ids, 1785 | position_ids=position_ids, 1786 | head_mask=head_mask, 1787 | inputs_embeds=inputs_embeds, 1788 | output_attentions=output_attentions, 1789 | output_hidden_states=output_hidden_states, 1790 | return_dict=return_dict, 1791 | ) 1792 | 1793 | sequence_output = outputs[0] 1794 | 1795 | logits = self.qa_outputs(sequence_output) 1796 | start_logits, end_logits = logits.split(1, dim=-1) 1797 | start_logits = start_logits.squeeze(-1) 1798 | end_logits = end_logits.squeeze(-1) 1799 | 1800 | total_loss = None 1801 | if start_positions is not None and end_positions is not None: 1802 | # If we are on multi-GPU, split add a dimension 1803 | if len(start_positions.size()) > 1: 1804 | start_positions = start_positions.squeeze(-1) 1805 | if len(end_positions.size()) > 1: 1806 | end_positions = end_positions.squeeze(-1) 1807 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 1808 | ignored_index = start_logits.size(1) 1809 | start_positions.clamp_(0, ignored_index) 1810 | end_positions.clamp_(0, ignored_index) 1811 | 1812 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 1813 | start_loss = loss_fct(start_logits, start_positions) 1814 | end_loss = loss_fct(end_logits, end_positions) 1815 | total_loss = (start_loss + end_loss) / 2 1816 | 1817 | if not return_dict: 1818 | output = (start_logits, end_logits) + outputs[2:] 1819 | return ((total_loss,) + output) if total_loss is not None else output 1820 | 1821 | return QuestionAnsweringModelOutput( 1822 | loss=total_loss, 1823 | start_logits=start_logits, 1824 | end_logits=end_logits, 1825 | hidden_states=outputs.hidden_states, 1826 | attentions=outputs.attentions, 1827 | ) 1828 | --------------------------------------------------------------------------------