├── image ├── bl.png ├── fz.png ├── main.png ├── pj.png ├── sc.png ├── tl.png ├── zj.png ├── Triage.png ├── Dagnosis.png └── Summary.png ├── Dagnosis ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── modeling_cpt.cpython-37.pyc └── __init__.py ├── Summary ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── modeling_cpt.cpython-37.pyc ├── __init__.py └── modeling_cpt.py ├── Triage ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── modeling_cpt.cpython-37.pyc └── __init__.py ├── main.py ├── README_ZH.md ├── readme.md └── modeling_cpt.py /image/bl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/image/bl.png -------------------------------------------------------------------------------- /image/fz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/image/fz.png -------------------------------------------------------------------------------- /image/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/image/main.png -------------------------------------------------------------------------------- /image/pj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/image/pj.png -------------------------------------------------------------------------------- /image/sc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/image/sc.png -------------------------------------------------------------------------------- /image/tl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/image/tl.png -------------------------------------------------------------------------------- /image/zj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/image/zj.png -------------------------------------------------------------------------------- /image/Triage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/image/Triage.png -------------------------------------------------------------------------------- /image/Dagnosis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/image/Dagnosis.png -------------------------------------------------------------------------------- /image/Summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/image/Summary.png -------------------------------------------------------------------------------- /Dagnosis/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/Dagnosis/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /Summary/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/Summary/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /Triage/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/Triage/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /Triage/__pycache__/modeling_cpt.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/Triage/__pycache__/modeling_cpt.cpython-37.pyc -------------------------------------------------------------------------------- /Dagnosis/__pycache__/modeling_cpt.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/Dagnosis/__pycache__/modeling_cpt.cpython-37.pyc -------------------------------------------------------------------------------- /Summary/__pycache__/modeling_cpt.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/Summary/__pycache__/modeling_cpt.cpython-37.pyc -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForSequenceClassification,BertTokenizer 3 | from modeling_cpt import CPTForConditionalGeneration 4 | from Triage import * 5 | from Summary import * 6 | from Dagnosis import * 7 | 8 | 9 | 10 | if __name__ == '__main__': 11 | import argparse 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--type", default='Dagnosis',type=str) 14 | parser.add_argument("--mode", default='interactive', type=str) 15 | parser.add_argument("--file_name",default=None,type=str) 16 | parser.add_argument("--result_file_name", default='result.csv', type=str) 17 | parser.add_argument("--message", default=None, type=str) 18 | parser = parser.parse_args() 19 | 20 | exec('RSA = {}(parser)'.format(parser.type)) 21 | print(RSA) 22 | -------------------------------------------------------------------------------- /Triage/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForSequenceClassification,BertTokenizer 3 | from .modeling_cpt import CPTForConditionalGeneration 4 | from Triage import * 5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 6 | FENZHEN_MODEL_NAME = 'WENGSYX/Dagnosis_Chinese_BERT' 7 | CMDD_MODEL_NAME = 'WENGSYX/Dagnosis_Chinese_CPT' 8 | BL_MODEL_NAME = 'WENGSYX/Dagnosis_Chinese_CPT' 9 | fenzhen_model = AutoModelForSequenceClassification.from_pretrained(FENZHEN_MODEL_NAME, num_labels=6).to(device) # 模型 10 | fenzhen_tokenizer = BertTokenizer.from_pretrained(FENZHEN_MODEL_NAME) 11 | 12 | def Triage(parser): 13 | 14 | if parser.mode == 'interactive': 15 | print('我们将为您分配科室') 16 | while True: 17 | message = input('请输入您想询问的症状(退出请输入\033[0;35m退出\033[0m)') 18 | if message == '退出': 19 | return '' 20 | text = fenzhen_tokenizer(message, max_length=512, return_tensors='pt') 21 | input_ids = text['input_ids'].to(device) 22 | fenzhen_prodict = fenzhen_model(input_ids)[0] 23 | ks = {0:'男科',1:'内科',2:'妇科',3:'肿瘤科',4:'儿科',5:'外科'}[int(fenzhen_prodict.argmax())] 24 | print('您可能需要前往: \033[0;32m{}\033[0m'.format(ks)) 25 | elif parser.mode == 'batch': 26 | with open(parser.file_name,'r',encoding='utf-8') as f: 27 | data = f.readlines() 28 | result = [] 29 | for message in data: 30 | text = fenzhen_tokenizer(message.replace('\n',''), max_length=512, return_tensors='pt') 31 | input_ids = text['input_ids'].to(device) 32 | fenzhen_prodict = fenzhen_model(input_ids)[0] 33 | ks = {0:'男科',1:'内科',2:'妇科',3:'肿瘤科',4:'儿科',5:'外科'}[int(fenzhen_prodict.argmax())] 34 | result.append(ks) 35 | with open(parser.result_file_name,'w',encoding='utf-8') as f: 36 | for i in result: 37 | f.write(i+'\n') 38 | 39 | else: 40 | assert parser.message != None,print('请传入文本') 41 | text = fenzhen_tokenizer(parser.message, max_length=512, return_tensors='pt') 42 | input_ids = text['input_ids'].to(device) 43 | fenzhen_prodict = fenzhen_model(input_ids)[0] 44 | ks = {0: '男科', 1: '内科', 2: '妇科', 3: '肿瘤科', 4: '儿科', 5: '外科'}[int(fenzhen_prodict.argmax())] 45 | return ks 46 | -------------------------------------------------------------------------------- /README_ZH.md: -------------------------------------------------------------------------------- 1 | # LingYi 技术文档 2 | 3 | 4 | 5 | ### 您可前往[灵医小智](http://kg.wengsyx.com)预览我们的系统。 6 | 7 | ### 8 | 9 | #### 如果可以通过准确有效的智能问诊系统,解决信任危机与医疗准确性问题,实时与患者进行沟通,收集患者信息,从而提升临床诊断的效率、减轻医生的负担、提高复诊积极性等,将可以有效缓解医生资源匮乏、医疗水平不平衡不充分、公众对互联网医疗信息存有疑虑等问题。 10 | 11 | #### 我们以中文医学领域为例,制作了面向个人的自动化诊疗系统,命名为“灵医小智”。 12 | 13 | #### “灵医小智”同时支持医院场景下智能预问诊台的搭建、药店场景下的咨询开药机器人搭建以及个人场景下的私人医生应用的搭建,深入医疗服务场景,凭借海量的医学知识和大规模预训练语言模型,服务于患者线上、线下医疗需求的全场景,提升医生服务效率,改善患者就医体验。 14 | 15 |
16 |

17 |

18 | 分诊 | 19 | 推理 | 20 | 生成 | 21 | 数据 | 22 | 病历 23 |

24 |

25 |
26 | 27 | 28 | 29 |

30 | 整体流程 31 |

32 | 33 | #### 我们分别制作了八大模块,全方位以对话形式轻松快捷准确地为用户解决医疗问题,相比于市面上的“灵医智慧”、"左手医生"等智慧医疗产品,我们的个性化私人医生“灵医小智”项目致力于打造主动式友好医疗,覆盖“诊前”“诊中”“诊后”,基于悟道模型卓越的生成性能和图像识别能力,多模态感知患者病情,灵活多变的拟人化语言与患者进行自由交流,极大改善用户体验,缓解了医院的线下医疗服务压力,助力于患者居家康复管理。 34 | 35 |
img
36 | 37 | 38 | 39 | 40 | 41 | ### 分诊 42 | 43 | ##### 根据实际场景及调研需求,引入分诊模块,缩小患病范围,提升问诊精度 44 | 45 |
img
46 | 47 | 48 | 49 | ### 推理 50 | 51 | ##### 实体消歧技术,有利于对用户话语进行归一化处理。我们使用了[ADBCMM技术](https://github.com/WENGSYX/ADBCMM),使得我们的模型能够准确判断用户所说实体的全称含义。 52 | 53 | ##### 消歧得到的实体,我们使用知识图谱推理技术,结合学界前沿知识问答技术,通过实体链接和多跳推理、路径排序等方式,并引入流程点方式,更为可控地搜寻下一步的目标实体,诸如可能患有的其他症状,进一步确诊所需的检查项目以及此疾病治疗所需药物等。 54 | 55 |
img
56 | 57 | 58 | 59 | ### 生成 60 | 61 | ##### 这一步中,将用到预训练生成模型。由于预训练生成模型是在通用文本上进行预训练,但是专业领域的推断能力不足。因此我们使用了ENTITY-PROMPT-LERARNING方法,在训练过程中,就将下句实体与流程点一并作为输入,进行训练。 62 | 63 | ##### 通过PROMPT的方式将对话上下文信息与实体信息进行融合,使得最后的回复具有预测的实体信息。 64 | 65 |
img
66 | 67 | 68 | 69 | 70 | 71 | ### 病历 72 | 73 | ##### 医生与病人进行自动化问诊之后,需要就诊疗过程进行就诊报告的撰写,对病人的整体情况情况进行描述,我们基于悟道CPM模型,采用Casual Language model等技术,对比最先进的方法(SOTA)在CCL数据集上领先平均得分2.05分 74 | 75 |
img
76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | # 项目评价指标 84 | 85 |
img
86 | 87 | 88 | 89 | 90 | 91 | # 项目总结 92 | 93 | **易用性**:通过医院机器人语音进行预问诊,并支持图片问诊,具备友好性。 94 | 95 | **功能数**:丰富的系统功能,具备八大医疗模块。 96 | 97 | **合作性**:产学研合作,分工明确,准确把握相关需求。 98 | 99 | **商业性**:多模模块中图文推荐药品一键购买。 100 | 101 | **覆盖面**:三十五大重点科室全覆盖。 102 | 103 | **专业性**:使用Prompt算法融入知识图谱推理的实体,相比直接生成更具专业性。 104 | 105 | **领先性**:采用本年度自然语言处理竞赛SOTA方案,进一步提升准确性。 106 | 107 | **可控性**:Entity-Prompt与流程点推理,增加可控性。 108 | 109 | **公益性**:有效缓解医生资源匮乏、医疗水平不平衡不充分。 110 | 111 | 112 | 113 |
img
114 | 115 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## MedConQA: Medical Conversational Question Answering System based on Knowledge Graphs 2 | 3 | ###### 4 | 5 | ### How to use 6 | 7 | ##### We offer three experience types: 8 | 9 | *1. Triage* You can take your symptoms as input, and we will triage it. 10 | 11 | *2. Dagnosis* You can ask any medical questions 12 | 13 | *3. Summary* Take the doctor-patient dialogue as input, and the medical record will be output 14 | 15 | ##### We offer three input mode: 16 | 17 | *1.interactive* Support human-computer interaction experience demo 18 | 19 | *2.api* Provide API interface for direct call 20 | 21 | *3.batch* Provide input files for batch processing 22 | 23 | 24 | 25 | ### Example 26 | 27 | ``` 28 | python main.py --mode api --type Dagnosis --message 我肚子好疼 29 | ``` 30 | 31 |
img
32 | 33 | ``` 34 | python main.py --mode interactiv --type Dagnosis 35 | ``` 36 | 37 |
img
38 | 39 | ``` 40 | python main.py --mode batch --type Summary --file_name input.csv --result_file_name result.csv 41 | ``` 42 | 43 |
img
44 | 45 | 46 | ## Cite 47 | ``` 48 | @inproceedings{xia-etal-2022-medconqa, 49 | title = "{M}ed{C}on{QA}: Medical Conversational Question Answering System based on Knowledge Graphs", 50 | author = "Xia, Fei and 51 | Li, Bin and 52 | Weng, Yixuan and 53 | He, Shizhu and 54 | Liu, Kang and 55 | Sun, Bin and 56 | Li, Shutao and 57 | Zhao, Jun", 58 | booktitle = "Proceedings of the The 2022 Conference on Empirical Methods in Natural Language Processing: System Demonstrations", 59 | month = dec, 60 | year = "2022", 61 | address = "Abu Dhabi, UAE", 62 | publisher = "Association for Computational Linguistics", 63 | url = "https://aclanthology.org/2022.emnlp-demos.15", 64 | pages = "148--158", 65 | abstract = "The medical conversational system can relieve doctors{'} burden and improve healthcare efficiency, especially during the COVID-19 pandemic. However, the existing medical dialogue systems have the problems of weak scalability, insufficient knowledge, and poor controllability. Thus, we propose a medical conversational question-answering (CQA) system based on the knowledge graph, namely MedConQA, which is designed as a pipeline framework to maintain high flexibility. Our system utilizes automated medical procedures, including medical triage, consultation, image-text drug recommendation, and record. Each module has been open-sourced as a tool, which can be used alone or in combination, with robust scalability. Besides, to conduct knowledge-grounded dialogues with users, we first construct a Chinese Medical Knowledge Graph (CMKG) and collect a large-scale Chinese Medical CQA (CMCQA) dataset, and we design a series of methods for reasoning more intellectually. Finally, we use several state-of-the-art (SOTA) techniques to keep the final generated response more controllable, which is further assured by hospital and professional evaluations. We have open-sourced related code, datasets, web pages, and tools, hoping to advance future research.", 66 | } 67 | ``` 68 | -------------------------------------------------------------------------------- /Dagnosis/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForSequenceClassification,BertTokenizer 3 | from modeling_cpt import CPTForConditionalGeneration 4 | from Triage import * 5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 6 | FENZHEN_MODEL_NAME = 'WENGSYX/Dagnosis_Chinese_BERT' 7 | CMDD_MODEL_NAME = 'WENGSYX/Dagnosis_Chinese_CPT' 8 | BL_MODEL_NAME = 'WENGSYX/Dagnosis_Chinese_CPT' 9 | 10 | cmdd_model = CPTForConditionalGeneration.from_pretrained(CMDD_MODEL_NAME) 11 | cmdd_model = cmdd_model.to(device) 12 | berttokenizer = BertTokenizer.from_pretrained(CMDD_MODEL_NAME) 13 | 14 | 15 | def Dagnosis(parser): 16 | 17 | if parser.mode == 'interactive': 18 | while True: 19 | message = input('医生:\033[0;34m您好,有什么我能帮您?\033[0m(退出请输入\033[0;35m退出\033[0m)') 20 | if message == '退出': 21 | return '' 22 | text = berttokenizer('患者:'+message,padding='max_length', truncation=True, max_length=512,return_tensors='pt') 23 | input_ids = text['input_ids'].to(device) 24 | attention_mask = text['attention_mask'].to(device) 25 | token_type_ids = text['token_type_ids'].to(device) 26 | 27 | out = berttokenizer.batch_decode( 28 | cmdd_model.generate(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, max_length=256))[0] 29 | out = out.replace('[SEP]', '').replace('[CLS]', '').replace(' ', '').replace('指导意见:', '').replace('病情分析:', 30 | '') 31 | print('医生: \033[0;34m{}\033[0m'.format(out)) 32 | elif parser.mode == 'batch': 33 | with open(parser.file_name,'r',encoding='utf-8') as f: 34 | data = f.readlines() 35 | result = [] 36 | for message in data: 37 | text = berttokenizer('患者:'+message,padding='max_length', truncation=True, max_length=512,return_tensors='pt') 38 | input_ids = text['input_ids'].to(device) 39 | attention_mask = text['attention_mask'].to(device) 40 | token_type_ids = text['token_type_ids'].to(device) 41 | 42 | out = berttokenizer.batch_decode( 43 | cmdd_model.generate(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, max_length=256))[0] 44 | out = out.replace('[SEP]', '').replace('[CLS]', '').replace(' ', '').replace('指导意见:', '').replace('病情分析:', 45 | '') 46 | result.append(out) 47 | with open(parser.result_file_name,'w',encoding='utf-8') as f: 48 | for i in result: 49 | f.write(i+'\n') 50 | 51 | else: 52 | assert parser.message != None,print('请传入文本') 53 | text = berttokenizer('患者:' + parser.message, padding='max_length', truncation=True, max_length=512, return_tensors='pt') 54 | input_ids = text['input_ids'].to(device) 55 | attention_mask = text['attention_mask'].to(device) 56 | token_type_ids = text['token_type_ids'].to(device) 57 | 58 | out = berttokenizer.batch_decode( 59 | cmdd_model.generate(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, max_length=256))[0] 60 | out = out.replace('[SEP]', '').replace('[CLS]', '').replace(' ', '').replace('指导意见:', '').replace('病情分析:', 61 | '') 62 | return out 63 | -------------------------------------------------------------------------------- /Summary/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForSequenceClassification,BertTokenizer 3 | from .modeling_cpt import CPTForConditionalGeneration 4 | from Triage import * 5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 6 | FENZHEN_MODEL_NAME = 'WENGSYX/Dagnosis_Chinese_BERT' 7 | CMDD_MODEL_NAME = 'WENGSYX/Dagnosis_Chinese_CPT' 8 | BL_MODEL_NAME = 'WENGSYX/Dagnosis_Chinese_CPT' 9 | fenzhen_model = AutoModelForSequenceClassification.from_pretrained(FENZHEN_MODEL_NAME, num_labels=6).to(device) # 模型 10 | fenzhen_tokenizer = BertTokenizer.from_pretrained(FENZHEN_MODEL_NAME) 11 | 12 | bl_model = CPTForConditionalGeneration.from_pretrained(BL_MODEL_NAME) 13 | bl_model = bl_model.to(device) 14 | berttokenizer = BertTokenizer.from_pretrained(CMDD_MODEL_NAME) 15 | 16 | def Summary(parser): 17 | 18 | if parser.mode == 'interactive': 19 | while True: 20 | message = input('请输入对话历史,让我来帮您记录病例信息:(退出请输入\033[0;35m退出\033[0m)') 21 | if message == '退出': 22 | return '' 23 | text = berttokenizer(message,padding='max_length', truncation=True, max_length=512,return_tensors='pt') 24 | input_ids = text['input_ids'].to(device) 25 | attention_mask = text['attention_mask'].to(device) 26 | token_type_ids = text['token_type_ids'].to(device) 27 | 28 | out = berttokenizer.batch_decode( 29 | bl_model.generate(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, max_length=256))[0] 30 | out = out.replace('[SEP]', '').replace('[CLS]', '').replace(' ', '').replace('指导意见:', '').replace('病情分析:', 31 | '') 32 | print('医生: \033[0;34m{}\033[0m'.format(out)) 33 | elif parser.mode == 'batch': 34 | with open(parser.file_name,'r',encoding='utf-8') as f: 35 | data = f.readlines() 36 | result = [] 37 | for message in data: 38 | text = berttokenizer(message,padding='max_length', truncation=True, max_length=512,return_tensors='pt') 39 | input_ids = text['input_ids'].to(device) 40 | attention_mask = text['attention_mask'].to(device) 41 | token_type_ids = text['token_type_ids'].to(device) 42 | 43 | out = berttokenizer.batch_decode( 44 | bl_model.generate(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, max_length=256))[0] 45 | out = out.replace('[SEP]', '').replace('[CLS]', '').replace(' ', '').replace('指导意见:', '').replace('病情分析:', 46 | '') 47 | result.append(out) 48 | with open(parser.result_file_name,'w',encoding='utf-8') as f: 49 | for i in result: 50 | f.write(i+'\n') 51 | 52 | else: 53 | assert parser.message != None,print('请传入文本') 54 | text = berttokenizer(parser.message, padding='max_length', truncation=True, max_length=512, return_tensors='pt') 55 | input_ids = text['input_ids'].to(device) 56 | attention_mask = text['attention_mask'].to(device) 57 | token_type_ids = text['token_type_ids'].to(device) 58 | 59 | out = berttokenizer.batch_decode( 60 | bl_model.generate(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, max_length=256))[0] 61 | out = out.replace('[SEP]', '').replace('[CLS]', '').replace(' ', '').replace('指导意见:', '').replace('病情分析:','') 62 | return out 63 | -------------------------------------------------------------------------------- /modeling_cpt.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ PyTorch CPT model. modified from transformers==4.4.1""" 16 | import copy 17 | import math 18 | import random 19 | import warnings 20 | from typing import Optional, Tuple 21 | 22 | import torch 23 | import torch.nn.functional as F 24 | from torch.nn.modules.loss import KLDivLoss, NLLLoss 25 | import torch.utils.checkpoint 26 | from torch import nn 27 | from torch.nn import CrossEntropyLoss 28 | 29 | from transformers.activations import ACT2FN 30 | from transformers.file_utils import ( 31 | add_code_sample_docstrings, 32 | add_end_docstrings, 33 | add_start_docstrings, 34 | add_start_docstrings_to_model_forward, 35 | replace_return_docstrings, 36 | ) 37 | from transformers.modeling_outputs import ( 38 | BaseModelOutput, 39 | BaseModelOutputWithPastAndCrossAttentions, 40 | CausalLMOutputWithCrossAttentions, 41 | Seq2SeqLMOutput, 42 | Seq2SeqModelOutput, 43 | Seq2SeqQuestionAnsweringModelOutput, 44 | Seq2SeqSequenceClassifierOutput, 45 | ) 46 | from transformers.modeling_utils import PreTrainedModel 47 | from transformers.utils import logging 48 | from transformers import BartConfig as CPTConfig 49 | from transformers import BertModel, BertConfig 50 | 51 | from torch.nn import LayerNorm 52 | 53 | logger = logging.get_logger(__name__) 54 | 55 | _CHECKPOINT_FOR_DOC = "fnlp/cpt-large" 56 | _CONFIG_FOR_DOC = "CPTConfig" 57 | _TOKENIZER_FOR_DOC = "CPTTokenizer" 58 | 59 | 60 | CPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ 61 | "fnlp/cpt-large", 62 | ] 63 | 64 | 65 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): 66 | """ 67 | Shift input ids one token to the right. 68 | """ 69 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 70 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() 71 | shifted_input_ids[:, 0] = decoder_start_token_id 72 | 73 | assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." 74 | # replace possible -100 values in labels by `pad_token_id` 75 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 76 | 77 | return shifted_input_ids 78 | 79 | 80 | def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): 81 | """ 82 | Make causal mask used for bi-directional self-attention. 83 | """ 84 | bsz, tgt_len = input_ids_shape 85 | mask = torch.full((tgt_len, tgt_len), float("-inf")) 86 | mask_cond = torch.arange(mask.size(-1)) 87 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 88 | mask = mask.to(dtype) 89 | 90 | if past_key_values_length > 0: 91 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) 92 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) 93 | 94 | 95 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 96 | """ 97 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 98 | """ 99 | bsz, src_len = mask.size() 100 | tgt_len = tgt_len if tgt_len is not None else src_len 101 | 102 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 103 | 104 | inverted_mask = 1.0 - expanded_mask 105 | 106 | return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) 107 | 108 | def attention_mask_func(attention_scores, attention_mask): 109 | return attention_scores + attention_mask 110 | 111 | def init_method(std): 112 | def init_(tensor): 113 | return torch.nn.init.normal_(tensor, mean=0.0, std=std) 114 | 115 | return init_ 116 | 117 | class CPTLearnedPositionalEmbedding(nn.Embedding): 118 | """ 119 | This module learns positional embeddings up to a fixed maximum size. 120 | """ 121 | 122 | def __init__(self, num_embeddings: int, embedding_dim: int): 123 | # CPT is set up so that if padding_idx is specified then offset the embedding ids by 2 124 | # and adjust num_embeddings appropriately. Other models dont have this hack 125 | self.offset = 2 126 | super().__init__(num_embeddings + self.offset, embedding_dim) 127 | 128 | def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): 129 | """`input_ids_shape` is expected to be [bsz x seqlen].""" 130 | bsz, seq_len = input_ids_shape[:2] 131 | positions = torch.arange( 132 | past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device 133 | ) 134 | return super().forward(positions + self.offset) 135 | 136 | 137 | class CPTAttention(nn.Module): 138 | """Multi-headed attention from 'Attention Is All You Need' paper""" 139 | 140 | def __init__( 141 | self, 142 | embed_dim: int, 143 | num_heads: int, 144 | dropout: float = 0.0, 145 | is_decoder: bool = False, 146 | bias: bool = True, 147 | ): 148 | super().__init__() 149 | self.embed_dim = embed_dim 150 | self.num_heads = num_heads 151 | self.dropout = dropout 152 | self.head_dim = embed_dim // num_heads 153 | assert ( 154 | self.head_dim * num_heads == self.embed_dim 155 | ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})." 156 | self.scaling = self.head_dim ** -0.5 157 | self.is_decoder = is_decoder 158 | 159 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 160 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 161 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 162 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 163 | 164 | 165 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 166 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 167 | 168 | def forward( 169 | self, 170 | hidden_states: torch.Tensor, 171 | key_value_states: Optional[torch.Tensor] = None, 172 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 173 | attention_mask: Optional[torch.Tensor] = None, 174 | layer_head_mask: Optional[torch.Tensor] = None, 175 | output_attentions: bool = False, 176 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 177 | """Input shape: Batch x Time x Channel""" 178 | 179 | # if key_value_states are provided this layer is used as a cross-attention layer 180 | # for the decoder 181 | is_cross_attention = key_value_states is not None 182 | bsz, tgt_len, embed_dim = hidden_states.size() 183 | 184 | # get query proj 185 | query_states = self.q_proj(hidden_states) * self.scaling 186 | # get key, value proj 187 | if is_cross_attention and past_key_value is not None: 188 | # reuse k,v, cross_attentions 189 | key_states = past_key_value[0] 190 | value_states = past_key_value[1] 191 | elif is_cross_attention: 192 | # cross_attentions 193 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz) 194 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz) 195 | elif past_key_value is not None: 196 | # reuse k, v, self_attention 197 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 198 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 199 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 200 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 201 | else: 202 | # self_attention 203 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 204 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 205 | 206 | if self.is_decoder: 207 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 208 | # Further calls to cross_attention layer can then reuse all cross-attention 209 | # key/value_states (first "if" case) 210 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 211 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 212 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 213 | # if encoder bi-directional self-attention `past_key_value` is always `None` 214 | past_key_value = (key_states, value_states) 215 | 216 | proj_shape = (bsz * self.num_heads, -1, self.head_dim) 217 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) 218 | key_states = key_states.view(*proj_shape) 219 | value_states = value_states.view(*proj_shape) 220 | 221 | src_len = key_states.size(1) 222 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) 223 | 224 | assert attn_weights.size() == ( 225 | bsz * self.num_heads, 226 | tgt_len, 227 | src_len, 228 | ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" 229 | 230 | if attention_mask is not None: 231 | assert attention_mask.size() == ( 232 | bsz, 233 | 1, 234 | tgt_len, 235 | src_len, 236 | ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" 237 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask 238 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 239 | 240 | attn_weights = F.softmax(attn_weights, dim=-1) 241 | 242 | if layer_head_mask is not None: 243 | assert layer_head_mask.size() == ( 244 | self.num_heads, 245 | ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" 246 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 247 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 248 | 249 | if output_attentions: 250 | # this operation is a bit akward, but it's required to 251 | # make sure that attn_weights keeps its gradient. 252 | # In order to do so, attn_weights have to reshaped 253 | # twice and have to be reused in the following 254 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 255 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) 256 | else: 257 | attn_weights_reshaped = None 258 | 259 | # with mpu.get_cuda_rng_tracker().fork(): 260 | attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) 261 | 262 | attn_output = torch.bmm(attn_probs, value_states) 263 | 264 | assert attn_output.size() == ( 265 | bsz * self.num_heads, 266 | tgt_len, 267 | self.head_dim, 268 | ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" 269 | 270 | attn_output = ( 271 | attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) 272 | .transpose(1, 2) 273 | .reshape(bsz, tgt_len, embed_dim) 274 | ) 275 | 276 | attn_output = self.out_proj(attn_output) 277 | 278 | return attn_output, attn_weights_reshaped, past_key_value 279 | 280 | class CPTDecoderLayer(nn.Module): 281 | def __init__(self, config: CPTConfig): 282 | super().__init__() 283 | self.embed_dim = config.d_model 284 | 285 | self.self_attn = CPTAttention( 286 | embed_dim=self.embed_dim, 287 | num_heads=config.decoder_attention_heads, 288 | dropout=config.attention_dropout, 289 | is_decoder=True, 290 | ) 291 | self.dropout = config.dropout 292 | self.activation_fn = ACT2FN[config.activation_function] 293 | self.activation_dropout = config.activation_dropout 294 | 295 | self.self_attn_layer_norm = LayerNorm(self.embed_dim) 296 | self.encoder_attn = CPTAttention( 297 | self.embed_dim, 298 | config.decoder_attention_heads, 299 | dropout=config.attention_dropout, 300 | is_decoder=True, 301 | ) 302 | self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) 303 | self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) 304 | self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) 305 | self.final_layer_norm = LayerNorm(self.embed_dim) 306 | 307 | def forward( 308 | self, 309 | hidden_states: torch.Tensor, 310 | attention_mask: Optional[torch.Tensor] = None, 311 | encoder_hidden_states: Optional[torch.Tensor] = None, 312 | encoder_attention_mask: Optional[torch.Tensor] = None, 313 | layer_head_mask: Optional[torch.Tensor] = None, 314 | encoder_layer_head_mask: Optional[torch.Tensor] = None, 315 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 316 | output_attentions: Optional[bool] = False, 317 | use_cache: Optional[bool] = True, 318 | ): 319 | """ 320 | Args: 321 | hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` 322 | attention_mask (:obj:`torch.FloatTensor`): attention mask of size 323 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 324 | encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` 325 | encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size 326 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 327 | layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size 328 | `(config.encoder_attention_heads,)`. 329 | encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of 330 | size `(config.encoder_attention_heads,)`. 331 | past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states 332 | output_attentions (:obj:`bool`, `optional`): 333 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under 334 | returned tensors for more detail. 335 | """ 336 | residual = hidden_states 337 | 338 | # Self Attention 339 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 340 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 341 | # add present self-attn cache to positions 1,2 of present_key_value tuple 342 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 343 | hidden_states=hidden_states, 344 | past_key_value=self_attn_past_key_value, 345 | attention_mask=attention_mask, 346 | layer_head_mask=layer_head_mask, 347 | output_attentions=output_attentions, 348 | ) 349 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 350 | hidden_states = residual + hidden_states 351 | hidden_states = self.self_attn_layer_norm(hidden_states) 352 | 353 | # Cross-Attention Block 354 | cross_attn_present_key_value = None 355 | cross_attn_weights = None 356 | if encoder_hidden_states is not None: 357 | residual = hidden_states 358 | 359 | # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple 360 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None 361 | hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( 362 | hidden_states=hidden_states, 363 | key_value_states=encoder_hidden_states, 364 | attention_mask=encoder_attention_mask, 365 | layer_head_mask=encoder_layer_head_mask, 366 | past_key_value=cross_attn_past_key_value, 367 | output_attentions=output_attentions, 368 | ) 369 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 370 | hidden_states = residual + hidden_states 371 | hidden_states = self.encoder_attn_layer_norm(hidden_states) 372 | 373 | # add cross-attn to positions 3,4 of present_key_value tuple 374 | present_key_value = present_key_value + cross_attn_present_key_value 375 | 376 | # Fully Connected 377 | residual = hidden_states 378 | hidden_states = self.activation_fn(self.fc1(hidden_states)) 379 | hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) 380 | hidden_states = self.fc2(hidden_states) 381 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 382 | hidden_states = residual + hidden_states 383 | hidden_states = self.final_layer_norm(hidden_states) 384 | 385 | outputs = (hidden_states,) 386 | 387 | if output_attentions: 388 | outputs += (self_attn_weights, cross_attn_weights) 389 | 390 | if use_cache: 391 | outputs += (present_key_value,) 392 | 393 | return outputs 394 | 395 | 396 | class CPTClassificationHead(nn.Module): 397 | """Head for sentence-level classification tasks.""" 398 | 399 | def __init__( 400 | self, 401 | input_dim: int, 402 | inner_dim: int, 403 | num_classes: int, 404 | pooler_dropout: float, 405 | ): 406 | super().__init__() 407 | self.dense = nn.Linear(input_dim, inner_dim) 408 | self.dropout = nn.Dropout(p=pooler_dropout) 409 | self.out_proj = nn.Linear(inner_dim, num_classes) 410 | 411 | def forward(self, hidden_states: torch.Tensor): 412 | hidden_states = self.dropout(hidden_states) 413 | hidden_states = self.dense(hidden_states) 414 | hidden_states = torch.tanh(hidden_states) 415 | hidden_states = self.dropout(hidden_states) 416 | hidden_states = self.out_proj(hidden_states) 417 | return hidden_states 418 | 419 | 420 | class CPTPretrainedModel(PreTrainedModel): 421 | config_class = CPTConfig 422 | base_model_prefix = "model" 423 | 424 | def _init_weights(self, module): 425 | std = self.config.init_std 426 | if isinstance(module, nn.Linear): 427 | module.weight.data.normal_(mean=0.0, std=std) 428 | if module.bias is not None: 429 | module.bias.data.zero_() 430 | elif isinstance(module, nn.Embedding): 431 | module.weight.data.normal_(mean=0.0, std=std) 432 | if module.padding_idx is not None: 433 | module.weight.data[module.padding_idx].zero_() 434 | 435 | @property 436 | def dummy_inputs(self): 437 | pad_token = self.config.pad_token_id 438 | input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) 439 | dummy_inputs = { 440 | "attention_mask": input_ids.ne(pad_token), 441 | "input_ids": input_ids, 442 | } 443 | return dummy_inputs 444 | 445 | CPT_START_DOCSTRING = r""" 446 | This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic 447 | methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, 448 | pruning heads etc.) 449 | 450 | This model is also a PyTorch `torch.nn.Module `__ 451 | subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to 452 | general usage and behavior. 453 | 454 | Parameters: 455 | config (:class:`~transformers.CPTConfig`): 456 | Model configuration class with all the parameters of the model. Initializing with a config file does not 457 | load the weights associated with the model, only the configuration. Check out the 458 | :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. 459 | """ 460 | 461 | CPT_INPUTS_DOCSTRING = r""" 462 | Args: 463 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): 464 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 465 | it. 466 | 467 | Indices can be obtained using :class:`~transformers.CPTTokenizer`. See 468 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for 469 | details. 470 | 471 | `What are input IDs? <../glossary.html#input-ids>`__ 472 | attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 473 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 474 | 475 | - 1 for tokens that are **not masked**, 476 | - 0 for tokens that are **masked**. 477 | 478 | `What are attention masks? <../glossary.html#attention-mask>`__ 479 | decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): 480 | Indices of decoder input sequence tokens in the vocabulary. 481 | 482 | Indices can be obtained using :class:`~transformers.CPTTokenizer`. See 483 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for 484 | details. 485 | 486 | `What are input IDs? <../glossary.html#input-ids>`__ 487 | 488 | CPT uses the :obj:`eos_token_id` as the starting token for :obj:`decoder_input_ids` generation. If 489 | :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see 490 | :obj:`past_key_values`). 491 | 492 | For translation and summarization training, :obj:`decoder_input_ids` should be provided. If no 493 | :obj:`decoder_input_ids` is provided, the model will create this tensor by shifting the :obj:`input_ids` to 494 | the right for denoising pre-training following the paper. 495 | decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): 496 | Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will 497 | also be used by default. 498 | 499 | If you want to change padding behavior, you should read :func:`modeling_cpt._prepare_decoder_inputs` and 500 | modify to your needs. See diagram 1 in `the paper `__ for more 501 | information on the default strategy. 502 | head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): 503 | Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: 504 | 505 | - 1 indicates the head is **not masked**, 506 | - 0 indicates the heas is **masked**. 507 | 508 | decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): 509 | Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: 510 | 511 | - 1 indicates the head is **not masked**, 512 | - 0 indicates the head is **masked**. 513 | 514 | encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): 515 | Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: 516 | :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, 517 | `optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the 518 | cross-attention of the decoder. 519 | past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 520 | Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. 521 | 522 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` 523 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` 524 | instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. 525 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 526 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. 527 | This is useful if you want more control over how to convert :obj:`input_ids` indices into associated 528 | vectors than the model's internal embedding lookup matrix. 529 | decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`): 530 | Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded 531 | representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds` 532 | have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert 533 | :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. 534 | 535 | If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds` 536 | takes the value of :obj:`inputs_embeds`. 537 | use_cache (:obj:`bool`, `optional`): 538 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up 539 | decoding (see :obj:`past_key_values`). 540 | output_attentions (:obj:`bool`, `optional`): 541 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned 542 | tensors for more detail. 543 | output_hidden_states (:obj:`bool`, `optional`): 544 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for 545 | more detail. 546 | return_dict (:obj:`bool`, `optional`): 547 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 548 | """ 549 | 550 | class CPTDecoder(CPTPretrainedModel): 551 | """ 552 | Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`CPTDecoderLayer` 553 | 554 | Args: 555 | config: CPTConfig 556 | embed_tokens (torch.nn.Embedding): output embedding 557 | """ 558 | 559 | def __init__(self, config: CPTConfig, embed_tokens: Optional[nn.Embedding] = None): 560 | super().__init__(config) 561 | self.dropout = config.dropout 562 | self.layerdrop = config.decoder_layerdrop 563 | self.padding_idx = config.pad_token_id 564 | self.max_target_positions = config.max_position_embeddings 565 | self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 566 | 567 | if embed_tokens is not None: 568 | self.embed_tokens = embed_tokens 569 | else: 570 | self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) 571 | 572 | self.embed_positions = CPTLearnedPositionalEmbedding( 573 | config.max_position_embeddings, 574 | config.d_model, 575 | ) 576 | self.layers = nn.ModuleList([CPTDecoderLayer(config) for _ in range(config.decoder_layers)]) 577 | self.layernorm_embedding = LayerNorm(config.d_model) 578 | 579 | self.init_weights() 580 | 581 | def get_input_embeddings(self): 582 | return self.embed_tokens 583 | 584 | def set_input_embeddings(self, value): 585 | self.embed_tokens = value 586 | 587 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 588 | # create causal mask 589 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 590 | combined_attention_mask = None 591 | if input_shape[-1] > 1: 592 | combined_attention_mask = _make_causal_mask( 593 | input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length 594 | ).to(self.device) 595 | 596 | if attention_mask is not None: 597 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 598 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) 599 | combined_attention_mask = ( 600 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 601 | ) 602 | 603 | return combined_attention_mask 604 | 605 | def forward( 606 | self, 607 | input_ids=None, 608 | attention_mask=None, 609 | encoder_hidden_states=None, 610 | encoder_attention_mask=None, 611 | head_mask=None, 612 | encoder_head_mask=None, 613 | past_key_values=None, 614 | inputs_embeds=None, 615 | use_cache=None, 616 | output_attentions=None, 617 | output_hidden_states=None, 618 | return_dict=None, 619 | ): 620 | r""" 621 | Args: 622 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): 623 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 624 | provide it. 625 | 626 | Indices can be obtained using :class:`~transformers.CPTTokenizer`. See 627 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` 628 | for details. 629 | 630 | `What are input IDs? <../glossary.html#input-ids>`__ 631 | attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 632 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 633 | 634 | - 1 for tokens that are **not masked**, 635 | - 0 for tokens that are **masked**. 636 | 637 | `What are attention masks? <../glossary.html#attention-mask>`__ 638 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`): 639 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention 640 | of the decoder. 641 | encoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, encoder_sequence_length)`, `optional`): 642 | Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values 643 | selected in ``[0, 1]``: 644 | 645 | - 1 for tokens that are **not masked**, 646 | - 0 for tokens that are **masked**. 647 | 648 | `What are attention masks? <../glossary.html#attention-mask>`__ 649 | head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): 650 | Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: 651 | 652 | - 1 indicates the head is **not masked**, 653 | - 0 indicates the heas is **masked**. 654 | 655 | encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): 656 | Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention 657 | on hidden heads. Mask values selected in ``[0, 1]``: 658 | 659 | - 1 indicates the head is **not masked**, 660 | - 0 indicates the heas is **masked**. 661 | 662 | past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 663 | Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up 664 | decoding. 665 | 666 | If :obj:`past_key_values` are used, the user can optionally input only the last 667 | :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of 668 | shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, 669 | sequence_length)`. 670 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 671 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded 672 | representation. This is useful if you want more control over how to convert :obj:`input_ids` indices 673 | into associated vectors than the model's internal embedding lookup matrix. 674 | output_attentions (:obj:`bool`, `optional`): 675 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under 676 | returned tensors for more detail. 677 | output_hidden_states (:obj:`bool`, `optional`): 678 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors 679 | for more detail. 680 | return_dict (:obj:`bool`, `optional`): 681 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 682 | """ 683 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 684 | output_hidden_states = ( 685 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 686 | ) 687 | use_cache = use_cache if use_cache is not None else self.config.use_cache 688 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 689 | 690 | # retrieve input_ids and inputs_embeds 691 | if input_ids is not None and inputs_embeds is not None: 692 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 693 | elif input_ids is not None: 694 | input_shape = input_ids.size() 695 | input_ids = input_ids.view(-1, input_shape[-1]) 696 | elif inputs_embeds is not None: 697 | input_shape = inputs_embeds.size()[:-1] 698 | else: 699 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 700 | 701 | # past_key_values_length 702 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 703 | 704 | if inputs_embeds is None: 705 | inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale 706 | 707 | attention_mask = self._prepare_decoder_attention_mask( 708 | attention_mask, input_shape, inputs_embeds, past_key_values_length 709 | ) 710 | 711 | # expand encoder attention mask 712 | if encoder_hidden_states is not None and encoder_attention_mask is not None: 713 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 714 | encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) 715 | 716 | # embed positions 717 | positions = self.embed_positions(input_shape, past_key_values_length) 718 | 719 | hidden_states = inputs_embeds + positions 720 | hidden_states = self.layernorm_embedding(hidden_states) 721 | 722 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 723 | 724 | # decoder layers 725 | all_hidden_states = () if output_hidden_states else None 726 | all_self_attns = () if output_attentions else None 727 | all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None 728 | next_decoder_cache = () if use_cache else None 729 | 730 | # check if head_mask has a correct number of layers specified if desired 731 | if head_mask is not None: 732 | assert head_mask.size()[0] == ( 733 | len(self.layers) 734 | ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." 735 | for idx, decoder_layer in enumerate(self.layers): 736 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 737 | if output_hidden_states: 738 | all_hidden_states += (hidden_states,) 739 | dropout_probability = random.uniform(0, 1) 740 | if self.training and (dropout_probability < self.layerdrop): 741 | continue 742 | 743 | past_key_value = past_key_values[idx] if past_key_values is not None else None 744 | 745 | if getattr(self.config, "gradient_checkpointing", False) and self.training: 746 | 747 | if use_cache: 748 | logger.warn( 749 | "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " 750 | "`use_cache=False`..." 751 | ) 752 | use_cache = False 753 | 754 | def create_custom_forward(module): 755 | def custom_forward(*inputs): 756 | # None for past_key_value 757 | return module(*inputs, output_attentions, use_cache) 758 | 759 | return custom_forward 760 | 761 | # layer_outputs = mpu.checkpoint( 762 | layer_outputs = torch.utils.checkpoint( 763 | create_custom_forward(decoder_layer), 764 | hidden_states, 765 | attention_mask, 766 | encoder_hidden_states, 767 | encoder_attention_mask, 768 | head_mask[idx] if head_mask is not None else None, 769 | encoder_head_mask[idx] if encoder_head_mask is not None else None, 770 | None, 771 | ) 772 | else: 773 | 774 | layer_outputs = decoder_layer( 775 | hidden_states, 776 | attention_mask=attention_mask, 777 | encoder_hidden_states=encoder_hidden_states, 778 | encoder_attention_mask=encoder_attention_mask, 779 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 780 | encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), 781 | past_key_value=past_key_value, 782 | output_attentions=output_attentions, 783 | use_cache=use_cache, 784 | ) 785 | hidden_states = layer_outputs[0] 786 | 787 | if use_cache: 788 | next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) 789 | 790 | if output_attentions: 791 | all_self_attns += (layer_outputs[1],) 792 | 793 | if encoder_hidden_states is not None: 794 | all_cross_attentions += (layer_outputs[2],) 795 | 796 | # add hidden states from the last decoder layer 797 | if output_hidden_states: 798 | all_hidden_states += (hidden_states,) 799 | 800 | next_cache = next_decoder_cache if use_cache else None 801 | if not return_dict: 802 | return tuple( 803 | v 804 | for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] 805 | if v is not None 806 | ) 807 | return BaseModelOutputWithPastAndCrossAttentions( 808 | last_hidden_state=hidden_states, 809 | past_key_values=next_cache, 810 | hidden_states=all_hidden_states, 811 | attentions=all_self_attns, 812 | cross_attentions=all_cross_attentions, 813 | ) 814 | 815 | 816 | @add_start_docstrings( 817 | "The bare CPT Model outputting raw hidden-states without any specific head on top.", 818 | CPT_START_DOCSTRING, 819 | ) 820 | class CPTModel(CPTPretrainedModel): 821 | def __init__(self, config: CPTConfig): 822 | super().__init__(config) 823 | encoder_config = BertConfig( 824 | vocab_size=config.vocab_size, 825 | hidden_size=config.d_model, 826 | num_hidden_layers=config.encoder_layers, 827 | num_attention_heads=config.encoder_attention_heads, 828 | intermediate_size=config.encoder_ffn_dim, 829 | hidden_dropout_prob=config.activation_dropout, 830 | attention_probs_dropout_prob=config.attention_dropout, 831 | ) 832 | config.vocab_size = encoder_config.vocab_size 833 | self.encoder = BertModel(encoder_config, add_pooling_layer=False) 834 | self.shared = self.encoder.get_input_embeddings() 835 | self.decoder = CPTDecoder(config, self.shared) 836 | self.num_decoder_layers = config.decoder_layers 837 | self.init_weights() 838 | 839 | def get_input_embeddings(self): 840 | return self.shared 841 | 842 | def set_input_embeddings(self, value): 843 | self.shared = value 844 | self.encoder.embed_tokens = self.shared 845 | self.decoder.embed_tokens = self.shared 846 | 847 | def get_encoder(self): 848 | class _Encoder(torch.nn.Module): 849 | def __init__(self, encoder): 850 | super().__init__() 851 | self.encoder = encoder 852 | 853 | def forward(self, *args, **kwargs): 854 | kwargs['output_hidden_states'] = True 855 | return self.encoder(*args, **kwargs) 856 | return _Encoder(self.encoder) 857 | 858 | def get_decoder(self): 859 | return self.decoder 860 | 861 | @add_start_docstrings_to_model_forward(CPT_INPUTS_DOCSTRING) 862 | @add_code_sample_docstrings( 863 | checkpoint=_CHECKPOINT_FOR_DOC, 864 | output_type=Seq2SeqModelOutput, 865 | ) 866 | def forward( 867 | self, 868 | input_ids=None, 869 | token_type_ids=None, 870 | attention_mask=None, 871 | decoder_input_ids=None, 872 | decoder_attention_mask=None, 873 | head_mask=None, 874 | decoder_head_mask=None, 875 | encoder_outputs=None, 876 | past_key_values=None, 877 | inputs_embeds=None, 878 | decoder_inputs_embeds=None, 879 | use_cache=None, 880 | output_attentions=None, 881 | output_hidden_states=None, 882 | return_dict=None, 883 | ): 884 | 885 | # different to other models, CPT automatically creates decoder_input_ids from 886 | # input_ids if no decoder_input_ids are provided 887 | if decoder_input_ids is None and decoder_inputs_embeds is None: 888 | decoder_input_ids = shift_tokens_right( 889 | input_ids, self.config.pad_token_id, self.config.decoder_start_token_id 890 | ) 891 | 892 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 893 | output_hidden_states = ( 894 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 895 | ) 896 | use_cache = use_cache if use_cache is not None else self.config.use_cache 897 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 898 | 899 | if getattr(self.config, "gradient_checkpointing", False) and self.training: 900 | # mpu.reset_checkpointed_activations_memory_buffer() 901 | use_cache = False 902 | 903 | if encoder_outputs is None: 904 | encoder_outputs = self.encoder( 905 | input_ids=input_ids, 906 | attention_mask=attention_mask, 907 | token_type_ids=token_type_ids, 908 | head_mask=head_mask, 909 | inputs_embeds=inputs_embeds, 910 | output_attentions=output_attentions, 911 | output_hidden_states=True, 912 | return_dict=return_dict, 913 | ) 914 | # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True 915 | elif return_dict and isinstance(encoder_outputs, (tuple, list)): 916 | encoder_outputs = BaseModelOutput( 917 | last_hidden_state=encoder_outputs[0], 918 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 919 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 920 | ) 921 | 922 | if isinstance(encoder_outputs, (torch.Tensor)): 923 | encoder_hidden_states = encoder_outputs 924 | else: 925 | encoder_hidden_states = encoder_outputs[1][-self.num_decoder_layers - 1] 926 | 927 | # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) 928 | decoder_outputs = self.decoder( 929 | input_ids=decoder_input_ids, 930 | attention_mask=decoder_attention_mask, 931 | encoder_hidden_states=encoder_hidden_states, 932 | encoder_attention_mask=attention_mask, 933 | head_mask=decoder_head_mask, 934 | encoder_head_mask=head_mask, 935 | past_key_values=past_key_values, 936 | inputs_embeds=decoder_inputs_embeds, 937 | use_cache=use_cache, 938 | output_attentions=output_attentions, 939 | output_hidden_states=output_hidden_states, 940 | return_dict=return_dict, 941 | ) 942 | 943 | if not return_dict: 944 | return decoder_outputs + encoder_outputs 945 | 946 | return Seq2SeqModelOutput( 947 | last_hidden_state=decoder_outputs.last_hidden_state, 948 | past_key_values=decoder_outputs.past_key_values, 949 | decoder_hidden_states=decoder_outputs.hidden_states, 950 | decoder_attentions=decoder_outputs.attentions, 951 | cross_attentions=decoder_outputs.cross_attentions, 952 | encoder_last_hidden_state=encoder_outputs.last_hidden_state if isinstance(encoder_outputs, dict) else None, 953 | encoder_hidden_states=encoder_outputs.hidden_states if isinstance(encoder_outputs, dict) else None, 954 | encoder_attentions=encoder_outputs.attentions if isinstance(encoder_outputs, dict) else None, 955 | ) 956 | 957 | 958 | @add_start_docstrings( 959 | "The CPT Model with a language modeling head. Can be used for summarization.", CPT_START_DOCSTRING 960 | ) 961 | class CPTForConditionalGeneration(CPTPretrainedModel): 962 | base_model_prefix = "model" 963 | _keys_to_ignore_on_load_missing = [ 964 | r"final_logits_bias", 965 | r"encoder\.version", 966 | r"decoder\.version", 967 | r"lm_head\.weight", 968 | ] 969 | 970 | def __init__(self, config): 971 | super().__init__(config) 972 | self.model = CPTModel(config) 973 | self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) 974 | self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) 975 | 976 | self.init_weights() 977 | 978 | def get_encoder(self): 979 | return self.model.get_encoder() 980 | 981 | def get_decoder(self): 982 | return self.model.get_decoder() 983 | 984 | def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: 985 | new_embeddings = super().resize_token_embeddings(new_num_tokens) 986 | self._resize_final_logits_bias(new_num_tokens) 987 | return new_embeddings 988 | 989 | def _resize_final_logits_bias(self, new_num_tokens: int) -> None: 990 | old_num_tokens = self.final_logits_bias.shape[-1] 991 | if new_num_tokens <= old_num_tokens: 992 | new_bias = self.final_logits_bias[:, :new_num_tokens] 993 | else: 994 | extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) 995 | new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) 996 | self.register_buffer("final_logits_bias", new_bias) 997 | 998 | def get_output_embeddings(self): 999 | return self.lm_head 1000 | 1001 | def set_output_embeddings(self, new_embeddings): 1002 | self.lm_head = new_embeddings 1003 | 1004 | @add_start_docstrings_to_model_forward(CPT_INPUTS_DOCSTRING) 1005 | @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) 1006 | def forward( 1007 | self, 1008 | input_ids=None, 1009 | attention_mask=None, 1010 | token_type_ids=None, 1011 | decoder_input_ids=None, 1012 | decoder_attention_mask=None, 1013 | head_mask=None, 1014 | decoder_head_mask=None, 1015 | encoder_outputs=None, 1016 | past_key_values=None, 1017 | inputs_embeds=None, 1018 | decoder_inputs_embeds=None, 1019 | labels=None, 1020 | use_cache=None, 1021 | output_attentions=None, 1022 | output_hidden_states=None, 1023 | return_dict=None, 1024 | ): 1025 | r""" 1026 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 1027 | Labels for computing the masked language modeling loss. Indices should either be in ``[0, ..., 1028 | config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored 1029 | (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``. 1030 | 1031 | Returns: 1032 | """ 1033 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1034 | 1035 | if labels is not None: 1036 | if decoder_input_ids is None: 1037 | decoder_input_ids = shift_tokens_right( 1038 | labels, self.config.pad_token_id, self.config.decoder_start_token_id 1039 | ) 1040 | 1041 | outputs = self.model( 1042 | input_ids, 1043 | attention_mask=attention_mask, 1044 | decoder_input_ids=decoder_input_ids, 1045 | encoder_outputs=encoder_outputs, 1046 | token_type_ids=token_type_ids, 1047 | decoder_attention_mask=decoder_attention_mask, 1048 | head_mask=head_mask, 1049 | decoder_head_mask=decoder_head_mask, 1050 | past_key_values=past_key_values, 1051 | inputs_embeds=inputs_embeds, 1052 | decoder_inputs_embeds=decoder_inputs_embeds, 1053 | use_cache=use_cache, 1054 | output_attentions=output_attentions, 1055 | output_hidden_states=output_hidden_states, 1056 | return_dict=return_dict, 1057 | ) 1058 | lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias 1059 | 1060 | masked_lm_loss = None 1061 | if labels is not None: 1062 | loss_fct = CrossEntropyLoss() 1063 | masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) 1064 | 1065 | if not return_dict: 1066 | output = (lm_logits,) + outputs[1:] 1067 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 1068 | 1069 | return Seq2SeqLMOutput( 1070 | loss=masked_lm_loss, 1071 | logits=lm_logits, 1072 | past_key_values=outputs.past_key_values, 1073 | decoder_hidden_states=outputs.decoder_hidden_states, 1074 | decoder_attentions=outputs.decoder_attentions, 1075 | cross_attentions=outputs.cross_attentions, 1076 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 1077 | encoder_hidden_states=outputs.encoder_hidden_states, 1078 | encoder_attentions=outputs.encoder_attentions, 1079 | ) 1080 | 1081 | def prepare_inputs_for_generation( 1082 | self, 1083 | decoder_input_ids, 1084 | past=None, 1085 | attention_mask=None, 1086 | head_mask=None, 1087 | use_cache=None, 1088 | encoder_outputs=None, 1089 | **kwargs 1090 | ): 1091 | # cut decoder_input_ids if past is used 1092 | if past is not None: 1093 | decoder_input_ids = decoder_input_ids[:, -1:] 1094 | 1095 | return { 1096 | "input_ids": None, # encoder_outputs is defined. input_ids not needed 1097 | "encoder_outputs": encoder_outputs, 1098 | "past_key_values": past, 1099 | "decoder_input_ids": decoder_input_ids, 1100 | "attention_mask": attention_mask, 1101 | "head_mask": head_mask, 1102 | "use_cache": use_cache, # change this to avoid caching (presumably for debugging) 1103 | } 1104 | 1105 | @staticmethod 1106 | def _expand_inputs_for_generation( 1107 | input_ids: torch.LongTensor, 1108 | expand_size: int = 1, 1109 | is_encoder_decoder: bool = False, 1110 | attention_mask: torch.LongTensor = None, 1111 | encoder_outputs = None, 1112 | **model_kwargs, 1113 | ): 1114 | expanded_return_idx = ( 1115 | torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) 1116 | ) 1117 | input_ids = input_ids.index_select(0, expanded_return_idx) 1118 | 1119 | if "token_type_ids" in model_kwargs: 1120 | token_type_ids = model_kwargs["token_type_ids"] 1121 | model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) 1122 | 1123 | if attention_mask is not None: 1124 | model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) 1125 | 1126 | if is_encoder_decoder: 1127 | assert encoder_outputs is not None 1128 | device = encoder_outputs.last_hidden_state.device 1129 | encoder_outputs["hidden_states"] = tuple(h.index_select(0, expanded_return_idx.to(device)) \ 1130 | for h in encoder_outputs["hidden_states"]) 1131 | model_kwargs["encoder_outputs"] = encoder_outputs 1132 | return input_ids, model_kwargs 1133 | 1134 | def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): 1135 | return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) 1136 | 1137 | @staticmethod 1138 | def _reorder_cache(past, beam_idx): 1139 | reordered_past = () 1140 | for layer_past in past: 1141 | # cached cross_attention states don't have to be reordered -> they are always the same 1142 | reordered_past += ( 1143 | tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], 1144 | ) 1145 | return reordered_past 1146 | 1147 | 1148 | @add_start_docstrings( 1149 | """ 1150 | CPT model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE 1151 | tasks. 1152 | """, 1153 | CPT_START_DOCSTRING, 1154 | ) 1155 | class CPTForSequenceClassification(CPTPretrainedModel): 1156 | def __init__(self, config: CPTConfig, cls_mode=1, **kwargs): 1157 | super().__init__(config, **kwargs) 1158 | self.model = CPTModel(config) 1159 | cls_mode = getattr(config, 'cls_mode', cls_mode) 1160 | if cls_mode == 1: 1161 | logger.info('Encoder for classification.') 1162 | cls_dim = config.d_model 1163 | elif cls_mode == 2: 1164 | logger.info('Decoder for classification.') 1165 | cls_dim = config.d_model 1166 | elif cls_mode == 3: 1167 | logger.info('Both encoder & decoder for classification.') 1168 | cls_dim = config.d_model * 2 1169 | else: 1170 | raise NotImplementedError 1171 | 1172 | self.cls_head = CPTClassificationHead( 1173 | cls_dim, 1174 | cls_dim, 1175 | config.num_labels, 1176 | config.classifier_dropout, 1177 | ) 1178 | self.model._init_weights(self.cls_head.dense) 1179 | self.model._init_weights(self.cls_head.out_proj) 1180 | self.cls_mode = cls_mode 1181 | config.cls_mode = cls_mode 1182 | 1183 | @add_start_docstrings_to_model_forward(CPT_INPUTS_DOCSTRING) 1184 | @add_code_sample_docstrings( 1185 | checkpoint=_CHECKPOINT_FOR_DOC, 1186 | output_type=Seq2SeqSequenceClassifierOutput, 1187 | config_class=_CONFIG_FOR_DOC, 1188 | ) 1189 | def forward( 1190 | self, 1191 | input_ids=None, 1192 | attention_mask=None, 1193 | decoder_input_ids=None, 1194 | decoder_attention_mask=None, 1195 | head_mask=None, 1196 | decoder_head_mask=None, 1197 | encoder_outputs=None, 1198 | inputs_embeds=None, 1199 | decoder_inputs_embeds=None, 1200 | labels=None, 1201 | use_cache=None, 1202 | output_attentions=None, 1203 | output_hidden_states=None, 1204 | return_dict=None, 1205 | ): 1206 | r""" 1207 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1208 | Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., 1209 | config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1210 | """ 1211 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1212 | if labels is not None: 1213 | use_cache = False 1214 | 1215 | if input_ids is None and inputs_embeds is not None: 1216 | raise NotImplementedError( 1217 | f"Passing input embeddings is currently not supported for {self.__class__.__name__}" 1218 | ) 1219 | 1220 | outputs = self.model( 1221 | input_ids, 1222 | attention_mask=attention_mask, 1223 | decoder_input_ids=decoder_input_ids, 1224 | decoder_attention_mask=decoder_attention_mask, 1225 | head_mask=head_mask, 1226 | decoder_head_mask=decoder_head_mask, 1227 | encoder_outputs=encoder_outputs, 1228 | inputs_embeds=inputs_embeds, 1229 | decoder_inputs_embeds=decoder_inputs_embeds, 1230 | use_cache=use_cache, 1231 | output_attentions=output_attentions, 1232 | output_hidden_states=output_hidden_states, 1233 | return_dict=True, 1234 | ) 1235 | 1236 | hidden_states = outputs.last_hidden_state 1237 | enc_hidden_states = outputs.encoder_last_hidden_state 1238 | enc_rep = enc_hidden_states[:, 0] 1239 | 1240 | eos_mask = input_ids.eq(self.config.eos_token_id) 1241 | 1242 | if len(torch.unique(eos_mask.sum(1))) > 1: 1243 | raise ValueError("All examples must have the same number of tokens.") 1244 | dec_rep = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ 1245 | :, -1, : 1246 | ] 1247 | 1248 | if self.cls_mode == 1: 1249 | logits = self.cls_head(enc_rep) 1250 | elif self.cls_mode == 2: 1251 | logits = self.cls_head(dec_rep) 1252 | elif self.cls_mode == 3: 1253 | rep = torch.cat([enc_rep, dec_rep], dim=-1) 1254 | logits = self.cls_head(rep) 1255 | else: 1256 | raise NotImplementedError 1257 | 1258 | loss = None 1259 | if labels is not None: 1260 | loss_fct = CrossEntropyLoss() 1261 | loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) 1262 | 1263 | if not return_dict: 1264 | output = (logits,) + outputs[1:] 1265 | return ((loss,) + output) if loss is not None else output 1266 | 1267 | return Seq2SeqSequenceClassifierOutput( 1268 | loss=loss, 1269 | logits=logits, 1270 | past_key_values=outputs.past_key_values, 1271 | decoder_hidden_states=outputs.decoder_hidden_states, 1272 | decoder_attentions=outputs.decoder_attentions, 1273 | cross_attentions=outputs.cross_attentions, 1274 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 1275 | encoder_hidden_states=outputs.encoder_hidden_states, 1276 | encoder_attentions=outputs.encoder_attentions, 1277 | ) 1278 | 1279 | 1280 | @add_start_docstrings( 1281 | """ 1282 | CPT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear 1283 | layer on top of the hidden-states output to compute `span start logits` and `span end logits`). 1284 | """, 1285 | CPT_START_DOCSTRING, 1286 | ) 1287 | class CPTForQuestionAnswering(CPTPretrainedModel): 1288 | def __init__(self, config: CPTConfig, cls_mode=1, **kwargs): 1289 | super().__init__(config, **kwargs) 1290 | config.num_labels = 2 1291 | self.num_labels = config.num_labels 1292 | 1293 | self.model = CPTModel(config) 1294 | 1295 | cls_mode = getattr(config, 'cls_mode', cls_mode) 1296 | if cls_mode == 1: 1297 | logger.info('Encoder for classification.') 1298 | cls_dim = config.d_model 1299 | elif cls_mode == 2: 1300 | logger.info('Decoder for classification.') 1301 | cls_dim = config.d_model 1302 | elif cls_mode == 3: 1303 | logger.info('Both encoder & decoder for classification.') 1304 | cls_dim = config.d_model * 2 1305 | else: 1306 | raise NotImplementedError 1307 | 1308 | self.qa_outputs = nn.Linear(cls_dim, config.num_labels) 1309 | self.model._init_weights(self.qa_outputs) 1310 | 1311 | self.cls_mode = cls_mode 1312 | config.cls_mode = cls_mode 1313 | 1314 | @add_start_docstrings_to_model_forward(CPT_INPUTS_DOCSTRING) 1315 | @add_code_sample_docstrings( 1316 | checkpoint=_CHECKPOINT_FOR_DOC, 1317 | output_type=Seq2SeqSequenceClassifierOutput, 1318 | config_class=_CONFIG_FOR_DOC, 1319 | ) 1320 | def forward( 1321 | self, 1322 | input_ids=None, 1323 | attention_mask=None, 1324 | decoder_input_ids=None, 1325 | decoder_attention_mask=None, 1326 | head_mask=None, 1327 | decoder_head_mask=None, 1328 | encoder_outputs=None, 1329 | start_positions=None, 1330 | end_positions=None, 1331 | inputs_embeds=None, 1332 | decoder_inputs_embeds=None, 1333 | use_cache=None, 1334 | output_attentions=None, 1335 | output_hidden_states=None, 1336 | return_dict=None, 1337 | ): 1338 | r""" 1339 | start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1340 | Labels for position (index) of the start of the labelled span for computing the token classification loss. 1341 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 1342 | are not taken into account for computing the loss. 1343 | end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1344 | Labels for position (index) of the end of the labelled span for computing the token classification loss. 1345 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 1346 | are not taken into account for computing the loss. 1347 | """ 1348 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1349 | 1350 | if input_ids is None and inputs_embeds is not None: 1351 | raise NotImplementedError( 1352 | f"Passing input embeddings is currently not supported for {self.__class__.__name__}" 1353 | ) 1354 | 1355 | outputs = self.model( 1356 | input_ids, 1357 | attention_mask=attention_mask, 1358 | decoder_input_ids=decoder_input_ids, 1359 | decoder_attention_mask=decoder_attention_mask, 1360 | head_mask=head_mask, 1361 | decoder_head_mask=decoder_head_mask, 1362 | encoder_outputs=encoder_outputs, 1363 | inputs_embeds=inputs_embeds, 1364 | decoder_inputs_embeds=decoder_inputs_embeds, 1365 | use_cache=use_cache, 1366 | output_attentions=output_attentions, 1367 | output_hidden_states=output_hidden_states, 1368 | return_dict=True, 1369 | ) 1370 | 1371 | hidden_states = outputs.last_hidden_state 1372 | enc_hidden_states = outputs.encoder_last_hidden_state 1373 | 1374 | if self.cls_mode == 1: 1375 | logits = self.qa_outputs(enc_hidden_states) 1376 | elif self.cls_mode == 2: 1377 | logits = self.qa_outputs(hidden_states) 1378 | elif self.cls_mode == 3: 1379 | rep = torch.cat([enc_hidden_states, hidden_states], dim=-1) 1380 | logits = self.qa_outputs(rep) 1381 | else: 1382 | raise NotImplementedError 1383 | 1384 | start_logits, end_logits = logits.split(1, dim=-1) 1385 | start_logits = start_logits.squeeze(-1) 1386 | end_logits = end_logits.squeeze(-1) 1387 | 1388 | total_loss = None 1389 | if start_positions is not None and end_positions is not None: 1390 | # If we are on multi-GPU, split add a dimension 1391 | if len(start_positions.size()) > 1: 1392 | start_positions = start_positions.squeeze(-1) 1393 | if len(end_positions.size()) > 1: 1394 | end_positions = end_positions.squeeze(-1) 1395 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 1396 | ignored_index = start_logits.size(1) 1397 | start_positions.clamp_(0, ignored_index) 1398 | end_positions.clamp_(0, ignored_index) 1399 | 1400 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 1401 | start_loss = loss_fct(start_logits, start_positions) 1402 | end_loss = loss_fct(end_logits, end_positions) 1403 | total_loss = (start_loss + end_loss) / 2 1404 | 1405 | if not return_dict: 1406 | output = ( 1407 | start_logits, 1408 | end_logits, 1409 | ) + outputs[1:] 1410 | return ((total_loss,) + output) if total_loss is not None else output 1411 | 1412 | return Seq2SeqQuestionAnsweringModelOutput( 1413 | loss=total_loss, 1414 | start_logits=start_logits, 1415 | end_logits=end_logits, 1416 | past_key_values=outputs.past_key_values, 1417 | decoder_hidden_states=outputs.decoder_hidden_states, 1418 | decoder_attentions=outputs.decoder_attentions, 1419 | cross_attentions=outputs.cross_attentions, 1420 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 1421 | encoder_hidden_states=outputs.encoder_hidden_states, 1422 | encoder_attentions=outputs.encoder_attentions, 1423 | ) 1424 | 1425 | 1426 | class CPTForMaskedLM(CPTPretrainedModel): 1427 | _keys_to_ignore_on_load_missing = [ 1428 | r"final_logits_bias", 1429 | r"encoder\.version", 1430 | r"decoder\.version", 1431 | r"lm_head\.weight", 1432 | ] 1433 | def __init__(self, config, **kwargs): 1434 | super().__init__(config) 1435 | self.model = CPTModel(config) 1436 | self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) 1437 | self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) 1438 | 1439 | self.init_weights() 1440 | 1441 | def get_encoder(self): 1442 | return self.model.get_encoder() 1443 | 1444 | def get_decoder(self): 1445 | return self.model.get_decoder() 1446 | 1447 | def get_output_embeddings(self): 1448 | return self.lm_head 1449 | 1450 | def forward( 1451 | self, 1452 | input_ids=None, 1453 | attention_mask=None, 1454 | decoder_input_ids=None, 1455 | decoder_attention_mask=None, 1456 | head_mask=None, 1457 | decoder_head_mask=None, 1458 | encoder_outputs=None, 1459 | inputs_embeds=None, 1460 | decoder_inputs_embeds=None, 1461 | use_cache=None, 1462 | output_attentions=None, 1463 | output_hidden_states=None, 1464 | return_dict=None, 1465 | ): 1466 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1467 | 1468 | if input_ids is None and inputs_embeds is not None: 1469 | raise NotImplementedError( 1470 | f"Passing input embeddings is currently not supported for {self.__class__.__name__}" 1471 | ) 1472 | 1473 | outputs = self.model( 1474 | input_ids, 1475 | attention_mask=attention_mask, 1476 | decoder_input_ids=decoder_input_ids, 1477 | decoder_attention_mask=decoder_attention_mask, 1478 | head_mask=head_mask, 1479 | decoder_head_mask=decoder_head_mask, 1480 | encoder_outputs=encoder_outputs, 1481 | inputs_embeds=inputs_embeds, 1482 | decoder_inputs_embeds=decoder_inputs_embeds, 1483 | use_cache=use_cache, 1484 | output_attentions=output_attentions, 1485 | output_hidden_states=output_hidden_states, 1486 | return_dict=True, 1487 | ) 1488 | 1489 | hidden_states = outputs.last_hidden_state 1490 | enc_hidden_states = outputs.encoder_last_hidden_state 1491 | 1492 | dec_logits = self.lm_head(hidden_states) + self.final_logits_bias 1493 | enc_logits = self.lm_head(enc_hidden_states) + self.final_logits_bias 1494 | 1495 | if not return_dict: 1496 | logits = (enc_logits, dec_logits) 1497 | output = (logits,) + outputs[1:] 1498 | return output 1499 | 1500 | return Seq2SeqLMOutput( 1501 | loss=None, 1502 | logits=(enc_logits, dec_logits), 1503 | past_key_values=outputs.past_key_values, 1504 | decoder_hidden_states=outputs.decoder_hidden_states, 1505 | decoder_attentions=outputs.decoder_attentions, 1506 | cross_attentions=outputs.cross_attentions, 1507 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 1508 | encoder_hidden_states=outputs.encoder_hidden_states, 1509 | encoder_attentions=outputs.encoder_attentions, 1510 | ) -------------------------------------------------------------------------------- /Summary/modeling_cpt.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ PyTorch CPT model. modified from transformers==4.4.1""" 16 | import copy 17 | import math 18 | import random 19 | import warnings 20 | from typing import Optional, Tuple 21 | 22 | import torch 23 | import torch.nn.functional as F 24 | from torch.nn.modules.loss import KLDivLoss, NLLLoss 25 | import torch.utils.checkpoint 26 | from torch import nn 27 | from torch.nn import CrossEntropyLoss 28 | 29 | from transformers.activations import ACT2FN 30 | from transformers.file_utils import ( 31 | add_code_sample_docstrings, 32 | add_end_docstrings, 33 | add_start_docstrings, 34 | add_start_docstrings_to_model_forward, 35 | replace_return_docstrings, 36 | ) 37 | from transformers.modeling_outputs import ( 38 | BaseModelOutput, 39 | BaseModelOutputWithPastAndCrossAttentions, 40 | CausalLMOutputWithCrossAttentions, 41 | Seq2SeqLMOutput, 42 | Seq2SeqModelOutput, 43 | Seq2SeqQuestionAnsweringModelOutput, 44 | Seq2SeqSequenceClassifierOutput, 45 | ) 46 | from transformers.modeling_utils import PreTrainedModel 47 | from transformers.utils import logging 48 | from transformers import BartConfig as CPTConfig 49 | from transformers import BertModel, BertConfig 50 | 51 | from torch.nn import LayerNorm 52 | 53 | logger = logging.get_logger(__name__) 54 | 55 | _CHECKPOINT_FOR_DOC = "fnlp/cpt-large" 56 | _CONFIG_FOR_DOC = "CPTConfig" 57 | _TOKENIZER_FOR_DOC = "CPTTokenizer" 58 | 59 | 60 | CPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ 61 | "fnlp/cpt-large", 62 | ] 63 | 64 | 65 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): 66 | """ 67 | Shift input ids one token to the right. 68 | """ 69 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 70 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() 71 | shifted_input_ids[:, 0] = decoder_start_token_id 72 | 73 | assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." 74 | # replace possible -100 values in labels by `pad_token_id` 75 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 76 | 77 | return shifted_input_ids 78 | 79 | 80 | def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): 81 | """ 82 | Make causal mask used for bi-directional self-attention. 83 | """ 84 | bsz, tgt_len = input_ids_shape 85 | mask = torch.full((tgt_len, tgt_len), float("-inf")) 86 | mask_cond = torch.arange(mask.size(-1)) 87 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 88 | mask = mask.to(dtype) 89 | 90 | if past_key_values_length > 0: 91 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) 92 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) 93 | 94 | 95 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 96 | """ 97 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 98 | """ 99 | bsz, src_len = mask.size() 100 | tgt_len = tgt_len if tgt_len is not None else src_len 101 | 102 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 103 | 104 | inverted_mask = 1.0 - expanded_mask 105 | 106 | return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) 107 | 108 | def attention_mask_func(attention_scores, attention_mask): 109 | return attention_scores + attention_mask 110 | 111 | def init_method(std): 112 | def init_(tensor): 113 | return torch.nn.init.normal_(tensor, mean=0.0, std=std) 114 | 115 | return init_ 116 | 117 | class CPTLearnedPositionalEmbedding(nn.Embedding): 118 | """ 119 | This module learns positional embeddings up to a fixed maximum size. 120 | """ 121 | 122 | def __init__(self, num_embeddings: int, embedding_dim: int): 123 | # CPT is set up so that if padding_idx is specified then offset the embedding ids by 2 124 | # and adjust num_embeddings appropriately. Other models dont have this hack 125 | self.offset = 2 126 | super().__init__(num_embeddings + self.offset, embedding_dim) 127 | 128 | def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): 129 | """`input_ids_shape` is expected to be [bsz x seqlen].""" 130 | bsz, seq_len = input_ids_shape[:2] 131 | positions = torch.arange( 132 | past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device 133 | ) 134 | return super().forward(positions + self.offset) 135 | 136 | 137 | class CPTAttention(nn.Module): 138 | """Multi-headed attention from 'Attention Is All You Need' paper""" 139 | 140 | def __init__( 141 | self, 142 | embed_dim: int, 143 | num_heads: int, 144 | dropout: float = 0.0, 145 | is_decoder: bool = False, 146 | bias: bool = True, 147 | ): 148 | super().__init__() 149 | self.embed_dim = embed_dim 150 | self.num_heads = num_heads 151 | self.dropout = dropout 152 | self.head_dim = embed_dim // num_heads 153 | assert ( 154 | self.head_dim * num_heads == self.embed_dim 155 | ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})." 156 | self.scaling = self.head_dim ** -0.5 157 | self.is_decoder = is_decoder 158 | 159 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 160 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 161 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 162 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 163 | 164 | 165 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 166 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 167 | 168 | def forward( 169 | self, 170 | hidden_states: torch.Tensor, 171 | key_value_states: Optional[torch.Tensor] = None, 172 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 173 | attention_mask: Optional[torch.Tensor] = None, 174 | layer_head_mask: Optional[torch.Tensor] = None, 175 | output_attentions: bool = False, 176 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 177 | """Input shape: Batch x Time x Channel""" 178 | 179 | # if key_value_states are provided this layer is used as a cross-attention layer 180 | # for the decoder 181 | is_cross_attention = key_value_states is not None 182 | bsz, tgt_len, embed_dim = hidden_states.size() 183 | 184 | # get query proj 185 | query_states = self.q_proj(hidden_states) * self.scaling 186 | # get key, value proj 187 | if is_cross_attention and past_key_value is not None: 188 | # reuse k,v, cross_attentions 189 | key_states = past_key_value[0] 190 | value_states = past_key_value[1] 191 | elif is_cross_attention: 192 | # cross_attentions 193 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz) 194 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz) 195 | elif past_key_value is not None: 196 | # reuse k, v, self_attention 197 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 198 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 199 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 200 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 201 | else: 202 | # self_attention 203 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 204 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 205 | 206 | if self.is_decoder: 207 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 208 | # Further calls to cross_attention layer can then reuse all cross-attention 209 | # key/value_states (first "if" case) 210 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 211 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 212 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 213 | # if encoder bi-directional self-attention `past_key_value` is always `None` 214 | past_key_value = (key_states, value_states) 215 | 216 | proj_shape = (bsz * self.num_heads, -1, self.head_dim) 217 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) 218 | key_states = key_states.view(*proj_shape) 219 | value_states = value_states.view(*proj_shape) 220 | 221 | src_len = key_states.size(1) 222 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) 223 | 224 | assert attn_weights.size() == ( 225 | bsz * self.num_heads, 226 | tgt_len, 227 | src_len, 228 | ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" 229 | 230 | if attention_mask is not None: 231 | assert attention_mask.size() == ( 232 | bsz, 233 | 1, 234 | tgt_len, 235 | src_len, 236 | ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" 237 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask 238 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 239 | 240 | attn_weights = F.softmax(attn_weights, dim=-1) 241 | 242 | if layer_head_mask is not None: 243 | assert layer_head_mask.size() == ( 244 | self.num_heads, 245 | ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" 246 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 247 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 248 | 249 | if output_attentions: 250 | # this operation is a bit akward, but it's required to 251 | # make sure that attn_weights keeps its gradient. 252 | # In order to do so, attn_weights have to reshaped 253 | # twice and have to be reused in the following 254 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 255 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) 256 | else: 257 | attn_weights_reshaped = None 258 | 259 | # with mpu.get_cuda_rng_tracker().fork(): 260 | attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) 261 | 262 | attn_output = torch.bmm(attn_probs, value_states) 263 | 264 | assert attn_output.size() == ( 265 | bsz * self.num_heads, 266 | tgt_len, 267 | self.head_dim, 268 | ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" 269 | 270 | attn_output = ( 271 | attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) 272 | .transpose(1, 2) 273 | .reshape(bsz, tgt_len, embed_dim) 274 | ) 275 | 276 | attn_output = self.out_proj(attn_output) 277 | 278 | return attn_output, attn_weights_reshaped, past_key_value 279 | 280 | class CPTDecoderLayer(nn.Module): 281 | def __init__(self, config: CPTConfig): 282 | super().__init__() 283 | self.embed_dim = config.d_model 284 | 285 | self.self_attn = CPTAttention( 286 | embed_dim=self.embed_dim, 287 | num_heads=config.decoder_attention_heads, 288 | dropout=config.attention_dropout, 289 | is_decoder=True, 290 | ) 291 | self.dropout = config.dropout 292 | self.activation_fn = ACT2FN[config.activation_function] 293 | self.activation_dropout = config.activation_dropout 294 | 295 | self.self_attn_layer_norm = LayerNorm(self.embed_dim) 296 | self.encoder_attn = CPTAttention( 297 | self.embed_dim, 298 | config.decoder_attention_heads, 299 | dropout=config.attention_dropout, 300 | is_decoder=True, 301 | ) 302 | self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) 303 | self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) 304 | self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) 305 | self.final_layer_norm = LayerNorm(self.embed_dim) 306 | 307 | def forward( 308 | self, 309 | hidden_states: torch.Tensor, 310 | attention_mask: Optional[torch.Tensor] = None, 311 | encoder_hidden_states: Optional[torch.Tensor] = None, 312 | encoder_attention_mask: Optional[torch.Tensor] = None, 313 | layer_head_mask: Optional[torch.Tensor] = None, 314 | encoder_layer_head_mask: Optional[torch.Tensor] = None, 315 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 316 | output_attentions: Optional[bool] = False, 317 | use_cache: Optional[bool] = True, 318 | ): 319 | """ 320 | Args: 321 | hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` 322 | attention_mask (:obj:`torch.FloatTensor`): attention mask of size 323 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 324 | encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` 325 | encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size 326 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 327 | layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size 328 | `(config.encoder_attention_heads,)`. 329 | encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of 330 | size `(config.encoder_attention_heads,)`. 331 | past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states 332 | output_attentions (:obj:`bool`, `optional`): 333 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under 334 | returned tensors for more detail. 335 | """ 336 | residual = hidden_states 337 | 338 | # Self Attention 339 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 340 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 341 | # add present self-attn cache to positions 1,2 of present_key_value tuple 342 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 343 | hidden_states=hidden_states, 344 | past_key_value=self_attn_past_key_value, 345 | attention_mask=attention_mask, 346 | layer_head_mask=layer_head_mask, 347 | output_attentions=output_attentions, 348 | ) 349 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 350 | hidden_states = residual + hidden_states 351 | hidden_states = self.self_attn_layer_norm(hidden_states) 352 | 353 | # Cross-Attention Block 354 | cross_attn_present_key_value = None 355 | cross_attn_weights = None 356 | if encoder_hidden_states is not None: 357 | residual = hidden_states 358 | 359 | # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple 360 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None 361 | hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( 362 | hidden_states=hidden_states, 363 | key_value_states=encoder_hidden_states, 364 | attention_mask=encoder_attention_mask, 365 | layer_head_mask=encoder_layer_head_mask, 366 | past_key_value=cross_attn_past_key_value, 367 | output_attentions=output_attentions, 368 | ) 369 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 370 | hidden_states = residual + hidden_states 371 | hidden_states = self.encoder_attn_layer_norm(hidden_states) 372 | 373 | # add cross-attn to positions 3,4 of present_key_value tuple 374 | present_key_value = present_key_value + cross_attn_present_key_value 375 | 376 | # Fully Connected 377 | residual = hidden_states 378 | hidden_states = self.activation_fn(self.fc1(hidden_states)) 379 | hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) 380 | hidden_states = self.fc2(hidden_states) 381 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 382 | hidden_states = residual + hidden_states 383 | hidden_states = self.final_layer_norm(hidden_states) 384 | 385 | outputs = (hidden_states,) 386 | 387 | if output_attentions: 388 | outputs += (self_attn_weights, cross_attn_weights) 389 | 390 | if use_cache: 391 | outputs += (present_key_value,) 392 | 393 | return outputs 394 | 395 | 396 | class CPTClassificationHead(nn.Module): 397 | """Head for sentence-level classification tasks.""" 398 | 399 | def __init__( 400 | self, 401 | input_dim: int, 402 | inner_dim: int, 403 | num_classes: int, 404 | pooler_dropout: float, 405 | ): 406 | super().__init__() 407 | self.dense = nn.Linear(input_dim, inner_dim) 408 | self.dropout = nn.Dropout(p=pooler_dropout) 409 | self.out_proj = nn.Linear(inner_dim, num_classes) 410 | 411 | def forward(self, hidden_states: torch.Tensor): 412 | hidden_states = self.dropout(hidden_states) 413 | hidden_states = self.dense(hidden_states) 414 | hidden_states = torch.tanh(hidden_states) 415 | hidden_states = self.dropout(hidden_states) 416 | hidden_states = self.out_proj(hidden_states) 417 | return hidden_states 418 | 419 | 420 | class CPTPretrainedModel(PreTrainedModel): 421 | config_class = CPTConfig 422 | base_model_prefix = "model" 423 | 424 | def _init_weights(self, module): 425 | std = self.config.init_std 426 | if isinstance(module, nn.Linear): 427 | module.weight.data.normal_(mean=0.0, std=std) 428 | if module.bias is not None: 429 | module.bias.data.zero_() 430 | elif isinstance(module, nn.Embedding): 431 | module.weight.data.normal_(mean=0.0, std=std) 432 | if module.padding_idx is not None: 433 | module.weight.data[module.padding_idx].zero_() 434 | 435 | @property 436 | def dummy_inputs(self): 437 | pad_token = self.config.pad_token_id 438 | input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) 439 | dummy_inputs = { 440 | "attention_mask": input_ids.ne(pad_token), 441 | "input_ids": input_ids, 442 | } 443 | return dummy_inputs 444 | 445 | CPT_START_DOCSTRING = r""" 446 | This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic 447 | methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, 448 | pruning heads etc.) 449 | 450 | This model is also a PyTorch `torch.nn.Module `__ 451 | subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to 452 | general usage and behavior. 453 | 454 | Parameters: 455 | config (:class:`~transformers.CPTConfig`): 456 | Model configuration class with all the parameters of the model. Initializing with a config file does not 457 | load the weights associated with the model, only the configuration. Check out the 458 | :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. 459 | """ 460 | 461 | CPT_INPUTS_DOCSTRING = r""" 462 | Args: 463 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): 464 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 465 | it. 466 | 467 | Indices can be obtained using :class:`~transformers.CPTTokenizer`. See 468 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for 469 | details. 470 | 471 | `What are input IDs? <../glossary.html#input-ids>`__ 472 | attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 473 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 474 | 475 | - 1 for tokens that are **not masked**, 476 | - 0 for tokens that are **masked**. 477 | 478 | `What are attention masks? <../glossary.html#attention-mask>`__ 479 | decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): 480 | Indices of decoder input sequence tokens in the vocabulary. 481 | 482 | Indices can be obtained using :class:`~transformers.CPTTokenizer`. See 483 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for 484 | details. 485 | 486 | `What are input IDs? <../glossary.html#input-ids>`__ 487 | 488 | CPT uses the :obj:`eos_token_id` as the starting token for :obj:`decoder_input_ids` generation. If 489 | :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see 490 | :obj:`past_key_values`). 491 | 492 | For translation and summarization training, :obj:`decoder_input_ids` should be provided. If no 493 | :obj:`decoder_input_ids` is provided, the model will create this tensor by shifting the :obj:`input_ids` to 494 | the right for denoising pre-training following the paper. 495 | decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): 496 | Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will 497 | also be used by default. 498 | 499 | If you want to change padding behavior, you should read :func:`modeling_cpt._prepare_decoder_inputs` and 500 | modify to your needs. See diagram 1 in `the paper `__ for more 501 | information on the default strategy. 502 | head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): 503 | Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: 504 | 505 | - 1 indicates the head is **not masked**, 506 | - 0 indicates the heas is **masked**. 507 | 508 | decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): 509 | Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: 510 | 511 | - 1 indicates the head is **not masked**, 512 | - 0 indicates the head is **masked**. 513 | 514 | encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): 515 | Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: 516 | :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, 517 | `optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the 518 | cross-attention of the decoder. 519 | past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 520 | Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. 521 | 522 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` 523 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` 524 | instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. 525 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 526 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. 527 | This is useful if you want more control over how to convert :obj:`input_ids` indices into associated 528 | vectors than the model's internal embedding lookup matrix. 529 | decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`): 530 | Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded 531 | representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds` 532 | have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert 533 | :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. 534 | 535 | If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds` 536 | takes the value of :obj:`inputs_embeds`. 537 | use_cache (:obj:`bool`, `optional`): 538 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up 539 | decoding (see :obj:`past_key_values`). 540 | output_attentions (:obj:`bool`, `optional`): 541 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned 542 | tensors for more detail. 543 | output_hidden_states (:obj:`bool`, `optional`): 544 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for 545 | more detail. 546 | return_dict (:obj:`bool`, `optional`): 547 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 548 | """ 549 | 550 | class CPTDecoder(CPTPretrainedModel): 551 | """ 552 | Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`CPTDecoderLayer` 553 | 554 | Args: 555 | config: CPTConfig 556 | embed_tokens (torch.nn.Embedding): output embedding 557 | """ 558 | 559 | def __init__(self, config: CPTConfig, embed_tokens: Optional[nn.Embedding] = None): 560 | super().__init__(config) 561 | self.dropout = config.dropout 562 | self.layerdrop = config.decoder_layerdrop 563 | self.padding_idx = config.pad_token_id 564 | self.max_target_positions = config.max_position_embeddings 565 | self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 566 | 567 | if embed_tokens is not None: 568 | self.embed_tokens = embed_tokens 569 | else: 570 | self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) 571 | 572 | self.embed_positions = CPTLearnedPositionalEmbedding( 573 | config.max_position_embeddings, 574 | config.d_model, 575 | ) 576 | self.layers = nn.ModuleList([CPTDecoderLayer(config) for _ in range(config.decoder_layers)]) 577 | self.layernorm_embedding = LayerNorm(config.d_model) 578 | 579 | self.init_weights() 580 | 581 | def get_input_embeddings(self): 582 | return self.embed_tokens 583 | 584 | def set_input_embeddings(self, value): 585 | self.embed_tokens = value 586 | 587 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 588 | # create causal mask 589 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 590 | combined_attention_mask = None 591 | if input_shape[-1] > 1: 592 | combined_attention_mask = _make_causal_mask( 593 | input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length 594 | ).to(self.device) 595 | 596 | if attention_mask is not None: 597 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 598 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) 599 | combined_attention_mask = ( 600 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 601 | ) 602 | 603 | return combined_attention_mask 604 | 605 | def forward( 606 | self, 607 | input_ids=None, 608 | attention_mask=None, 609 | encoder_hidden_states=None, 610 | encoder_attention_mask=None, 611 | head_mask=None, 612 | encoder_head_mask=None, 613 | past_key_values=None, 614 | inputs_embeds=None, 615 | use_cache=None, 616 | output_attentions=None, 617 | output_hidden_states=None, 618 | return_dict=None, 619 | ): 620 | r""" 621 | Args: 622 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): 623 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 624 | provide it. 625 | 626 | Indices can be obtained using :class:`~transformers.CPTTokenizer`. See 627 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` 628 | for details. 629 | 630 | `What are input IDs? <../glossary.html#input-ids>`__ 631 | attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 632 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 633 | 634 | - 1 for tokens that are **not masked**, 635 | - 0 for tokens that are **masked**. 636 | 637 | `What are attention masks? <../glossary.html#attention-mask>`__ 638 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`): 639 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention 640 | of the decoder. 641 | encoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, encoder_sequence_length)`, `optional`): 642 | Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values 643 | selected in ``[0, 1]``: 644 | 645 | - 1 for tokens that are **not masked**, 646 | - 0 for tokens that are **masked**. 647 | 648 | `What are attention masks? <../glossary.html#attention-mask>`__ 649 | head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): 650 | Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: 651 | 652 | - 1 indicates the head is **not masked**, 653 | - 0 indicates the heas is **masked**. 654 | 655 | encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): 656 | Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention 657 | on hidden heads. Mask values selected in ``[0, 1]``: 658 | 659 | - 1 indicates the head is **not masked**, 660 | - 0 indicates the heas is **masked**. 661 | 662 | past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 663 | Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up 664 | decoding. 665 | 666 | If :obj:`past_key_values` are used, the user can optionally input only the last 667 | :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of 668 | shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, 669 | sequence_length)`. 670 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 671 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded 672 | representation. This is useful if you want more control over how to convert :obj:`input_ids` indices 673 | into associated vectors than the model's internal embedding lookup matrix. 674 | output_attentions (:obj:`bool`, `optional`): 675 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under 676 | returned tensors for more detail. 677 | output_hidden_states (:obj:`bool`, `optional`): 678 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors 679 | for more detail. 680 | return_dict (:obj:`bool`, `optional`): 681 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 682 | """ 683 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 684 | output_hidden_states = ( 685 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 686 | ) 687 | use_cache = use_cache if use_cache is not None else self.config.use_cache 688 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 689 | 690 | # retrieve input_ids and inputs_embeds 691 | if input_ids is not None and inputs_embeds is not None: 692 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 693 | elif input_ids is not None: 694 | input_shape = input_ids.size() 695 | input_ids = input_ids.view(-1, input_shape[-1]) 696 | elif inputs_embeds is not None: 697 | input_shape = inputs_embeds.size()[:-1] 698 | else: 699 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 700 | 701 | # past_key_values_length 702 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 703 | 704 | if inputs_embeds is None: 705 | inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale 706 | 707 | attention_mask = self._prepare_decoder_attention_mask( 708 | attention_mask, input_shape, inputs_embeds, past_key_values_length 709 | ) 710 | 711 | # expand encoder attention mask 712 | if encoder_hidden_states is not None and encoder_attention_mask is not None: 713 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 714 | encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) 715 | 716 | # embed positions 717 | positions = self.embed_positions(input_shape, past_key_values_length) 718 | 719 | hidden_states = inputs_embeds + positions 720 | hidden_states = self.layernorm_embedding(hidden_states) 721 | 722 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 723 | 724 | # decoder layers 725 | all_hidden_states = () if output_hidden_states else None 726 | all_self_attns = () if output_attentions else None 727 | all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None 728 | next_decoder_cache = () if use_cache else None 729 | 730 | # check if head_mask has a correct number of layers specified if desired 731 | if head_mask is not None: 732 | assert head_mask.size()[0] == ( 733 | len(self.layers) 734 | ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." 735 | for idx, decoder_layer in enumerate(self.layers): 736 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 737 | if output_hidden_states: 738 | all_hidden_states += (hidden_states,) 739 | dropout_probability = random.uniform(0, 1) 740 | if self.training and (dropout_probability < self.layerdrop): 741 | continue 742 | 743 | past_key_value = past_key_values[idx] if past_key_values is not None else None 744 | 745 | if getattr(self.config, "gradient_checkpointing", False) and self.training: 746 | 747 | if use_cache: 748 | logger.warn( 749 | "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " 750 | "`use_cache=False`..." 751 | ) 752 | use_cache = False 753 | 754 | def create_custom_forward(module): 755 | def custom_forward(*inputs): 756 | # None for past_key_value 757 | return module(*inputs, output_attentions, use_cache) 758 | 759 | return custom_forward 760 | 761 | # layer_outputs = mpu.checkpoint( 762 | layer_outputs = torch.utils.checkpoint( 763 | create_custom_forward(decoder_layer), 764 | hidden_states, 765 | attention_mask, 766 | encoder_hidden_states, 767 | encoder_attention_mask, 768 | head_mask[idx] if head_mask is not None else None, 769 | encoder_head_mask[idx] if encoder_head_mask is not None else None, 770 | None, 771 | ) 772 | else: 773 | 774 | layer_outputs = decoder_layer( 775 | hidden_states, 776 | attention_mask=attention_mask, 777 | encoder_hidden_states=encoder_hidden_states, 778 | encoder_attention_mask=encoder_attention_mask, 779 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 780 | encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), 781 | past_key_value=past_key_value, 782 | output_attentions=output_attentions, 783 | use_cache=use_cache, 784 | ) 785 | hidden_states = layer_outputs[0] 786 | 787 | if use_cache: 788 | next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) 789 | 790 | if output_attentions: 791 | all_self_attns += (layer_outputs[1],) 792 | 793 | if encoder_hidden_states is not None: 794 | all_cross_attentions += (layer_outputs[2],) 795 | 796 | # add hidden states from the last decoder layer 797 | if output_hidden_states: 798 | all_hidden_states += (hidden_states,) 799 | 800 | next_cache = next_decoder_cache if use_cache else None 801 | if not return_dict: 802 | return tuple( 803 | v 804 | for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] 805 | if v is not None 806 | ) 807 | return BaseModelOutputWithPastAndCrossAttentions( 808 | last_hidden_state=hidden_states, 809 | past_key_values=next_cache, 810 | hidden_states=all_hidden_states, 811 | attentions=all_self_attns, 812 | cross_attentions=all_cross_attentions, 813 | ) 814 | 815 | 816 | @add_start_docstrings( 817 | "The bare CPT Model outputting raw hidden-states without any specific head on top.", 818 | CPT_START_DOCSTRING, 819 | ) 820 | class CPTModel(CPTPretrainedModel): 821 | def __init__(self, config: CPTConfig): 822 | super().__init__(config) 823 | encoder_config = BertConfig( 824 | vocab_size=config.vocab_size, 825 | hidden_size=config.d_model, 826 | num_hidden_layers=config.encoder_layers, 827 | num_attention_heads=config.encoder_attention_heads, 828 | intermediate_size=config.encoder_ffn_dim, 829 | hidden_dropout_prob=config.activation_dropout, 830 | attention_probs_dropout_prob=config.attention_dropout, 831 | ) 832 | config.vocab_size = encoder_config.vocab_size 833 | self.encoder = BertModel(encoder_config, add_pooling_layer=False) 834 | self.shared = self.encoder.get_input_embeddings() 835 | self.decoder = CPTDecoder(config, self.shared) 836 | self.num_decoder_layers = config.decoder_layers 837 | self.init_weights() 838 | 839 | def get_input_embeddings(self): 840 | return self.shared 841 | 842 | def set_input_embeddings(self, value): 843 | self.shared = value 844 | self.encoder.embed_tokens = self.shared 845 | self.decoder.embed_tokens = self.shared 846 | 847 | def get_encoder(self): 848 | class _Encoder(torch.nn.Module): 849 | def __init__(self, encoder): 850 | super().__init__() 851 | self.encoder = encoder 852 | 853 | def forward(self, *args, **kwargs): 854 | kwargs['output_hidden_states'] = True 855 | return self.encoder(*args, **kwargs) 856 | return _Encoder(self.encoder) 857 | 858 | def get_decoder(self): 859 | return self.decoder 860 | 861 | @add_start_docstrings_to_model_forward(CPT_INPUTS_DOCSTRING) 862 | @add_code_sample_docstrings( 863 | checkpoint=_CHECKPOINT_FOR_DOC, 864 | output_type=Seq2SeqModelOutput, 865 | ) 866 | def forward( 867 | self, 868 | input_ids=None, 869 | token_type_ids=None, 870 | attention_mask=None, 871 | decoder_input_ids=None, 872 | decoder_attention_mask=None, 873 | head_mask=None, 874 | decoder_head_mask=None, 875 | encoder_outputs=None, 876 | past_key_values=None, 877 | inputs_embeds=None, 878 | decoder_inputs_embeds=None, 879 | use_cache=None, 880 | output_attentions=None, 881 | output_hidden_states=None, 882 | return_dict=None, 883 | ): 884 | 885 | # different to other models, CPT automatically creates decoder_input_ids from 886 | # input_ids if no decoder_input_ids are provided 887 | if decoder_input_ids is None and decoder_inputs_embeds is None: 888 | decoder_input_ids = shift_tokens_right( 889 | input_ids, self.config.pad_token_id, self.config.decoder_start_token_id 890 | ) 891 | 892 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 893 | output_hidden_states = ( 894 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 895 | ) 896 | use_cache = use_cache if use_cache is not None else self.config.use_cache 897 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 898 | 899 | if getattr(self.config, "gradient_checkpointing", False) and self.training: 900 | # mpu.reset_checkpointed_activations_memory_buffer() 901 | use_cache = False 902 | 903 | if encoder_outputs is None: 904 | encoder_outputs = self.encoder( 905 | input_ids=input_ids, 906 | attention_mask=attention_mask, 907 | token_type_ids=token_type_ids, 908 | head_mask=head_mask, 909 | inputs_embeds=inputs_embeds, 910 | output_attentions=output_attentions, 911 | output_hidden_states=True, 912 | return_dict=return_dict, 913 | ) 914 | # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True 915 | elif return_dict and isinstance(encoder_outputs, (tuple, list)): 916 | encoder_outputs = BaseModelOutput( 917 | last_hidden_state=encoder_outputs[0], 918 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 919 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 920 | ) 921 | 922 | if isinstance(encoder_outputs, (torch.Tensor)): 923 | encoder_hidden_states = encoder_outputs 924 | else: 925 | encoder_hidden_states = encoder_outputs[1][-self.num_decoder_layers - 1] 926 | 927 | # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) 928 | decoder_outputs = self.decoder( 929 | input_ids=decoder_input_ids, 930 | attention_mask=decoder_attention_mask, 931 | encoder_hidden_states=encoder_hidden_states, 932 | encoder_attention_mask=attention_mask, 933 | head_mask=decoder_head_mask, 934 | encoder_head_mask=head_mask, 935 | past_key_values=past_key_values, 936 | inputs_embeds=decoder_inputs_embeds, 937 | use_cache=use_cache, 938 | output_attentions=output_attentions, 939 | output_hidden_states=output_hidden_states, 940 | return_dict=return_dict, 941 | ) 942 | 943 | if not return_dict: 944 | return decoder_outputs + encoder_outputs 945 | 946 | return Seq2SeqModelOutput( 947 | last_hidden_state=decoder_outputs.last_hidden_state, 948 | past_key_values=decoder_outputs.past_key_values, 949 | decoder_hidden_states=decoder_outputs.hidden_states, 950 | decoder_attentions=decoder_outputs.attentions, 951 | cross_attentions=decoder_outputs.cross_attentions, 952 | encoder_last_hidden_state=encoder_outputs.last_hidden_state if isinstance(encoder_outputs, dict) else None, 953 | encoder_hidden_states=encoder_outputs.hidden_states if isinstance(encoder_outputs, dict) else None, 954 | encoder_attentions=encoder_outputs.attentions if isinstance(encoder_outputs, dict) else None, 955 | ) 956 | 957 | 958 | @add_start_docstrings( 959 | "The CPT Model with a language modeling head. Can be used for summarization.", CPT_START_DOCSTRING 960 | ) 961 | class CPTForConditionalGeneration(CPTPretrainedModel): 962 | base_model_prefix = "model" 963 | _keys_to_ignore_on_load_missing = [ 964 | r"final_logits_bias", 965 | r"encoder\.version", 966 | r"decoder\.version", 967 | r"lm_head\.weight", 968 | ] 969 | 970 | def __init__(self, config): 971 | super().__init__(config) 972 | self.model = CPTModel(config) 973 | self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) 974 | self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) 975 | 976 | self.init_weights() 977 | 978 | def get_encoder(self): 979 | return self.model.get_encoder() 980 | 981 | def get_decoder(self): 982 | return self.model.get_decoder() 983 | 984 | def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: 985 | new_embeddings = super().resize_token_embeddings(new_num_tokens) 986 | self._resize_final_logits_bias(new_num_tokens) 987 | return new_embeddings 988 | 989 | def _resize_final_logits_bias(self, new_num_tokens: int) -> None: 990 | old_num_tokens = self.final_logits_bias.shape[-1] 991 | if new_num_tokens <= old_num_tokens: 992 | new_bias = self.final_logits_bias[:, :new_num_tokens] 993 | else: 994 | extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) 995 | new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) 996 | self.register_buffer("final_logits_bias", new_bias) 997 | 998 | def get_output_embeddings(self): 999 | return self.lm_head 1000 | 1001 | def set_output_embeddings(self, new_embeddings): 1002 | self.lm_head = new_embeddings 1003 | 1004 | @add_start_docstrings_to_model_forward(CPT_INPUTS_DOCSTRING) 1005 | @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) 1006 | def forward( 1007 | self, 1008 | input_ids=None, 1009 | attention_mask=None, 1010 | token_type_ids=None, 1011 | decoder_input_ids=None, 1012 | decoder_attention_mask=None, 1013 | head_mask=None, 1014 | decoder_head_mask=None, 1015 | encoder_outputs=None, 1016 | past_key_values=None, 1017 | inputs_embeds=None, 1018 | decoder_inputs_embeds=None, 1019 | labels=None, 1020 | use_cache=None, 1021 | output_attentions=None, 1022 | output_hidden_states=None, 1023 | return_dict=None, 1024 | ): 1025 | r""" 1026 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 1027 | Labels for computing the masked language modeling loss. Indices should either be in ``[0, ..., 1028 | config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored 1029 | (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``. 1030 | 1031 | Returns: 1032 | """ 1033 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1034 | 1035 | if labels is not None: 1036 | if decoder_input_ids is None: 1037 | decoder_input_ids = shift_tokens_right( 1038 | labels, self.config.pad_token_id, self.config.decoder_start_token_id 1039 | ) 1040 | 1041 | outputs = self.model( 1042 | input_ids, 1043 | attention_mask=attention_mask, 1044 | decoder_input_ids=decoder_input_ids, 1045 | encoder_outputs=encoder_outputs, 1046 | token_type_ids=token_type_ids, 1047 | decoder_attention_mask=decoder_attention_mask, 1048 | head_mask=head_mask, 1049 | decoder_head_mask=decoder_head_mask, 1050 | past_key_values=past_key_values, 1051 | inputs_embeds=inputs_embeds, 1052 | decoder_inputs_embeds=decoder_inputs_embeds, 1053 | use_cache=use_cache, 1054 | output_attentions=output_attentions, 1055 | output_hidden_states=output_hidden_states, 1056 | return_dict=return_dict, 1057 | ) 1058 | lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias 1059 | 1060 | masked_lm_loss = None 1061 | if labels is not None: 1062 | loss_fct = CrossEntropyLoss() 1063 | masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) 1064 | 1065 | if not return_dict: 1066 | output = (lm_logits,) + outputs[1:] 1067 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 1068 | 1069 | return Seq2SeqLMOutput( 1070 | loss=masked_lm_loss, 1071 | logits=lm_logits, 1072 | past_key_values=outputs.past_key_values, 1073 | decoder_hidden_states=outputs.decoder_hidden_states, 1074 | decoder_attentions=outputs.decoder_attentions, 1075 | cross_attentions=outputs.cross_attentions, 1076 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 1077 | encoder_hidden_states=outputs.encoder_hidden_states, 1078 | encoder_attentions=outputs.encoder_attentions, 1079 | ) 1080 | 1081 | def prepare_inputs_for_generation( 1082 | self, 1083 | decoder_input_ids, 1084 | past=None, 1085 | attention_mask=None, 1086 | head_mask=None, 1087 | use_cache=None, 1088 | encoder_outputs=None, 1089 | **kwargs 1090 | ): 1091 | # cut decoder_input_ids if past is used 1092 | if past is not None: 1093 | decoder_input_ids = decoder_input_ids[:, -1:] 1094 | 1095 | return { 1096 | "input_ids": None, # encoder_outputs is defined. input_ids not needed 1097 | "encoder_outputs": encoder_outputs, 1098 | "past_key_values": past, 1099 | "decoder_input_ids": decoder_input_ids, 1100 | "attention_mask": attention_mask, 1101 | "head_mask": head_mask, 1102 | "use_cache": use_cache, # change this to avoid caching (presumably for debugging) 1103 | } 1104 | 1105 | @staticmethod 1106 | def _expand_inputs_for_generation( 1107 | input_ids: torch.LongTensor, 1108 | expand_size: int = 1, 1109 | is_encoder_decoder: bool = False, 1110 | attention_mask: torch.LongTensor = None, 1111 | encoder_outputs = None, 1112 | **model_kwargs, 1113 | ): 1114 | expanded_return_idx = ( 1115 | torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) 1116 | ) 1117 | input_ids = input_ids.index_select(0, expanded_return_idx) 1118 | 1119 | if "token_type_ids" in model_kwargs: 1120 | token_type_ids = model_kwargs["token_type_ids"] 1121 | model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) 1122 | 1123 | if attention_mask is not None: 1124 | model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) 1125 | 1126 | if is_encoder_decoder: 1127 | assert encoder_outputs is not None 1128 | device = encoder_outputs.last_hidden_state.device 1129 | encoder_outputs["hidden_states"] = tuple(h.index_select(0, expanded_return_idx.to(device)) \ 1130 | for h in encoder_outputs["hidden_states"]) 1131 | model_kwargs["encoder_outputs"] = encoder_outputs 1132 | return input_ids, model_kwargs 1133 | 1134 | def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): 1135 | return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) 1136 | 1137 | @staticmethod 1138 | def _reorder_cache(past, beam_idx): 1139 | reordered_past = () 1140 | for layer_past in past: 1141 | # cached cross_attention states don't have to be reordered -> they are always the same 1142 | reordered_past += ( 1143 | tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], 1144 | ) 1145 | return reordered_past 1146 | 1147 | 1148 | @add_start_docstrings( 1149 | """ 1150 | CPT model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE 1151 | tasks. 1152 | """, 1153 | CPT_START_DOCSTRING, 1154 | ) 1155 | class CPTForSequenceClassification(CPTPretrainedModel): 1156 | def __init__(self, config: CPTConfig, cls_mode=1, **kwargs): 1157 | super().__init__(config, **kwargs) 1158 | self.model = CPTModel(config) 1159 | cls_mode = getattr(config, 'cls_mode', cls_mode) 1160 | if cls_mode == 1: 1161 | logger.info('Encoder for classification.') 1162 | cls_dim = config.d_model 1163 | elif cls_mode == 2: 1164 | logger.info('Decoder for classification.') 1165 | cls_dim = config.d_model 1166 | elif cls_mode == 3: 1167 | logger.info('Both encoder & decoder for classification.') 1168 | cls_dim = config.d_model * 2 1169 | else: 1170 | raise NotImplementedError 1171 | 1172 | self.cls_head = CPTClassificationHead( 1173 | cls_dim, 1174 | cls_dim, 1175 | config.num_labels, 1176 | config.classifier_dropout, 1177 | ) 1178 | self.model._init_weights(self.cls_head.dense) 1179 | self.model._init_weights(self.cls_head.out_proj) 1180 | self.cls_mode = cls_mode 1181 | config.cls_mode = cls_mode 1182 | 1183 | @add_start_docstrings_to_model_forward(CPT_INPUTS_DOCSTRING) 1184 | @add_code_sample_docstrings( 1185 | checkpoint=_CHECKPOINT_FOR_DOC, 1186 | output_type=Seq2SeqSequenceClassifierOutput, 1187 | config_class=_CONFIG_FOR_DOC, 1188 | ) 1189 | def forward( 1190 | self, 1191 | input_ids=None, 1192 | attention_mask=None, 1193 | decoder_input_ids=None, 1194 | decoder_attention_mask=None, 1195 | head_mask=None, 1196 | decoder_head_mask=None, 1197 | encoder_outputs=None, 1198 | inputs_embeds=None, 1199 | decoder_inputs_embeds=None, 1200 | labels=None, 1201 | use_cache=None, 1202 | output_attentions=None, 1203 | output_hidden_states=None, 1204 | return_dict=None, 1205 | ): 1206 | r""" 1207 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1208 | Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., 1209 | config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1210 | """ 1211 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1212 | if labels is not None: 1213 | use_cache = False 1214 | 1215 | if input_ids is None and inputs_embeds is not None: 1216 | raise NotImplementedError( 1217 | f"Passing input embeddings is currently not supported for {self.__class__.__name__}" 1218 | ) 1219 | 1220 | outputs = self.model( 1221 | input_ids, 1222 | attention_mask=attention_mask, 1223 | decoder_input_ids=decoder_input_ids, 1224 | decoder_attention_mask=decoder_attention_mask, 1225 | head_mask=head_mask, 1226 | decoder_head_mask=decoder_head_mask, 1227 | encoder_outputs=encoder_outputs, 1228 | inputs_embeds=inputs_embeds, 1229 | decoder_inputs_embeds=decoder_inputs_embeds, 1230 | use_cache=use_cache, 1231 | output_attentions=output_attentions, 1232 | output_hidden_states=output_hidden_states, 1233 | return_dict=True, 1234 | ) 1235 | 1236 | hidden_states = outputs.last_hidden_state 1237 | enc_hidden_states = outputs.encoder_last_hidden_state 1238 | enc_rep = enc_hidden_states[:, 0] 1239 | 1240 | eos_mask = input_ids.eq(self.config.eos_token_id) 1241 | 1242 | if len(torch.unique(eos_mask.sum(1))) > 1: 1243 | raise ValueError("All examples must have the same number of tokens.") 1244 | dec_rep = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ 1245 | :, -1, : 1246 | ] 1247 | 1248 | if self.cls_mode == 1: 1249 | logits = self.cls_head(enc_rep) 1250 | elif self.cls_mode == 2: 1251 | logits = self.cls_head(dec_rep) 1252 | elif self.cls_mode == 3: 1253 | rep = torch.cat([enc_rep, dec_rep], dim=-1) 1254 | logits = self.cls_head(rep) 1255 | else: 1256 | raise NotImplementedError 1257 | 1258 | loss = None 1259 | if labels is not None: 1260 | loss_fct = CrossEntropyLoss() 1261 | loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) 1262 | 1263 | if not return_dict: 1264 | output = (logits,) + outputs[1:] 1265 | return ((loss,) + output) if loss is not None else output 1266 | 1267 | return Seq2SeqSequenceClassifierOutput( 1268 | loss=loss, 1269 | logits=logits, 1270 | past_key_values=outputs.past_key_values, 1271 | decoder_hidden_states=outputs.decoder_hidden_states, 1272 | decoder_attentions=outputs.decoder_attentions, 1273 | cross_attentions=outputs.cross_attentions, 1274 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 1275 | encoder_hidden_states=outputs.encoder_hidden_states, 1276 | encoder_attentions=outputs.encoder_attentions, 1277 | ) 1278 | 1279 | 1280 | @add_start_docstrings( 1281 | """ 1282 | CPT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear 1283 | layer on top of the hidden-states output to compute `span start logits` and `span end logits`). 1284 | """, 1285 | CPT_START_DOCSTRING, 1286 | ) 1287 | class CPTForQuestionAnswering(CPTPretrainedModel): 1288 | def __init__(self, config: CPTConfig, cls_mode=1, **kwargs): 1289 | super().__init__(config, **kwargs) 1290 | config.num_labels = 2 1291 | self.num_labels = config.num_labels 1292 | 1293 | self.model = CPTModel(config) 1294 | 1295 | cls_mode = getattr(config, 'cls_mode', cls_mode) 1296 | if cls_mode == 1: 1297 | logger.info('Encoder for classification.') 1298 | cls_dim = config.d_model 1299 | elif cls_mode == 2: 1300 | logger.info('Decoder for classification.') 1301 | cls_dim = config.d_model 1302 | elif cls_mode == 3: 1303 | logger.info('Both encoder & decoder for classification.') 1304 | cls_dim = config.d_model * 2 1305 | else: 1306 | raise NotImplementedError 1307 | 1308 | self.qa_outputs = nn.Linear(cls_dim, config.num_labels) 1309 | self.model._init_weights(self.qa_outputs) 1310 | 1311 | self.cls_mode = cls_mode 1312 | config.cls_mode = cls_mode 1313 | 1314 | @add_start_docstrings_to_model_forward(CPT_INPUTS_DOCSTRING) 1315 | @add_code_sample_docstrings( 1316 | checkpoint=_CHECKPOINT_FOR_DOC, 1317 | output_type=Seq2SeqSequenceClassifierOutput, 1318 | config_class=_CONFIG_FOR_DOC, 1319 | ) 1320 | def forward( 1321 | self, 1322 | input_ids=None, 1323 | attention_mask=None, 1324 | decoder_input_ids=None, 1325 | decoder_attention_mask=None, 1326 | head_mask=None, 1327 | decoder_head_mask=None, 1328 | encoder_outputs=None, 1329 | start_positions=None, 1330 | end_positions=None, 1331 | inputs_embeds=None, 1332 | decoder_inputs_embeds=None, 1333 | use_cache=None, 1334 | output_attentions=None, 1335 | output_hidden_states=None, 1336 | return_dict=None, 1337 | ): 1338 | r""" 1339 | start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1340 | Labels for position (index) of the start of the labelled span for computing the token classification loss. 1341 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 1342 | are not taken into account for computing the loss. 1343 | end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1344 | Labels for position (index) of the end of the labelled span for computing the token classification loss. 1345 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 1346 | are not taken into account for computing the loss. 1347 | """ 1348 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1349 | 1350 | if input_ids is None and inputs_embeds is not None: 1351 | raise NotImplementedError( 1352 | f"Passing input embeddings is currently not supported for {self.__class__.__name__}" 1353 | ) 1354 | 1355 | outputs = self.model( 1356 | input_ids, 1357 | attention_mask=attention_mask, 1358 | decoder_input_ids=decoder_input_ids, 1359 | decoder_attention_mask=decoder_attention_mask, 1360 | head_mask=head_mask, 1361 | decoder_head_mask=decoder_head_mask, 1362 | encoder_outputs=encoder_outputs, 1363 | inputs_embeds=inputs_embeds, 1364 | decoder_inputs_embeds=decoder_inputs_embeds, 1365 | use_cache=use_cache, 1366 | output_attentions=output_attentions, 1367 | output_hidden_states=output_hidden_states, 1368 | return_dict=True, 1369 | ) 1370 | 1371 | hidden_states = outputs.last_hidden_state 1372 | enc_hidden_states = outputs.encoder_last_hidden_state 1373 | 1374 | if self.cls_mode == 1: 1375 | logits = self.qa_outputs(enc_hidden_states) 1376 | elif self.cls_mode == 2: 1377 | logits = self.qa_outputs(hidden_states) 1378 | elif self.cls_mode == 3: 1379 | rep = torch.cat([enc_hidden_states, hidden_states], dim=-1) 1380 | logits = self.qa_outputs(rep) 1381 | else: 1382 | raise NotImplementedError 1383 | 1384 | start_logits, end_logits = logits.split(1, dim=-1) 1385 | start_logits = start_logits.squeeze(-1) 1386 | end_logits = end_logits.squeeze(-1) 1387 | 1388 | total_loss = None 1389 | if start_positions is not None and end_positions is not None: 1390 | # If we are on multi-GPU, split add a dimension 1391 | if len(start_positions.size()) > 1: 1392 | start_positions = start_positions.squeeze(-1) 1393 | if len(end_positions.size()) > 1: 1394 | end_positions = end_positions.squeeze(-1) 1395 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 1396 | ignored_index = start_logits.size(1) 1397 | start_positions.clamp_(0, ignored_index) 1398 | end_positions.clamp_(0, ignored_index) 1399 | 1400 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 1401 | start_loss = loss_fct(start_logits, start_positions) 1402 | end_loss = loss_fct(end_logits, end_positions) 1403 | total_loss = (start_loss + end_loss) / 2 1404 | 1405 | if not return_dict: 1406 | output = ( 1407 | start_logits, 1408 | end_logits, 1409 | ) + outputs[1:] 1410 | return ((total_loss,) + output) if total_loss is not None else output 1411 | 1412 | return Seq2SeqQuestionAnsweringModelOutput( 1413 | loss=total_loss, 1414 | start_logits=start_logits, 1415 | end_logits=end_logits, 1416 | past_key_values=outputs.past_key_values, 1417 | decoder_hidden_states=outputs.decoder_hidden_states, 1418 | decoder_attentions=outputs.decoder_attentions, 1419 | cross_attentions=outputs.cross_attentions, 1420 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 1421 | encoder_hidden_states=outputs.encoder_hidden_states, 1422 | encoder_attentions=outputs.encoder_attentions, 1423 | ) 1424 | 1425 | 1426 | class CPTForMaskedLM(CPTPretrainedModel): 1427 | _keys_to_ignore_on_load_missing = [ 1428 | r"final_logits_bias", 1429 | r"encoder\.version", 1430 | r"decoder\.version", 1431 | r"lm_head\.weight", 1432 | ] 1433 | def __init__(self, config, **kwargs): 1434 | super().__init__(config) 1435 | self.model = CPTModel(config) 1436 | self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) 1437 | self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) 1438 | 1439 | self.init_weights() 1440 | 1441 | def get_encoder(self): 1442 | return self.model.get_encoder() 1443 | 1444 | def get_decoder(self): 1445 | return self.model.get_decoder() 1446 | 1447 | def get_output_embeddings(self): 1448 | return self.lm_head 1449 | 1450 | def forward( 1451 | self, 1452 | input_ids=None, 1453 | attention_mask=None, 1454 | decoder_input_ids=None, 1455 | decoder_attention_mask=None, 1456 | head_mask=None, 1457 | decoder_head_mask=None, 1458 | encoder_outputs=None, 1459 | inputs_embeds=None, 1460 | decoder_inputs_embeds=None, 1461 | use_cache=None, 1462 | output_attentions=None, 1463 | output_hidden_states=None, 1464 | return_dict=None, 1465 | ): 1466 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1467 | 1468 | if input_ids is None and inputs_embeds is not None: 1469 | raise NotImplementedError( 1470 | f"Passing input embeddings is currently not supported for {self.__class__.__name__}" 1471 | ) 1472 | 1473 | outputs = self.model( 1474 | input_ids, 1475 | attention_mask=attention_mask, 1476 | decoder_input_ids=decoder_input_ids, 1477 | decoder_attention_mask=decoder_attention_mask, 1478 | head_mask=head_mask, 1479 | decoder_head_mask=decoder_head_mask, 1480 | encoder_outputs=encoder_outputs, 1481 | inputs_embeds=inputs_embeds, 1482 | decoder_inputs_embeds=decoder_inputs_embeds, 1483 | use_cache=use_cache, 1484 | output_attentions=output_attentions, 1485 | output_hidden_states=output_hidden_states, 1486 | return_dict=True, 1487 | ) 1488 | 1489 | hidden_states = outputs.last_hidden_state 1490 | enc_hidden_states = outputs.encoder_last_hidden_state 1491 | 1492 | dec_logits = self.lm_head(hidden_states) + self.final_logits_bias 1493 | enc_logits = self.lm_head(enc_hidden_states) + self.final_logits_bias 1494 | 1495 | if not return_dict: 1496 | logits = (enc_logits, dec_logits) 1497 | output = (logits,) + outputs[1:] 1498 | return output 1499 | 1500 | return Seq2SeqLMOutput( 1501 | loss=None, 1502 | logits=(enc_logits, dec_logits), 1503 | past_key_values=outputs.past_key_values, 1504 | decoder_hidden_states=outputs.decoder_hidden_states, 1505 | decoder_attentions=outputs.decoder_attentions, 1506 | cross_attentions=outputs.cross_attentions, 1507 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 1508 | encoder_hidden_states=outputs.encoder_hidden_states, 1509 | encoder_attentions=outputs.encoder_attentions, 1510 | ) --------------------------------------------------------------------------------