├── image
├── bl.png
├── fz.png
├── main.png
├── pj.png
├── sc.png
├── tl.png
├── zj.png
├── Triage.png
├── Dagnosis.png
└── Summary.png
├── Dagnosis
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── modeling_cpt.cpython-37.pyc
└── __init__.py
├── Summary
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── modeling_cpt.cpython-37.pyc
├── __init__.py
└── modeling_cpt.py
├── Triage
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── modeling_cpt.cpython-37.pyc
└── __init__.py
├── main.py
├── README_ZH.md
├── readme.md
└── modeling_cpt.py
/image/bl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/image/bl.png
--------------------------------------------------------------------------------
/image/fz.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/image/fz.png
--------------------------------------------------------------------------------
/image/main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/image/main.png
--------------------------------------------------------------------------------
/image/pj.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/image/pj.png
--------------------------------------------------------------------------------
/image/sc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/image/sc.png
--------------------------------------------------------------------------------
/image/tl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/image/tl.png
--------------------------------------------------------------------------------
/image/zj.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/image/zj.png
--------------------------------------------------------------------------------
/image/Triage.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/image/Triage.png
--------------------------------------------------------------------------------
/image/Dagnosis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/image/Dagnosis.png
--------------------------------------------------------------------------------
/image/Summary.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/image/Summary.png
--------------------------------------------------------------------------------
/Dagnosis/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/Dagnosis/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/Summary/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/Summary/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/Triage/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/Triage/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/Triage/__pycache__/modeling_cpt.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/Triage/__pycache__/modeling_cpt.cpython-37.pyc
--------------------------------------------------------------------------------
/Dagnosis/__pycache__/modeling_cpt.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/Dagnosis/__pycache__/modeling_cpt.cpython-37.pyc
--------------------------------------------------------------------------------
/Summary/__pycache__/modeling_cpt.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WENGSYX/LingYi/HEAD/Summary/__pycache__/modeling_cpt.cpython-37.pyc
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoModelForSequenceClassification,BertTokenizer
3 | from modeling_cpt import CPTForConditionalGeneration
4 | from Triage import *
5 | from Summary import *
6 | from Dagnosis import *
7 |
8 |
9 |
10 | if __name__ == '__main__':
11 | import argparse
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--type", default='Dagnosis',type=str)
14 | parser.add_argument("--mode", default='interactive', type=str)
15 | parser.add_argument("--file_name",default=None,type=str)
16 | parser.add_argument("--result_file_name", default='result.csv', type=str)
17 | parser.add_argument("--message", default=None, type=str)
18 | parser = parser.parse_args()
19 |
20 | exec('RSA = {}(parser)'.format(parser.type))
21 | print(RSA)
22 |
--------------------------------------------------------------------------------
/Triage/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoModelForSequenceClassification,BertTokenizer
3 | from .modeling_cpt import CPTForConditionalGeneration
4 | from Triage import *
5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6 | FENZHEN_MODEL_NAME = 'WENGSYX/Dagnosis_Chinese_BERT'
7 | CMDD_MODEL_NAME = 'WENGSYX/Dagnosis_Chinese_CPT'
8 | BL_MODEL_NAME = 'WENGSYX/Dagnosis_Chinese_CPT'
9 | fenzhen_model = AutoModelForSequenceClassification.from_pretrained(FENZHEN_MODEL_NAME, num_labels=6).to(device) # 模型
10 | fenzhen_tokenizer = BertTokenizer.from_pretrained(FENZHEN_MODEL_NAME)
11 |
12 | def Triage(parser):
13 |
14 | if parser.mode == 'interactive':
15 | print('我们将为您分配科室')
16 | while True:
17 | message = input('请输入您想询问的症状(退出请输入\033[0;35m退出\033[0m)')
18 | if message == '退出':
19 | return ''
20 | text = fenzhen_tokenizer(message, max_length=512, return_tensors='pt')
21 | input_ids = text['input_ids'].to(device)
22 | fenzhen_prodict = fenzhen_model(input_ids)[0]
23 | ks = {0:'男科',1:'内科',2:'妇科',3:'肿瘤科',4:'儿科',5:'外科'}[int(fenzhen_prodict.argmax())]
24 | print('您可能需要前往: \033[0;32m{}\033[0m'.format(ks))
25 | elif parser.mode == 'batch':
26 | with open(parser.file_name,'r',encoding='utf-8') as f:
27 | data = f.readlines()
28 | result = []
29 | for message in data:
30 | text = fenzhen_tokenizer(message.replace('\n',''), max_length=512, return_tensors='pt')
31 | input_ids = text['input_ids'].to(device)
32 | fenzhen_prodict = fenzhen_model(input_ids)[0]
33 | ks = {0:'男科',1:'内科',2:'妇科',3:'肿瘤科',4:'儿科',5:'外科'}[int(fenzhen_prodict.argmax())]
34 | result.append(ks)
35 | with open(parser.result_file_name,'w',encoding='utf-8') as f:
36 | for i in result:
37 | f.write(i+'\n')
38 |
39 | else:
40 | assert parser.message != None,print('请传入文本')
41 | text = fenzhen_tokenizer(parser.message, max_length=512, return_tensors='pt')
42 | input_ids = text['input_ids'].to(device)
43 | fenzhen_prodict = fenzhen_model(input_ids)[0]
44 | ks = {0: '男科', 1: '内科', 2: '妇科', 3: '肿瘤科', 4: '儿科', 5: '外科'}[int(fenzhen_prodict.argmax())]
45 | return ks
46 |
--------------------------------------------------------------------------------
/README_ZH.md:
--------------------------------------------------------------------------------
1 | # LingYi 技术文档
2 |
3 |
4 |
5 | ### 您可前往[灵医小智](http://kg.wengsyx.com)预览我们的系统。
6 |
7 | ###
8 |
9 | #### 如果可以通过准确有效的智能问诊系统,解决信任危机与医疗准确性问题,实时与患者进行沟通,收集患者信息,从而提升临床诊断的效率、减轻医生的负担、提高复诊积极性等,将可以有效缓解医生资源匮乏、医疗水平不平衡不充分、公众对互联网医疗信息存有疑虑等问题。
10 |
11 | #### 我们以中文医学领域为例,制作了面向个人的自动化诊疗系统,命名为“灵医小智”。
12 |
13 | #### “灵医小智”同时支持医院场景下智能预问诊台的搭建、药店场景下的咨询开药机器人搭建以及个人场景下的私人医生应用的搭建,深入医疗服务场景,凭借海量的医学知识和大规模预训练语言模型,服务于患者线上、线下医疗需求的全场景,提升医生服务效率,改善患者就医体验。
14 |
15 |
16 |
17 |
18 | 分诊 |
19 | 推理 |
20 | 生成 |
21 | 数据 |
22 | 病历
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 | 整体流程
31 |
32 |
33 | #### 我们分别制作了八大模块,全方位以对话形式轻松快捷准确地为用户解决医疗问题,相比于市面上的“灵医智慧”、"左手医生"等智慧医疗产品,我们的个性化私人医生“灵医小智”项目致力于打造主动式友好医疗,覆盖“诊前”“诊中”“诊后”,基于悟道模型卓越的生成性能和图像识别能力,多模态感知患者病情,灵活多变的拟人化语言与患者进行自由交流,极大改善用户体验,缓解了医院的线下医疗服务压力,助力于患者居家康复管理。
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 | ### 分诊
42 |
43 | ##### 根据实际场景及调研需求,引入分诊模块,缩小患病范围,提升问诊精度
44 |
45 |
46 |
47 |
48 |
49 | ### 推理
50 |
51 | ##### 实体消歧技术,有利于对用户话语进行归一化处理。我们使用了[ADBCMM技术](https://github.com/WENGSYX/ADBCMM),使得我们的模型能够准确判断用户所说实体的全称含义。
52 |
53 | ##### 消歧得到的实体,我们使用知识图谱推理技术,结合学界前沿知识问答技术,通过实体链接和多跳推理、路径排序等方式,并引入流程点方式,更为可控地搜寻下一步的目标实体,诸如可能患有的其他症状,进一步确诊所需的检查项目以及此疾病治疗所需药物等。
54 |
55 |
56 |
57 |
58 |
59 | ### 生成
60 |
61 | ##### 这一步中,将用到预训练生成模型。由于预训练生成模型是在通用文本上进行预训练,但是专业领域的推断能力不足。因此我们使用了ENTITY-PROMPT-LERARNING方法,在训练过程中,就将下句实体与流程点一并作为输入,进行训练。
62 |
63 | ##### 通过PROMPT的方式将对话上下文信息与实体信息进行融合,使得最后的回复具有预测的实体信息。
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 | ### 病历
72 |
73 | ##### 医生与病人进行自动化问诊之后,需要就诊疗过程进行就诊报告的撰写,对病人的整体情况情况进行描述,我们基于悟道CPM模型,采用Casual Language model等技术,对比最先进的方法(SOTA)在CCL数据集上领先平均得分2.05分
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 | # 项目评价指标
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 | # 项目总结
92 |
93 | **易用性**:通过医院机器人语音进行预问诊,并支持图片问诊,具备友好性。
94 |
95 | **功能数**:丰富的系统功能,具备八大医疗模块。
96 |
97 | **合作性**:产学研合作,分工明确,准确把握相关需求。
98 |
99 | **商业性**:多模模块中图文推荐药品一键购买。
100 |
101 | **覆盖面**:三十五大重点科室全覆盖。
102 |
103 | **专业性**:使用Prompt算法融入知识图谱推理的实体,相比直接生成更具专业性。
104 |
105 | **领先性**:采用本年度自然语言处理竞赛SOTA方案,进一步提升准确性。
106 |
107 | **可控性**:Entity-Prompt与流程点推理,增加可控性。
108 |
109 | **公益性**:有效缓解医生资源匮乏、医疗水平不平衡不充分。
110 |
111 |
112 |
113 |
114 |
115 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | ## MedConQA: Medical Conversational Question Answering System based on Knowledge Graphs
2 |
3 | ######
4 |
5 | ### How to use
6 |
7 | ##### We offer three experience types:
8 |
9 | *1. Triage* You can take your symptoms as input, and we will triage it.
10 |
11 | *2. Dagnosis* You can ask any medical questions
12 |
13 | *3. Summary* Take the doctor-patient dialogue as input, and the medical record will be output
14 |
15 | ##### We offer three input mode:
16 |
17 | *1.interactive* Support human-computer interaction experience demo
18 |
19 | *2.api* Provide API interface for direct call
20 |
21 | *3.batch* Provide input files for batch processing
22 |
23 |
24 |
25 | ### Example
26 |
27 | ```
28 | python main.py --mode api --type Dagnosis --message 我肚子好疼
29 | ```
30 |
31 |
32 |
33 | ```
34 | python main.py --mode interactiv --type Dagnosis
35 | ```
36 |
37 |
38 |
39 | ```
40 | python main.py --mode batch --type Summary --file_name input.csv --result_file_name result.csv
41 | ```
42 |
43 |
44 |
45 |
46 | ## Cite
47 | ```
48 | @inproceedings{xia-etal-2022-medconqa,
49 | title = "{M}ed{C}on{QA}: Medical Conversational Question Answering System based on Knowledge Graphs",
50 | author = "Xia, Fei and
51 | Li, Bin and
52 | Weng, Yixuan and
53 | He, Shizhu and
54 | Liu, Kang and
55 | Sun, Bin and
56 | Li, Shutao and
57 | Zhao, Jun",
58 | booktitle = "Proceedings of the The 2022 Conference on Empirical Methods in Natural Language Processing: System Demonstrations",
59 | month = dec,
60 | year = "2022",
61 | address = "Abu Dhabi, UAE",
62 | publisher = "Association for Computational Linguistics",
63 | url = "https://aclanthology.org/2022.emnlp-demos.15",
64 | pages = "148--158",
65 | abstract = "The medical conversational system can relieve doctors{'} burden and improve healthcare efficiency, especially during the COVID-19 pandemic. However, the existing medical dialogue systems have the problems of weak scalability, insufficient knowledge, and poor controllability. Thus, we propose a medical conversational question-answering (CQA) system based on the knowledge graph, namely MedConQA, which is designed as a pipeline framework to maintain high flexibility. Our system utilizes automated medical procedures, including medical triage, consultation, image-text drug recommendation, and record. Each module has been open-sourced as a tool, which can be used alone or in combination, with robust scalability. Besides, to conduct knowledge-grounded dialogues with users, we first construct a Chinese Medical Knowledge Graph (CMKG) and collect a large-scale Chinese Medical CQA (CMCQA) dataset, and we design a series of methods for reasoning more intellectually. Finally, we use several state-of-the-art (SOTA) techniques to keep the final generated response more controllable, which is further assured by hospital and professional evaluations. We have open-sourced related code, datasets, web pages, and tools, hoping to advance future research.",
66 | }
67 | ```
68 |
--------------------------------------------------------------------------------
/Dagnosis/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoModelForSequenceClassification,BertTokenizer
3 | from modeling_cpt import CPTForConditionalGeneration
4 | from Triage import *
5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6 | FENZHEN_MODEL_NAME = 'WENGSYX/Dagnosis_Chinese_BERT'
7 | CMDD_MODEL_NAME = 'WENGSYX/Dagnosis_Chinese_CPT'
8 | BL_MODEL_NAME = 'WENGSYX/Dagnosis_Chinese_CPT'
9 |
10 | cmdd_model = CPTForConditionalGeneration.from_pretrained(CMDD_MODEL_NAME)
11 | cmdd_model = cmdd_model.to(device)
12 | berttokenizer = BertTokenizer.from_pretrained(CMDD_MODEL_NAME)
13 |
14 |
15 | def Dagnosis(parser):
16 |
17 | if parser.mode == 'interactive':
18 | while True:
19 | message = input('医生:\033[0;34m您好,有什么我能帮您?\033[0m(退出请输入\033[0;35m退出\033[0m)')
20 | if message == '退出':
21 | return ''
22 | text = berttokenizer('患者:'+message,padding='max_length', truncation=True, max_length=512,return_tensors='pt')
23 | input_ids = text['input_ids'].to(device)
24 | attention_mask = text['attention_mask'].to(device)
25 | token_type_ids = text['token_type_ids'].to(device)
26 |
27 | out = berttokenizer.batch_decode(
28 | cmdd_model.generate(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, max_length=256))[0]
29 | out = out.replace('[SEP]', '').replace('[CLS]', '').replace(' ', '').replace('指导意见:', '').replace('病情分析:',
30 | '')
31 | print('医生: \033[0;34m{}\033[0m'.format(out))
32 | elif parser.mode == 'batch':
33 | with open(parser.file_name,'r',encoding='utf-8') as f:
34 | data = f.readlines()
35 | result = []
36 | for message in data:
37 | text = berttokenizer('患者:'+message,padding='max_length', truncation=True, max_length=512,return_tensors='pt')
38 | input_ids = text['input_ids'].to(device)
39 | attention_mask = text['attention_mask'].to(device)
40 | token_type_ids = text['token_type_ids'].to(device)
41 |
42 | out = berttokenizer.batch_decode(
43 | cmdd_model.generate(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, max_length=256))[0]
44 | out = out.replace('[SEP]', '').replace('[CLS]', '').replace(' ', '').replace('指导意见:', '').replace('病情分析:',
45 | '')
46 | result.append(out)
47 | with open(parser.result_file_name,'w',encoding='utf-8') as f:
48 | for i in result:
49 | f.write(i+'\n')
50 |
51 | else:
52 | assert parser.message != None,print('请传入文本')
53 | text = berttokenizer('患者:' + parser.message, padding='max_length', truncation=True, max_length=512, return_tensors='pt')
54 | input_ids = text['input_ids'].to(device)
55 | attention_mask = text['attention_mask'].to(device)
56 | token_type_ids = text['token_type_ids'].to(device)
57 |
58 | out = berttokenizer.batch_decode(
59 | cmdd_model.generate(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, max_length=256))[0]
60 | out = out.replace('[SEP]', '').replace('[CLS]', '').replace(' ', '').replace('指导意见:', '').replace('病情分析:',
61 | '')
62 | return out
63 |
--------------------------------------------------------------------------------
/Summary/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoModelForSequenceClassification,BertTokenizer
3 | from .modeling_cpt import CPTForConditionalGeneration
4 | from Triage import *
5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6 | FENZHEN_MODEL_NAME = 'WENGSYX/Dagnosis_Chinese_BERT'
7 | CMDD_MODEL_NAME = 'WENGSYX/Dagnosis_Chinese_CPT'
8 | BL_MODEL_NAME = 'WENGSYX/Dagnosis_Chinese_CPT'
9 | fenzhen_model = AutoModelForSequenceClassification.from_pretrained(FENZHEN_MODEL_NAME, num_labels=6).to(device) # 模型
10 | fenzhen_tokenizer = BertTokenizer.from_pretrained(FENZHEN_MODEL_NAME)
11 |
12 | bl_model = CPTForConditionalGeneration.from_pretrained(BL_MODEL_NAME)
13 | bl_model = bl_model.to(device)
14 | berttokenizer = BertTokenizer.from_pretrained(CMDD_MODEL_NAME)
15 |
16 | def Summary(parser):
17 |
18 | if parser.mode == 'interactive':
19 | while True:
20 | message = input('请输入对话历史,让我来帮您记录病例信息:(退出请输入\033[0;35m退出\033[0m)')
21 | if message == '退出':
22 | return ''
23 | text = berttokenizer(message,padding='max_length', truncation=True, max_length=512,return_tensors='pt')
24 | input_ids = text['input_ids'].to(device)
25 | attention_mask = text['attention_mask'].to(device)
26 | token_type_ids = text['token_type_ids'].to(device)
27 |
28 | out = berttokenizer.batch_decode(
29 | bl_model.generate(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, max_length=256))[0]
30 | out = out.replace('[SEP]', '').replace('[CLS]', '').replace(' ', '').replace('指导意见:', '').replace('病情分析:',
31 | '')
32 | print('医生: \033[0;34m{}\033[0m'.format(out))
33 | elif parser.mode == 'batch':
34 | with open(parser.file_name,'r',encoding='utf-8') as f:
35 | data = f.readlines()
36 | result = []
37 | for message in data:
38 | text = berttokenizer(message,padding='max_length', truncation=True, max_length=512,return_tensors='pt')
39 | input_ids = text['input_ids'].to(device)
40 | attention_mask = text['attention_mask'].to(device)
41 | token_type_ids = text['token_type_ids'].to(device)
42 |
43 | out = berttokenizer.batch_decode(
44 | bl_model.generate(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, max_length=256))[0]
45 | out = out.replace('[SEP]', '').replace('[CLS]', '').replace(' ', '').replace('指导意见:', '').replace('病情分析:',
46 | '')
47 | result.append(out)
48 | with open(parser.result_file_name,'w',encoding='utf-8') as f:
49 | for i in result:
50 | f.write(i+'\n')
51 |
52 | else:
53 | assert parser.message != None,print('请传入文本')
54 | text = berttokenizer(parser.message, padding='max_length', truncation=True, max_length=512, return_tensors='pt')
55 | input_ids = text['input_ids'].to(device)
56 | attention_mask = text['attention_mask'].to(device)
57 | token_type_ids = text['token_type_ids'].to(device)
58 |
59 | out = berttokenizer.batch_decode(
60 | bl_model.generate(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, max_length=256))[0]
61 | out = out.replace('[SEP]', '').replace('[CLS]', '').replace(' ', '').replace('指导意见:', '').replace('病情分析:','')
62 | return out
63 |
--------------------------------------------------------------------------------
/modeling_cpt.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ PyTorch CPT model. modified from transformers==4.4.1"""
16 | import copy
17 | import math
18 | import random
19 | import warnings
20 | from typing import Optional, Tuple
21 |
22 | import torch
23 | import torch.nn.functional as F
24 | from torch.nn.modules.loss import KLDivLoss, NLLLoss
25 | import torch.utils.checkpoint
26 | from torch import nn
27 | from torch.nn import CrossEntropyLoss
28 |
29 | from transformers.activations import ACT2FN
30 | from transformers.file_utils import (
31 | add_code_sample_docstrings,
32 | add_end_docstrings,
33 | add_start_docstrings,
34 | add_start_docstrings_to_model_forward,
35 | replace_return_docstrings,
36 | )
37 | from transformers.modeling_outputs import (
38 | BaseModelOutput,
39 | BaseModelOutputWithPastAndCrossAttentions,
40 | CausalLMOutputWithCrossAttentions,
41 | Seq2SeqLMOutput,
42 | Seq2SeqModelOutput,
43 | Seq2SeqQuestionAnsweringModelOutput,
44 | Seq2SeqSequenceClassifierOutput,
45 | )
46 | from transformers.modeling_utils import PreTrainedModel
47 | from transformers.utils import logging
48 | from transformers import BartConfig as CPTConfig
49 | from transformers import BertModel, BertConfig
50 |
51 | from torch.nn import LayerNorm
52 |
53 | logger = logging.get_logger(__name__)
54 |
55 | _CHECKPOINT_FOR_DOC = "fnlp/cpt-large"
56 | _CONFIG_FOR_DOC = "CPTConfig"
57 | _TOKENIZER_FOR_DOC = "CPTTokenizer"
58 |
59 |
60 | CPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
61 | "fnlp/cpt-large",
62 | ]
63 |
64 |
65 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
66 | """
67 | Shift input ids one token to the right.
68 | """
69 | shifted_input_ids = input_ids.new_zeros(input_ids.shape)
70 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
71 | shifted_input_ids[:, 0] = decoder_start_token_id
72 |
73 | assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
74 | # replace possible -100 values in labels by `pad_token_id`
75 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
76 |
77 | return shifted_input_ids
78 |
79 |
80 | def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
81 | """
82 | Make causal mask used for bi-directional self-attention.
83 | """
84 | bsz, tgt_len = input_ids_shape
85 | mask = torch.full((tgt_len, tgt_len), float("-inf"))
86 | mask_cond = torch.arange(mask.size(-1))
87 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
88 | mask = mask.to(dtype)
89 |
90 | if past_key_values_length > 0:
91 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
92 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
93 |
94 |
95 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
96 | """
97 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
98 | """
99 | bsz, src_len = mask.size()
100 | tgt_len = tgt_len if tgt_len is not None else src_len
101 |
102 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
103 |
104 | inverted_mask = 1.0 - expanded_mask
105 |
106 | return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
107 |
108 | def attention_mask_func(attention_scores, attention_mask):
109 | return attention_scores + attention_mask
110 |
111 | def init_method(std):
112 | def init_(tensor):
113 | return torch.nn.init.normal_(tensor, mean=0.0, std=std)
114 |
115 | return init_
116 |
117 | class CPTLearnedPositionalEmbedding(nn.Embedding):
118 | """
119 | This module learns positional embeddings up to a fixed maximum size.
120 | """
121 |
122 | def __init__(self, num_embeddings: int, embedding_dim: int):
123 | # CPT is set up so that if padding_idx is specified then offset the embedding ids by 2
124 | # and adjust num_embeddings appropriately. Other models dont have this hack
125 | self.offset = 2
126 | super().__init__(num_embeddings + self.offset, embedding_dim)
127 |
128 | def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
129 | """`input_ids_shape` is expected to be [bsz x seqlen]."""
130 | bsz, seq_len = input_ids_shape[:2]
131 | positions = torch.arange(
132 | past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
133 | )
134 | return super().forward(positions + self.offset)
135 |
136 |
137 | class CPTAttention(nn.Module):
138 | """Multi-headed attention from 'Attention Is All You Need' paper"""
139 |
140 | def __init__(
141 | self,
142 | embed_dim: int,
143 | num_heads: int,
144 | dropout: float = 0.0,
145 | is_decoder: bool = False,
146 | bias: bool = True,
147 | ):
148 | super().__init__()
149 | self.embed_dim = embed_dim
150 | self.num_heads = num_heads
151 | self.dropout = dropout
152 | self.head_dim = embed_dim // num_heads
153 | assert (
154 | self.head_dim * num_heads == self.embed_dim
155 | ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
156 | self.scaling = self.head_dim ** -0.5
157 | self.is_decoder = is_decoder
158 |
159 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
160 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
161 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
162 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
163 |
164 |
165 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
166 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
167 |
168 | def forward(
169 | self,
170 | hidden_states: torch.Tensor,
171 | key_value_states: Optional[torch.Tensor] = None,
172 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
173 | attention_mask: Optional[torch.Tensor] = None,
174 | layer_head_mask: Optional[torch.Tensor] = None,
175 | output_attentions: bool = False,
176 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
177 | """Input shape: Batch x Time x Channel"""
178 |
179 | # if key_value_states are provided this layer is used as a cross-attention layer
180 | # for the decoder
181 | is_cross_attention = key_value_states is not None
182 | bsz, tgt_len, embed_dim = hidden_states.size()
183 |
184 | # get query proj
185 | query_states = self.q_proj(hidden_states) * self.scaling
186 | # get key, value proj
187 | if is_cross_attention and past_key_value is not None:
188 | # reuse k,v, cross_attentions
189 | key_states = past_key_value[0]
190 | value_states = past_key_value[1]
191 | elif is_cross_attention:
192 | # cross_attentions
193 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
194 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
195 | elif past_key_value is not None:
196 | # reuse k, v, self_attention
197 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
198 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
199 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
200 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
201 | else:
202 | # self_attention
203 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
204 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
205 |
206 | if self.is_decoder:
207 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
208 | # Further calls to cross_attention layer can then reuse all cross-attention
209 | # key/value_states (first "if" case)
210 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
211 | # all previous decoder key/value_states. Further calls to uni-directional self-attention
212 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
213 | # if encoder bi-directional self-attention `past_key_value` is always `None`
214 | past_key_value = (key_states, value_states)
215 |
216 | proj_shape = (bsz * self.num_heads, -1, self.head_dim)
217 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
218 | key_states = key_states.view(*proj_shape)
219 | value_states = value_states.view(*proj_shape)
220 |
221 | src_len = key_states.size(1)
222 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
223 |
224 | assert attn_weights.size() == (
225 | bsz * self.num_heads,
226 | tgt_len,
227 | src_len,
228 | ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
229 |
230 | if attention_mask is not None:
231 | assert attention_mask.size() == (
232 | bsz,
233 | 1,
234 | tgt_len,
235 | src_len,
236 | ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
237 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
238 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
239 |
240 | attn_weights = F.softmax(attn_weights, dim=-1)
241 |
242 | if layer_head_mask is not None:
243 | assert layer_head_mask.size() == (
244 | self.num_heads,
245 | ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
246 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
247 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
248 |
249 | if output_attentions:
250 | # this operation is a bit akward, but it's required to
251 | # make sure that attn_weights keeps its gradient.
252 | # In order to do so, attn_weights have to reshaped
253 | # twice and have to be reused in the following
254 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
255 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
256 | else:
257 | attn_weights_reshaped = None
258 |
259 | # with mpu.get_cuda_rng_tracker().fork():
260 | attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
261 |
262 | attn_output = torch.bmm(attn_probs, value_states)
263 |
264 | assert attn_output.size() == (
265 | bsz * self.num_heads,
266 | tgt_len,
267 | self.head_dim,
268 | ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
269 |
270 | attn_output = (
271 | attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
272 | .transpose(1, 2)
273 | .reshape(bsz, tgt_len, embed_dim)
274 | )
275 |
276 | attn_output = self.out_proj(attn_output)
277 |
278 | return attn_output, attn_weights_reshaped, past_key_value
279 |
280 | class CPTDecoderLayer(nn.Module):
281 | def __init__(self, config: CPTConfig):
282 | super().__init__()
283 | self.embed_dim = config.d_model
284 |
285 | self.self_attn = CPTAttention(
286 | embed_dim=self.embed_dim,
287 | num_heads=config.decoder_attention_heads,
288 | dropout=config.attention_dropout,
289 | is_decoder=True,
290 | )
291 | self.dropout = config.dropout
292 | self.activation_fn = ACT2FN[config.activation_function]
293 | self.activation_dropout = config.activation_dropout
294 |
295 | self.self_attn_layer_norm = LayerNorm(self.embed_dim)
296 | self.encoder_attn = CPTAttention(
297 | self.embed_dim,
298 | config.decoder_attention_heads,
299 | dropout=config.attention_dropout,
300 | is_decoder=True,
301 | )
302 | self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
303 | self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
304 | self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
305 | self.final_layer_norm = LayerNorm(self.embed_dim)
306 |
307 | def forward(
308 | self,
309 | hidden_states: torch.Tensor,
310 | attention_mask: Optional[torch.Tensor] = None,
311 | encoder_hidden_states: Optional[torch.Tensor] = None,
312 | encoder_attention_mask: Optional[torch.Tensor] = None,
313 | layer_head_mask: Optional[torch.Tensor] = None,
314 | encoder_layer_head_mask: Optional[torch.Tensor] = None,
315 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
316 | output_attentions: Optional[bool] = False,
317 | use_cache: Optional[bool] = True,
318 | ):
319 | """
320 | Args:
321 | hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
322 | attention_mask (:obj:`torch.FloatTensor`): attention mask of size
323 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
324 | encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
325 | encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
326 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
327 | layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
328 | `(config.encoder_attention_heads,)`.
329 | encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
330 | size `(config.encoder_attention_heads,)`.
331 | past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
332 | output_attentions (:obj:`bool`, `optional`):
333 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
334 | returned tensors for more detail.
335 | """
336 | residual = hidden_states
337 |
338 | # Self Attention
339 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
340 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
341 | # add present self-attn cache to positions 1,2 of present_key_value tuple
342 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
343 | hidden_states=hidden_states,
344 | past_key_value=self_attn_past_key_value,
345 | attention_mask=attention_mask,
346 | layer_head_mask=layer_head_mask,
347 | output_attentions=output_attentions,
348 | )
349 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
350 | hidden_states = residual + hidden_states
351 | hidden_states = self.self_attn_layer_norm(hidden_states)
352 |
353 | # Cross-Attention Block
354 | cross_attn_present_key_value = None
355 | cross_attn_weights = None
356 | if encoder_hidden_states is not None:
357 | residual = hidden_states
358 |
359 | # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
360 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
361 | hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
362 | hidden_states=hidden_states,
363 | key_value_states=encoder_hidden_states,
364 | attention_mask=encoder_attention_mask,
365 | layer_head_mask=encoder_layer_head_mask,
366 | past_key_value=cross_attn_past_key_value,
367 | output_attentions=output_attentions,
368 | )
369 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
370 | hidden_states = residual + hidden_states
371 | hidden_states = self.encoder_attn_layer_norm(hidden_states)
372 |
373 | # add cross-attn to positions 3,4 of present_key_value tuple
374 | present_key_value = present_key_value + cross_attn_present_key_value
375 |
376 | # Fully Connected
377 | residual = hidden_states
378 | hidden_states = self.activation_fn(self.fc1(hidden_states))
379 | hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
380 | hidden_states = self.fc2(hidden_states)
381 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
382 | hidden_states = residual + hidden_states
383 | hidden_states = self.final_layer_norm(hidden_states)
384 |
385 | outputs = (hidden_states,)
386 |
387 | if output_attentions:
388 | outputs += (self_attn_weights, cross_attn_weights)
389 |
390 | if use_cache:
391 | outputs += (present_key_value,)
392 |
393 | return outputs
394 |
395 |
396 | class CPTClassificationHead(nn.Module):
397 | """Head for sentence-level classification tasks."""
398 |
399 | def __init__(
400 | self,
401 | input_dim: int,
402 | inner_dim: int,
403 | num_classes: int,
404 | pooler_dropout: float,
405 | ):
406 | super().__init__()
407 | self.dense = nn.Linear(input_dim, inner_dim)
408 | self.dropout = nn.Dropout(p=pooler_dropout)
409 | self.out_proj = nn.Linear(inner_dim, num_classes)
410 |
411 | def forward(self, hidden_states: torch.Tensor):
412 | hidden_states = self.dropout(hidden_states)
413 | hidden_states = self.dense(hidden_states)
414 | hidden_states = torch.tanh(hidden_states)
415 | hidden_states = self.dropout(hidden_states)
416 | hidden_states = self.out_proj(hidden_states)
417 | return hidden_states
418 |
419 |
420 | class CPTPretrainedModel(PreTrainedModel):
421 | config_class = CPTConfig
422 | base_model_prefix = "model"
423 |
424 | def _init_weights(self, module):
425 | std = self.config.init_std
426 | if isinstance(module, nn.Linear):
427 | module.weight.data.normal_(mean=0.0, std=std)
428 | if module.bias is not None:
429 | module.bias.data.zero_()
430 | elif isinstance(module, nn.Embedding):
431 | module.weight.data.normal_(mean=0.0, std=std)
432 | if module.padding_idx is not None:
433 | module.weight.data[module.padding_idx].zero_()
434 |
435 | @property
436 | def dummy_inputs(self):
437 | pad_token = self.config.pad_token_id
438 | input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
439 | dummy_inputs = {
440 | "attention_mask": input_ids.ne(pad_token),
441 | "input_ids": input_ids,
442 | }
443 | return dummy_inputs
444 |
445 | CPT_START_DOCSTRING = r"""
446 | This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
447 | methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
448 | pruning heads etc.)
449 |
450 | This model is also a PyTorch `torch.nn.Module `__
451 | subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
452 | general usage and behavior.
453 |
454 | Parameters:
455 | config (:class:`~transformers.CPTConfig`):
456 | Model configuration class with all the parameters of the model. Initializing with a config file does not
457 | load the weights associated with the model, only the configuration. Check out the
458 | :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
459 | """
460 |
461 | CPT_INPUTS_DOCSTRING = r"""
462 | Args:
463 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
464 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
465 | it.
466 |
467 | Indices can be obtained using :class:`~transformers.CPTTokenizer`. See
468 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
469 | details.
470 |
471 | `What are input IDs? <../glossary.html#input-ids>`__
472 | attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
473 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
474 |
475 | - 1 for tokens that are **not masked**,
476 | - 0 for tokens that are **masked**.
477 |
478 | `What are attention masks? <../glossary.html#attention-mask>`__
479 | decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
480 | Indices of decoder input sequence tokens in the vocabulary.
481 |
482 | Indices can be obtained using :class:`~transformers.CPTTokenizer`. See
483 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
484 | details.
485 |
486 | `What are input IDs? <../glossary.html#input-ids>`__
487 |
488 | CPT uses the :obj:`eos_token_id` as the starting token for :obj:`decoder_input_ids` generation. If
489 | :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see
490 | :obj:`past_key_values`).
491 |
492 | For translation and summarization training, :obj:`decoder_input_ids` should be provided. If no
493 | :obj:`decoder_input_ids` is provided, the model will create this tensor by shifting the :obj:`input_ids` to
494 | the right for denoising pre-training following the paper.
495 | decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
496 | Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
497 | also be used by default.
498 |
499 | If you want to change padding behavior, you should read :func:`modeling_cpt._prepare_decoder_inputs` and
500 | modify to your needs. See diagram 1 in `the paper `__ for more
501 | information on the default strategy.
502 | head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
503 | Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
504 |
505 | - 1 indicates the head is **not masked**,
506 | - 0 indicates the heas is **masked**.
507 |
508 | decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
509 | Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
510 |
511 | - 1 indicates the head is **not masked**,
512 | - 0 indicates the head is **masked**.
513 |
514 | encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
515 | Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
516 | :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
517 | `optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
518 | cross-attention of the decoder.
519 | past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
520 | Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
521 |
522 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
523 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
524 | instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
525 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
526 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
527 | This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
528 | vectors than the model's internal embedding lookup matrix.
529 | decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`):
530 | Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded
531 | representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds`
532 | have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert
533 | :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
534 |
535 | If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds`
536 | takes the value of :obj:`inputs_embeds`.
537 | use_cache (:obj:`bool`, `optional`):
538 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
539 | decoding (see :obj:`past_key_values`).
540 | output_attentions (:obj:`bool`, `optional`):
541 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
542 | tensors for more detail.
543 | output_hidden_states (:obj:`bool`, `optional`):
544 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
545 | more detail.
546 | return_dict (:obj:`bool`, `optional`):
547 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
548 | """
549 |
550 | class CPTDecoder(CPTPretrainedModel):
551 | """
552 | Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`CPTDecoderLayer`
553 |
554 | Args:
555 | config: CPTConfig
556 | embed_tokens (torch.nn.Embedding): output embedding
557 | """
558 |
559 | def __init__(self, config: CPTConfig, embed_tokens: Optional[nn.Embedding] = None):
560 | super().__init__(config)
561 | self.dropout = config.dropout
562 | self.layerdrop = config.decoder_layerdrop
563 | self.padding_idx = config.pad_token_id
564 | self.max_target_positions = config.max_position_embeddings
565 | self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
566 |
567 | if embed_tokens is not None:
568 | self.embed_tokens = embed_tokens
569 | else:
570 | self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
571 |
572 | self.embed_positions = CPTLearnedPositionalEmbedding(
573 | config.max_position_embeddings,
574 | config.d_model,
575 | )
576 | self.layers = nn.ModuleList([CPTDecoderLayer(config) for _ in range(config.decoder_layers)])
577 | self.layernorm_embedding = LayerNorm(config.d_model)
578 |
579 | self.init_weights()
580 |
581 | def get_input_embeddings(self):
582 | return self.embed_tokens
583 |
584 | def set_input_embeddings(self, value):
585 | self.embed_tokens = value
586 |
587 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
588 | # create causal mask
589 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
590 | combined_attention_mask = None
591 | if input_shape[-1] > 1:
592 | combined_attention_mask = _make_causal_mask(
593 | input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
594 | ).to(self.device)
595 |
596 | if attention_mask is not None:
597 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
598 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
599 | combined_attention_mask = (
600 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
601 | )
602 |
603 | return combined_attention_mask
604 |
605 | def forward(
606 | self,
607 | input_ids=None,
608 | attention_mask=None,
609 | encoder_hidden_states=None,
610 | encoder_attention_mask=None,
611 | head_mask=None,
612 | encoder_head_mask=None,
613 | past_key_values=None,
614 | inputs_embeds=None,
615 | use_cache=None,
616 | output_attentions=None,
617 | output_hidden_states=None,
618 | return_dict=None,
619 | ):
620 | r"""
621 | Args:
622 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
623 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
624 | provide it.
625 |
626 | Indices can be obtained using :class:`~transformers.CPTTokenizer`. See
627 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
628 | for details.
629 |
630 | `What are input IDs? <../glossary.html#input-ids>`__
631 | attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
632 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
633 |
634 | - 1 for tokens that are **not masked**,
635 | - 0 for tokens that are **masked**.
636 |
637 | `What are attention masks? <../glossary.html#attention-mask>`__
638 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`):
639 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
640 | of the decoder.
641 | encoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, encoder_sequence_length)`, `optional`):
642 | Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
643 | selected in ``[0, 1]``:
644 |
645 | - 1 for tokens that are **not masked**,
646 | - 0 for tokens that are **masked**.
647 |
648 | `What are attention masks? <../glossary.html#attention-mask>`__
649 | head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
650 | Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
651 |
652 | - 1 indicates the head is **not masked**,
653 | - 0 indicates the heas is **masked**.
654 |
655 | encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
656 | Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
657 | on hidden heads. Mask values selected in ``[0, 1]``:
658 |
659 | - 1 indicates the head is **not masked**,
660 | - 0 indicates the heas is **masked**.
661 |
662 | past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
663 | Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
664 | decoding.
665 |
666 | If :obj:`past_key_values` are used, the user can optionally input only the last
667 | :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of
668 | shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size,
669 | sequence_length)`.
670 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
671 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
672 | representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
673 | into associated vectors than the model's internal embedding lookup matrix.
674 | output_attentions (:obj:`bool`, `optional`):
675 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
676 | returned tensors for more detail.
677 | output_hidden_states (:obj:`bool`, `optional`):
678 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
679 | for more detail.
680 | return_dict (:obj:`bool`, `optional`):
681 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
682 | """
683 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
684 | output_hidden_states = (
685 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
686 | )
687 | use_cache = use_cache if use_cache is not None else self.config.use_cache
688 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
689 |
690 | # retrieve input_ids and inputs_embeds
691 | if input_ids is not None and inputs_embeds is not None:
692 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
693 | elif input_ids is not None:
694 | input_shape = input_ids.size()
695 | input_ids = input_ids.view(-1, input_shape[-1])
696 | elif inputs_embeds is not None:
697 | input_shape = inputs_embeds.size()[:-1]
698 | else:
699 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
700 |
701 | # past_key_values_length
702 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
703 |
704 | if inputs_embeds is None:
705 | inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
706 |
707 | attention_mask = self._prepare_decoder_attention_mask(
708 | attention_mask, input_shape, inputs_embeds, past_key_values_length
709 | )
710 |
711 | # expand encoder attention mask
712 | if encoder_hidden_states is not None and encoder_attention_mask is not None:
713 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
714 | encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
715 |
716 | # embed positions
717 | positions = self.embed_positions(input_shape, past_key_values_length)
718 |
719 | hidden_states = inputs_embeds + positions
720 | hidden_states = self.layernorm_embedding(hidden_states)
721 |
722 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
723 |
724 | # decoder layers
725 | all_hidden_states = () if output_hidden_states else None
726 | all_self_attns = () if output_attentions else None
727 | all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
728 | next_decoder_cache = () if use_cache else None
729 |
730 | # check if head_mask has a correct number of layers specified if desired
731 | if head_mask is not None:
732 | assert head_mask.size()[0] == (
733 | len(self.layers)
734 | ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
735 | for idx, decoder_layer in enumerate(self.layers):
736 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
737 | if output_hidden_states:
738 | all_hidden_states += (hidden_states,)
739 | dropout_probability = random.uniform(0, 1)
740 | if self.training and (dropout_probability < self.layerdrop):
741 | continue
742 |
743 | past_key_value = past_key_values[idx] if past_key_values is not None else None
744 |
745 | if getattr(self.config, "gradient_checkpointing", False) and self.training:
746 |
747 | if use_cache:
748 | logger.warn(
749 | "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
750 | "`use_cache=False`..."
751 | )
752 | use_cache = False
753 |
754 | def create_custom_forward(module):
755 | def custom_forward(*inputs):
756 | # None for past_key_value
757 | return module(*inputs, output_attentions, use_cache)
758 |
759 | return custom_forward
760 |
761 | # layer_outputs = mpu.checkpoint(
762 | layer_outputs = torch.utils.checkpoint(
763 | create_custom_forward(decoder_layer),
764 | hidden_states,
765 | attention_mask,
766 | encoder_hidden_states,
767 | encoder_attention_mask,
768 | head_mask[idx] if head_mask is not None else None,
769 | encoder_head_mask[idx] if encoder_head_mask is not None else None,
770 | None,
771 | )
772 | else:
773 |
774 | layer_outputs = decoder_layer(
775 | hidden_states,
776 | attention_mask=attention_mask,
777 | encoder_hidden_states=encoder_hidden_states,
778 | encoder_attention_mask=encoder_attention_mask,
779 | layer_head_mask=(head_mask[idx] if head_mask is not None else None),
780 | encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
781 | past_key_value=past_key_value,
782 | output_attentions=output_attentions,
783 | use_cache=use_cache,
784 | )
785 | hidden_states = layer_outputs[0]
786 |
787 | if use_cache:
788 | next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
789 |
790 | if output_attentions:
791 | all_self_attns += (layer_outputs[1],)
792 |
793 | if encoder_hidden_states is not None:
794 | all_cross_attentions += (layer_outputs[2],)
795 |
796 | # add hidden states from the last decoder layer
797 | if output_hidden_states:
798 | all_hidden_states += (hidden_states,)
799 |
800 | next_cache = next_decoder_cache if use_cache else None
801 | if not return_dict:
802 | return tuple(
803 | v
804 | for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
805 | if v is not None
806 | )
807 | return BaseModelOutputWithPastAndCrossAttentions(
808 | last_hidden_state=hidden_states,
809 | past_key_values=next_cache,
810 | hidden_states=all_hidden_states,
811 | attentions=all_self_attns,
812 | cross_attentions=all_cross_attentions,
813 | )
814 |
815 |
816 | @add_start_docstrings(
817 | "The bare CPT Model outputting raw hidden-states without any specific head on top.",
818 | CPT_START_DOCSTRING,
819 | )
820 | class CPTModel(CPTPretrainedModel):
821 | def __init__(self, config: CPTConfig):
822 | super().__init__(config)
823 | encoder_config = BertConfig(
824 | vocab_size=config.vocab_size,
825 | hidden_size=config.d_model,
826 | num_hidden_layers=config.encoder_layers,
827 | num_attention_heads=config.encoder_attention_heads,
828 | intermediate_size=config.encoder_ffn_dim,
829 | hidden_dropout_prob=config.activation_dropout,
830 | attention_probs_dropout_prob=config.attention_dropout,
831 | )
832 | config.vocab_size = encoder_config.vocab_size
833 | self.encoder = BertModel(encoder_config, add_pooling_layer=False)
834 | self.shared = self.encoder.get_input_embeddings()
835 | self.decoder = CPTDecoder(config, self.shared)
836 | self.num_decoder_layers = config.decoder_layers
837 | self.init_weights()
838 |
839 | def get_input_embeddings(self):
840 | return self.shared
841 |
842 | def set_input_embeddings(self, value):
843 | self.shared = value
844 | self.encoder.embed_tokens = self.shared
845 | self.decoder.embed_tokens = self.shared
846 |
847 | def get_encoder(self):
848 | class _Encoder(torch.nn.Module):
849 | def __init__(self, encoder):
850 | super().__init__()
851 | self.encoder = encoder
852 |
853 | def forward(self, *args, **kwargs):
854 | kwargs['output_hidden_states'] = True
855 | return self.encoder(*args, **kwargs)
856 | return _Encoder(self.encoder)
857 |
858 | def get_decoder(self):
859 | return self.decoder
860 |
861 | @add_start_docstrings_to_model_forward(CPT_INPUTS_DOCSTRING)
862 | @add_code_sample_docstrings(
863 | checkpoint=_CHECKPOINT_FOR_DOC,
864 | output_type=Seq2SeqModelOutput,
865 | )
866 | def forward(
867 | self,
868 | input_ids=None,
869 | token_type_ids=None,
870 | attention_mask=None,
871 | decoder_input_ids=None,
872 | decoder_attention_mask=None,
873 | head_mask=None,
874 | decoder_head_mask=None,
875 | encoder_outputs=None,
876 | past_key_values=None,
877 | inputs_embeds=None,
878 | decoder_inputs_embeds=None,
879 | use_cache=None,
880 | output_attentions=None,
881 | output_hidden_states=None,
882 | return_dict=None,
883 | ):
884 |
885 | # different to other models, CPT automatically creates decoder_input_ids from
886 | # input_ids if no decoder_input_ids are provided
887 | if decoder_input_ids is None and decoder_inputs_embeds is None:
888 | decoder_input_ids = shift_tokens_right(
889 | input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
890 | )
891 |
892 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
893 | output_hidden_states = (
894 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
895 | )
896 | use_cache = use_cache if use_cache is not None else self.config.use_cache
897 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
898 |
899 | if getattr(self.config, "gradient_checkpointing", False) and self.training:
900 | # mpu.reset_checkpointed_activations_memory_buffer()
901 | use_cache = False
902 |
903 | if encoder_outputs is None:
904 | encoder_outputs = self.encoder(
905 | input_ids=input_ids,
906 | attention_mask=attention_mask,
907 | token_type_ids=token_type_ids,
908 | head_mask=head_mask,
909 | inputs_embeds=inputs_embeds,
910 | output_attentions=output_attentions,
911 | output_hidden_states=True,
912 | return_dict=return_dict,
913 | )
914 | # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
915 | elif return_dict and isinstance(encoder_outputs, (tuple, list)):
916 | encoder_outputs = BaseModelOutput(
917 | last_hidden_state=encoder_outputs[0],
918 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
919 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
920 | )
921 |
922 | if isinstance(encoder_outputs, (torch.Tensor)):
923 | encoder_hidden_states = encoder_outputs
924 | else:
925 | encoder_hidden_states = encoder_outputs[1][-self.num_decoder_layers - 1]
926 |
927 | # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
928 | decoder_outputs = self.decoder(
929 | input_ids=decoder_input_ids,
930 | attention_mask=decoder_attention_mask,
931 | encoder_hidden_states=encoder_hidden_states,
932 | encoder_attention_mask=attention_mask,
933 | head_mask=decoder_head_mask,
934 | encoder_head_mask=head_mask,
935 | past_key_values=past_key_values,
936 | inputs_embeds=decoder_inputs_embeds,
937 | use_cache=use_cache,
938 | output_attentions=output_attentions,
939 | output_hidden_states=output_hidden_states,
940 | return_dict=return_dict,
941 | )
942 |
943 | if not return_dict:
944 | return decoder_outputs + encoder_outputs
945 |
946 | return Seq2SeqModelOutput(
947 | last_hidden_state=decoder_outputs.last_hidden_state,
948 | past_key_values=decoder_outputs.past_key_values,
949 | decoder_hidden_states=decoder_outputs.hidden_states,
950 | decoder_attentions=decoder_outputs.attentions,
951 | cross_attentions=decoder_outputs.cross_attentions,
952 | encoder_last_hidden_state=encoder_outputs.last_hidden_state if isinstance(encoder_outputs, dict) else None,
953 | encoder_hidden_states=encoder_outputs.hidden_states if isinstance(encoder_outputs, dict) else None,
954 | encoder_attentions=encoder_outputs.attentions if isinstance(encoder_outputs, dict) else None,
955 | )
956 |
957 |
958 | @add_start_docstrings(
959 | "The CPT Model with a language modeling head. Can be used for summarization.", CPT_START_DOCSTRING
960 | )
961 | class CPTForConditionalGeneration(CPTPretrainedModel):
962 | base_model_prefix = "model"
963 | _keys_to_ignore_on_load_missing = [
964 | r"final_logits_bias",
965 | r"encoder\.version",
966 | r"decoder\.version",
967 | r"lm_head\.weight",
968 | ]
969 |
970 | def __init__(self, config):
971 | super().__init__(config)
972 | self.model = CPTModel(config)
973 | self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
974 | self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
975 |
976 | self.init_weights()
977 |
978 | def get_encoder(self):
979 | return self.model.get_encoder()
980 |
981 | def get_decoder(self):
982 | return self.model.get_decoder()
983 |
984 | def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
985 | new_embeddings = super().resize_token_embeddings(new_num_tokens)
986 | self._resize_final_logits_bias(new_num_tokens)
987 | return new_embeddings
988 |
989 | def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
990 | old_num_tokens = self.final_logits_bias.shape[-1]
991 | if new_num_tokens <= old_num_tokens:
992 | new_bias = self.final_logits_bias[:, :new_num_tokens]
993 | else:
994 | extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
995 | new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
996 | self.register_buffer("final_logits_bias", new_bias)
997 |
998 | def get_output_embeddings(self):
999 | return self.lm_head
1000 |
1001 | def set_output_embeddings(self, new_embeddings):
1002 | self.lm_head = new_embeddings
1003 |
1004 | @add_start_docstrings_to_model_forward(CPT_INPUTS_DOCSTRING)
1005 | @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1006 | def forward(
1007 | self,
1008 | input_ids=None,
1009 | attention_mask=None,
1010 | token_type_ids=None,
1011 | decoder_input_ids=None,
1012 | decoder_attention_mask=None,
1013 | head_mask=None,
1014 | decoder_head_mask=None,
1015 | encoder_outputs=None,
1016 | past_key_values=None,
1017 | inputs_embeds=None,
1018 | decoder_inputs_embeds=None,
1019 | labels=None,
1020 | use_cache=None,
1021 | output_attentions=None,
1022 | output_hidden_states=None,
1023 | return_dict=None,
1024 | ):
1025 | r"""
1026 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1027 | Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
1028 | config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored
1029 | (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``.
1030 |
1031 | Returns:
1032 | """
1033 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1034 |
1035 | if labels is not None:
1036 | if decoder_input_ids is None:
1037 | decoder_input_ids = shift_tokens_right(
1038 | labels, self.config.pad_token_id, self.config.decoder_start_token_id
1039 | )
1040 |
1041 | outputs = self.model(
1042 | input_ids,
1043 | attention_mask=attention_mask,
1044 | decoder_input_ids=decoder_input_ids,
1045 | encoder_outputs=encoder_outputs,
1046 | token_type_ids=token_type_ids,
1047 | decoder_attention_mask=decoder_attention_mask,
1048 | head_mask=head_mask,
1049 | decoder_head_mask=decoder_head_mask,
1050 | past_key_values=past_key_values,
1051 | inputs_embeds=inputs_embeds,
1052 | decoder_inputs_embeds=decoder_inputs_embeds,
1053 | use_cache=use_cache,
1054 | output_attentions=output_attentions,
1055 | output_hidden_states=output_hidden_states,
1056 | return_dict=return_dict,
1057 | )
1058 | lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
1059 |
1060 | masked_lm_loss = None
1061 | if labels is not None:
1062 | loss_fct = CrossEntropyLoss()
1063 | masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
1064 |
1065 | if not return_dict:
1066 | output = (lm_logits,) + outputs[1:]
1067 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1068 |
1069 | return Seq2SeqLMOutput(
1070 | loss=masked_lm_loss,
1071 | logits=lm_logits,
1072 | past_key_values=outputs.past_key_values,
1073 | decoder_hidden_states=outputs.decoder_hidden_states,
1074 | decoder_attentions=outputs.decoder_attentions,
1075 | cross_attentions=outputs.cross_attentions,
1076 | encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1077 | encoder_hidden_states=outputs.encoder_hidden_states,
1078 | encoder_attentions=outputs.encoder_attentions,
1079 | )
1080 |
1081 | def prepare_inputs_for_generation(
1082 | self,
1083 | decoder_input_ids,
1084 | past=None,
1085 | attention_mask=None,
1086 | head_mask=None,
1087 | use_cache=None,
1088 | encoder_outputs=None,
1089 | **kwargs
1090 | ):
1091 | # cut decoder_input_ids if past is used
1092 | if past is not None:
1093 | decoder_input_ids = decoder_input_ids[:, -1:]
1094 |
1095 | return {
1096 | "input_ids": None, # encoder_outputs is defined. input_ids not needed
1097 | "encoder_outputs": encoder_outputs,
1098 | "past_key_values": past,
1099 | "decoder_input_ids": decoder_input_ids,
1100 | "attention_mask": attention_mask,
1101 | "head_mask": head_mask,
1102 | "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1103 | }
1104 |
1105 | @staticmethod
1106 | def _expand_inputs_for_generation(
1107 | input_ids: torch.LongTensor,
1108 | expand_size: int = 1,
1109 | is_encoder_decoder: bool = False,
1110 | attention_mask: torch.LongTensor = None,
1111 | encoder_outputs = None,
1112 | **model_kwargs,
1113 | ):
1114 | expanded_return_idx = (
1115 | torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
1116 | )
1117 | input_ids = input_ids.index_select(0, expanded_return_idx)
1118 |
1119 | if "token_type_ids" in model_kwargs:
1120 | token_type_ids = model_kwargs["token_type_ids"]
1121 | model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)
1122 |
1123 | if attention_mask is not None:
1124 | model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
1125 |
1126 | if is_encoder_decoder:
1127 | assert encoder_outputs is not None
1128 | device = encoder_outputs.last_hidden_state.device
1129 | encoder_outputs["hidden_states"] = tuple(h.index_select(0, expanded_return_idx.to(device)) \
1130 | for h in encoder_outputs["hidden_states"])
1131 | model_kwargs["encoder_outputs"] = encoder_outputs
1132 | return input_ids, model_kwargs
1133 |
1134 | def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1135 | return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
1136 |
1137 | @staticmethod
1138 | def _reorder_cache(past, beam_idx):
1139 | reordered_past = ()
1140 | for layer_past in past:
1141 | # cached cross_attention states don't have to be reordered -> they are always the same
1142 | reordered_past += (
1143 | tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
1144 | )
1145 | return reordered_past
1146 |
1147 |
1148 | @add_start_docstrings(
1149 | """
1150 | CPT model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
1151 | tasks.
1152 | """,
1153 | CPT_START_DOCSTRING,
1154 | )
1155 | class CPTForSequenceClassification(CPTPretrainedModel):
1156 | def __init__(self, config: CPTConfig, cls_mode=1, **kwargs):
1157 | super().__init__(config, **kwargs)
1158 | self.model = CPTModel(config)
1159 | cls_mode = getattr(config, 'cls_mode', cls_mode)
1160 | if cls_mode == 1:
1161 | logger.info('Encoder for classification.')
1162 | cls_dim = config.d_model
1163 | elif cls_mode == 2:
1164 | logger.info('Decoder for classification.')
1165 | cls_dim = config.d_model
1166 | elif cls_mode == 3:
1167 | logger.info('Both encoder & decoder for classification.')
1168 | cls_dim = config.d_model * 2
1169 | else:
1170 | raise NotImplementedError
1171 |
1172 | self.cls_head = CPTClassificationHead(
1173 | cls_dim,
1174 | cls_dim,
1175 | config.num_labels,
1176 | config.classifier_dropout,
1177 | )
1178 | self.model._init_weights(self.cls_head.dense)
1179 | self.model._init_weights(self.cls_head.out_proj)
1180 | self.cls_mode = cls_mode
1181 | config.cls_mode = cls_mode
1182 |
1183 | @add_start_docstrings_to_model_forward(CPT_INPUTS_DOCSTRING)
1184 | @add_code_sample_docstrings(
1185 | checkpoint=_CHECKPOINT_FOR_DOC,
1186 | output_type=Seq2SeqSequenceClassifierOutput,
1187 | config_class=_CONFIG_FOR_DOC,
1188 | )
1189 | def forward(
1190 | self,
1191 | input_ids=None,
1192 | attention_mask=None,
1193 | decoder_input_ids=None,
1194 | decoder_attention_mask=None,
1195 | head_mask=None,
1196 | decoder_head_mask=None,
1197 | encoder_outputs=None,
1198 | inputs_embeds=None,
1199 | decoder_inputs_embeds=None,
1200 | labels=None,
1201 | use_cache=None,
1202 | output_attentions=None,
1203 | output_hidden_states=None,
1204 | return_dict=None,
1205 | ):
1206 | r"""
1207 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1208 | Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1209 | config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1210 | """
1211 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1212 | if labels is not None:
1213 | use_cache = False
1214 |
1215 | if input_ids is None and inputs_embeds is not None:
1216 | raise NotImplementedError(
1217 | f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
1218 | )
1219 |
1220 | outputs = self.model(
1221 | input_ids,
1222 | attention_mask=attention_mask,
1223 | decoder_input_ids=decoder_input_ids,
1224 | decoder_attention_mask=decoder_attention_mask,
1225 | head_mask=head_mask,
1226 | decoder_head_mask=decoder_head_mask,
1227 | encoder_outputs=encoder_outputs,
1228 | inputs_embeds=inputs_embeds,
1229 | decoder_inputs_embeds=decoder_inputs_embeds,
1230 | use_cache=use_cache,
1231 | output_attentions=output_attentions,
1232 | output_hidden_states=output_hidden_states,
1233 | return_dict=True,
1234 | )
1235 |
1236 | hidden_states = outputs.last_hidden_state
1237 | enc_hidden_states = outputs.encoder_last_hidden_state
1238 | enc_rep = enc_hidden_states[:, 0]
1239 |
1240 | eos_mask = input_ids.eq(self.config.eos_token_id)
1241 |
1242 | if len(torch.unique(eos_mask.sum(1))) > 1:
1243 | raise ValueError("All examples must have the same number of tokens.")
1244 | dec_rep = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
1245 | :, -1, :
1246 | ]
1247 |
1248 | if self.cls_mode == 1:
1249 | logits = self.cls_head(enc_rep)
1250 | elif self.cls_mode == 2:
1251 | logits = self.cls_head(dec_rep)
1252 | elif self.cls_mode == 3:
1253 | rep = torch.cat([enc_rep, dec_rep], dim=-1)
1254 | logits = self.cls_head(rep)
1255 | else:
1256 | raise NotImplementedError
1257 |
1258 | loss = None
1259 | if labels is not None:
1260 | loss_fct = CrossEntropyLoss()
1261 | loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1262 |
1263 | if not return_dict:
1264 | output = (logits,) + outputs[1:]
1265 | return ((loss,) + output) if loss is not None else output
1266 |
1267 | return Seq2SeqSequenceClassifierOutput(
1268 | loss=loss,
1269 | logits=logits,
1270 | past_key_values=outputs.past_key_values,
1271 | decoder_hidden_states=outputs.decoder_hidden_states,
1272 | decoder_attentions=outputs.decoder_attentions,
1273 | cross_attentions=outputs.cross_attentions,
1274 | encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1275 | encoder_hidden_states=outputs.encoder_hidden_states,
1276 | encoder_attentions=outputs.encoder_attentions,
1277 | )
1278 |
1279 |
1280 | @add_start_docstrings(
1281 | """
1282 | CPT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1283 | layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1284 | """,
1285 | CPT_START_DOCSTRING,
1286 | )
1287 | class CPTForQuestionAnswering(CPTPretrainedModel):
1288 | def __init__(self, config: CPTConfig, cls_mode=1, **kwargs):
1289 | super().__init__(config, **kwargs)
1290 | config.num_labels = 2
1291 | self.num_labels = config.num_labels
1292 |
1293 | self.model = CPTModel(config)
1294 |
1295 | cls_mode = getattr(config, 'cls_mode', cls_mode)
1296 | if cls_mode == 1:
1297 | logger.info('Encoder for classification.')
1298 | cls_dim = config.d_model
1299 | elif cls_mode == 2:
1300 | logger.info('Decoder for classification.')
1301 | cls_dim = config.d_model
1302 | elif cls_mode == 3:
1303 | logger.info('Both encoder & decoder for classification.')
1304 | cls_dim = config.d_model * 2
1305 | else:
1306 | raise NotImplementedError
1307 |
1308 | self.qa_outputs = nn.Linear(cls_dim, config.num_labels)
1309 | self.model._init_weights(self.qa_outputs)
1310 |
1311 | self.cls_mode = cls_mode
1312 | config.cls_mode = cls_mode
1313 |
1314 | @add_start_docstrings_to_model_forward(CPT_INPUTS_DOCSTRING)
1315 | @add_code_sample_docstrings(
1316 | checkpoint=_CHECKPOINT_FOR_DOC,
1317 | output_type=Seq2SeqSequenceClassifierOutput,
1318 | config_class=_CONFIG_FOR_DOC,
1319 | )
1320 | def forward(
1321 | self,
1322 | input_ids=None,
1323 | attention_mask=None,
1324 | decoder_input_ids=None,
1325 | decoder_attention_mask=None,
1326 | head_mask=None,
1327 | decoder_head_mask=None,
1328 | encoder_outputs=None,
1329 | start_positions=None,
1330 | end_positions=None,
1331 | inputs_embeds=None,
1332 | decoder_inputs_embeds=None,
1333 | use_cache=None,
1334 | output_attentions=None,
1335 | output_hidden_states=None,
1336 | return_dict=None,
1337 | ):
1338 | r"""
1339 | start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1340 | Labels for position (index) of the start of the labelled span for computing the token classification loss.
1341 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1342 | are not taken into account for computing the loss.
1343 | end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1344 | Labels for position (index) of the end of the labelled span for computing the token classification loss.
1345 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1346 | are not taken into account for computing the loss.
1347 | """
1348 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1349 |
1350 | if input_ids is None and inputs_embeds is not None:
1351 | raise NotImplementedError(
1352 | f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
1353 | )
1354 |
1355 | outputs = self.model(
1356 | input_ids,
1357 | attention_mask=attention_mask,
1358 | decoder_input_ids=decoder_input_ids,
1359 | decoder_attention_mask=decoder_attention_mask,
1360 | head_mask=head_mask,
1361 | decoder_head_mask=decoder_head_mask,
1362 | encoder_outputs=encoder_outputs,
1363 | inputs_embeds=inputs_embeds,
1364 | decoder_inputs_embeds=decoder_inputs_embeds,
1365 | use_cache=use_cache,
1366 | output_attentions=output_attentions,
1367 | output_hidden_states=output_hidden_states,
1368 | return_dict=True,
1369 | )
1370 |
1371 | hidden_states = outputs.last_hidden_state
1372 | enc_hidden_states = outputs.encoder_last_hidden_state
1373 |
1374 | if self.cls_mode == 1:
1375 | logits = self.qa_outputs(enc_hidden_states)
1376 | elif self.cls_mode == 2:
1377 | logits = self.qa_outputs(hidden_states)
1378 | elif self.cls_mode == 3:
1379 | rep = torch.cat([enc_hidden_states, hidden_states], dim=-1)
1380 | logits = self.qa_outputs(rep)
1381 | else:
1382 | raise NotImplementedError
1383 |
1384 | start_logits, end_logits = logits.split(1, dim=-1)
1385 | start_logits = start_logits.squeeze(-1)
1386 | end_logits = end_logits.squeeze(-1)
1387 |
1388 | total_loss = None
1389 | if start_positions is not None and end_positions is not None:
1390 | # If we are on multi-GPU, split add a dimension
1391 | if len(start_positions.size()) > 1:
1392 | start_positions = start_positions.squeeze(-1)
1393 | if len(end_positions.size()) > 1:
1394 | end_positions = end_positions.squeeze(-1)
1395 | # sometimes the start/end positions are outside our model inputs, we ignore these terms
1396 | ignored_index = start_logits.size(1)
1397 | start_positions.clamp_(0, ignored_index)
1398 | end_positions.clamp_(0, ignored_index)
1399 |
1400 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1401 | start_loss = loss_fct(start_logits, start_positions)
1402 | end_loss = loss_fct(end_logits, end_positions)
1403 | total_loss = (start_loss + end_loss) / 2
1404 |
1405 | if not return_dict:
1406 | output = (
1407 | start_logits,
1408 | end_logits,
1409 | ) + outputs[1:]
1410 | return ((total_loss,) + output) if total_loss is not None else output
1411 |
1412 | return Seq2SeqQuestionAnsweringModelOutput(
1413 | loss=total_loss,
1414 | start_logits=start_logits,
1415 | end_logits=end_logits,
1416 | past_key_values=outputs.past_key_values,
1417 | decoder_hidden_states=outputs.decoder_hidden_states,
1418 | decoder_attentions=outputs.decoder_attentions,
1419 | cross_attentions=outputs.cross_attentions,
1420 | encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1421 | encoder_hidden_states=outputs.encoder_hidden_states,
1422 | encoder_attentions=outputs.encoder_attentions,
1423 | )
1424 |
1425 |
1426 | class CPTForMaskedLM(CPTPretrainedModel):
1427 | _keys_to_ignore_on_load_missing = [
1428 | r"final_logits_bias",
1429 | r"encoder\.version",
1430 | r"decoder\.version",
1431 | r"lm_head\.weight",
1432 | ]
1433 | def __init__(self, config, **kwargs):
1434 | super().__init__(config)
1435 | self.model = CPTModel(config)
1436 | self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
1437 | self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
1438 |
1439 | self.init_weights()
1440 |
1441 | def get_encoder(self):
1442 | return self.model.get_encoder()
1443 |
1444 | def get_decoder(self):
1445 | return self.model.get_decoder()
1446 |
1447 | def get_output_embeddings(self):
1448 | return self.lm_head
1449 |
1450 | def forward(
1451 | self,
1452 | input_ids=None,
1453 | attention_mask=None,
1454 | decoder_input_ids=None,
1455 | decoder_attention_mask=None,
1456 | head_mask=None,
1457 | decoder_head_mask=None,
1458 | encoder_outputs=None,
1459 | inputs_embeds=None,
1460 | decoder_inputs_embeds=None,
1461 | use_cache=None,
1462 | output_attentions=None,
1463 | output_hidden_states=None,
1464 | return_dict=None,
1465 | ):
1466 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1467 |
1468 | if input_ids is None and inputs_embeds is not None:
1469 | raise NotImplementedError(
1470 | f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
1471 | )
1472 |
1473 | outputs = self.model(
1474 | input_ids,
1475 | attention_mask=attention_mask,
1476 | decoder_input_ids=decoder_input_ids,
1477 | decoder_attention_mask=decoder_attention_mask,
1478 | head_mask=head_mask,
1479 | decoder_head_mask=decoder_head_mask,
1480 | encoder_outputs=encoder_outputs,
1481 | inputs_embeds=inputs_embeds,
1482 | decoder_inputs_embeds=decoder_inputs_embeds,
1483 | use_cache=use_cache,
1484 | output_attentions=output_attentions,
1485 | output_hidden_states=output_hidden_states,
1486 | return_dict=True,
1487 | )
1488 |
1489 | hidden_states = outputs.last_hidden_state
1490 | enc_hidden_states = outputs.encoder_last_hidden_state
1491 |
1492 | dec_logits = self.lm_head(hidden_states) + self.final_logits_bias
1493 | enc_logits = self.lm_head(enc_hidden_states) + self.final_logits_bias
1494 |
1495 | if not return_dict:
1496 | logits = (enc_logits, dec_logits)
1497 | output = (logits,) + outputs[1:]
1498 | return output
1499 |
1500 | return Seq2SeqLMOutput(
1501 | loss=None,
1502 | logits=(enc_logits, dec_logits),
1503 | past_key_values=outputs.past_key_values,
1504 | decoder_hidden_states=outputs.decoder_hidden_states,
1505 | decoder_attentions=outputs.decoder_attentions,
1506 | cross_attentions=outputs.cross_attentions,
1507 | encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1508 | encoder_hidden_states=outputs.encoder_hidden_states,
1509 | encoder_attentions=outputs.encoder_attentions,
1510 | )
--------------------------------------------------------------------------------
/Summary/modeling_cpt.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ PyTorch CPT model. modified from transformers==4.4.1"""
16 | import copy
17 | import math
18 | import random
19 | import warnings
20 | from typing import Optional, Tuple
21 |
22 | import torch
23 | import torch.nn.functional as F
24 | from torch.nn.modules.loss import KLDivLoss, NLLLoss
25 | import torch.utils.checkpoint
26 | from torch import nn
27 | from torch.nn import CrossEntropyLoss
28 |
29 | from transformers.activations import ACT2FN
30 | from transformers.file_utils import (
31 | add_code_sample_docstrings,
32 | add_end_docstrings,
33 | add_start_docstrings,
34 | add_start_docstrings_to_model_forward,
35 | replace_return_docstrings,
36 | )
37 | from transformers.modeling_outputs import (
38 | BaseModelOutput,
39 | BaseModelOutputWithPastAndCrossAttentions,
40 | CausalLMOutputWithCrossAttentions,
41 | Seq2SeqLMOutput,
42 | Seq2SeqModelOutput,
43 | Seq2SeqQuestionAnsweringModelOutput,
44 | Seq2SeqSequenceClassifierOutput,
45 | )
46 | from transformers.modeling_utils import PreTrainedModel
47 | from transformers.utils import logging
48 | from transformers import BartConfig as CPTConfig
49 | from transformers import BertModel, BertConfig
50 |
51 | from torch.nn import LayerNorm
52 |
53 | logger = logging.get_logger(__name__)
54 |
55 | _CHECKPOINT_FOR_DOC = "fnlp/cpt-large"
56 | _CONFIG_FOR_DOC = "CPTConfig"
57 | _TOKENIZER_FOR_DOC = "CPTTokenizer"
58 |
59 |
60 | CPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
61 | "fnlp/cpt-large",
62 | ]
63 |
64 |
65 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
66 | """
67 | Shift input ids one token to the right.
68 | """
69 | shifted_input_ids = input_ids.new_zeros(input_ids.shape)
70 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
71 | shifted_input_ids[:, 0] = decoder_start_token_id
72 |
73 | assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
74 | # replace possible -100 values in labels by `pad_token_id`
75 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
76 |
77 | return shifted_input_ids
78 |
79 |
80 | def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
81 | """
82 | Make causal mask used for bi-directional self-attention.
83 | """
84 | bsz, tgt_len = input_ids_shape
85 | mask = torch.full((tgt_len, tgt_len), float("-inf"))
86 | mask_cond = torch.arange(mask.size(-1))
87 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
88 | mask = mask.to(dtype)
89 |
90 | if past_key_values_length > 0:
91 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
92 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
93 |
94 |
95 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
96 | """
97 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
98 | """
99 | bsz, src_len = mask.size()
100 | tgt_len = tgt_len if tgt_len is not None else src_len
101 |
102 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
103 |
104 | inverted_mask = 1.0 - expanded_mask
105 |
106 | return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
107 |
108 | def attention_mask_func(attention_scores, attention_mask):
109 | return attention_scores + attention_mask
110 |
111 | def init_method(std):
112 | def init_(tensor):
113 | return torch.nn.init.normal_(tensor, mean=0.0, std=std)
114 |
115 | return init_
116 |
117 | class CPTLearnedPositionalEmbedding(nn.Embedding):
118 | """
119 | This module learns positional embeddings up to a fixed maximum size.
120 | """
121 |
122 | def __init__(self, num_embeddings: int, embedding_dim: int):
123 | # CPT is set up so that if padding_idx is specified then offset the embedding ids by 2
124 | # and adjust num_embeddings appropriately. Other models dont have this hack
125 | self.offset = 2
126 | super().__init__(num_embeddings + self.offset, embedding_dim)
127 |
128 | def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
129 | """`input_ids_shape` is expected to be [bsz x seqlen]."""
130 | bsz, seq_len = input_ids_shape[:2]
131 | positions = torch.arange(
132 | past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
133 | )
134 | return super().forward(positions + self.offset)
135 |
136 |
137 | class CPTAttention(nn.Module):
138 | """Multi-headed attention from 'Attention Is All You Need' paper"""
139 |
140 | def __init__(
141 | self,
142 | embed_dim: int,
143 | num_heads: int,
144 | dropout: float = 0.0,
145 | is_decoder: bool = False,
146 | bias: bool = True,
147 | ):
148 | super().__init__()
149 | self.embed_dim = embed_dim
150 | self.num_heads = num_heads
151 | self.dropout = dropout
152 | self.head_dim = embed_dim // num_heads
153 | assert (
154 | self.head_dim * num_heads == self.embed_dim
155 | ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
156 | self.scaling = self.head_dim ** -0.5
157 | self.is_decoder = is_decoder
158 |
159 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
160 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
161 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
162 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
163 |
164 |
165 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
166 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
167 |
168 | def forward(
169 | self,
170 | hidden_states: torch.Tensor,
171 | key_value_states: Optional[torch.Tensor] = None,
172 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
173 | attention_mask: Optional[torch.Tensor] = None,
174 | layer_head_mask: Optional[torch.Tensor] = None,
175 | output_attentions: bool = False,
176 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
177 | """Input shape: Batch x Time x Channel"""
178 |
179 | # if key_value_states are provided this layer is used as a cross-attention layer
180 | # for the decoder
181 | is_cross_attention = key_value_states is not None
182 | bsz, tgt_len, embed_dim = hidden_states.size()
183 |
184 | # get query proj
185 | query_states = self.q_proj(hidden_states) * self.scaling
186 | # get key, value proj
187 | if is_cross_attention and past_key_value is not None:
188 | # reuse k,v, cross_attentions
189 | key_states = past_key_value[0]
190 | value_states = past_key_value[1]
191 | elif is_cross_attention:
192 | # cross_attentions
193 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
194 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
195 | elif past_key_value is not None:
196 | # reuse k, v, self_attention
197 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
198 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
199 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
200 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
201 | else:
202 | # self_attention
203 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
204 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
205 |
206 | if self.is_decoder:
207 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
208 | # Further calls to cross_attention layer can then reuse all cross-attention
209 | # key/value_states (first "if" case)
210 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
211 | # all previous decoder key/value_states. Further calls to uni-directional self-attention
212 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
213 | # if encoder bi-directional self-attention `past_key_value` is always `None`
214 | past_key_value = (key_states, value_states)
215 |
216 | proj_shape = (bsz * self.num_heads, -1, self.head_dim)
217 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
218 | key_states = key_states.view(*proj_shape)
219 | value_states = value_states.view(*proj_shape)
220 |
221 | src_len = key_states.size(1)
222 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
223 |
224 | assert attn_weights.size() == (
225 | bsz * self.num_heads,
226 | tgt_len,
227 | src_len,
228 | ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
229 |
230 | if attention_mask is not None:
231 | assert attention_mask.size() == (
232 | bsz,
233 | 1,
234 | tgt_len,
235 | src_len,
236 | ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
237 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
238 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
239 |
240 | attn_weights = F.softmax(attn_weights, dim=-1)
241 |
242 | if layer_head_mask is not None:
243 | assert layer_head_mask.size() == (
244 | self.num_heads,
245 | ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
246 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
247 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
248 |
249 | if output_attentions:
250 | # this operation is a bit akward, but it's required to
251 | # make sure that attn_weights keeps its gradient.
252 | # In order to do so, attn_weights have to reshaped
253 | # twice and have to be reused in the following
254 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
255 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
256 | else:
257 | attn_weights_reshaped = None
258 |
259 | # with mpu.get_cuda_rng_tracker().fork():
260 | attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
261 |
262 | attn_output = torch.bmm(attn_probs, value_states)
263 |
264 | assert attn_output.size() == (
265 | bsz * self.num_heads,
266 | tgt_len,
267 | self.head_dim,
268 | ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
269 |
270 | attn_output = (
271 | attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
272 | .transpose(1, 2)
273 | .reshape(bsz, tgt_len, embed_dim)
274 | )
275 |
276 | attn_output = self.out_proj(attn_output)
277 |
278 | return attn_output, attn_weights_reshaped, past_key_value
279 |
280 | class CPTDecoderLayer(nn.Module):
281 | def __init__(self, config: CPTConfig):
282 | super().__init__()
283 | self.embed_dim = config.d_model
284 |
285 | self.self_attn = CPTAttention(
286 | embed_dim=self.embed_dim,
287 | num_heads=config.decoder_attention_heads,
288 | dropout=config.attention_dropout,
289 | is_decoder=True,
290 | )
291 | self.dropout = config.dropout
292 | self.activation_fn = ACT2FN[config.activation_function]
293 | self.activation_dropout = config.activation_dropout
294 |
295 | self.self_attn_layer_norm = LayerNorm(self.embed_dim)
296 | self.encoder_attn = CPTAttention(
297 | self.embed_dim,
298 | config.decoder_attention_heads,
299 | dropout=config.attention_dropout,
300 | is_decoder=True,
301 | )
302 | self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
303 | self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
304 | self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
305 | self.final_layer_norm = LayerNorm(self.embed_dim)
306 |
307 | def forward(
308 | self,
309 | hidden_states: torch.Tensor,
310 | attention_mask: Optional[torch.Tensor] = None,
311 | encoder_hidden_states: Optional[torch.Tensor] = None,
312 | encoder_attention_mask: Optional[torch.Tensor] = None,
313 | layer_head_mask: Optional[torch.Tensor] = None,
314 | encoder_layer_head_mask: Optional[torch.Tensor] = None,
315 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
316 | output_attentions: Optional[bool] = False,
317 | use_cache: Optional[bool] = True,
318 | ):
319 | """
320 | Args:
321 | hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
322 | attention_mask (:obj:`torch.FloatTensor`): attention mask of size
323 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
324 | encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
325 | encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
326 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
327 | layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
328 | `(config.encoder_attention_heads,)`.
329 | encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
330 | size `(config.encoder_attention_heads,)`.
331 | past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
332 | output_attentions (:obj:`bool`, `optional`):
333 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
334 | returned tensors for more detail.
335 | """
336 | residual = hidden_states
337 |
338 | # Self Attention
339 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
340 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
341 | # add present self-attn cache to positions 1,2 of present_key_value tuple
342 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
343 | hidden_states=hidden_states,
344 | past_key_value=self_attn_past_key_value,
345 | attention_mask=attention_mask,
346 | layer_head_mask=layer_head_mask,
347 | output_attentions=output_attentions,
348 | )
349 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
350 | hidden_states = residual + hidden_states
351 | hidden_states = self.self_attn_layer_norm(hidden_states)
352 |
353 | # Cross-Attention Block
354 | cross_attn_present_key_value = None
355 | cross_attn_weights = None
356 | if encoder_hidden_states is not None:
357 | residual = hidden_states
358 |
359 | # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
360 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
361 | hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
362 | hidden_states=hidden_states,
363 | key_value_states=encoder_hidden_states,
364 | attention_mask=encoder_attention_mask,
365 | layer_head_mask=encoder_layer_head_mask,
366 | past_key_value=cross_attn_past_key_value,
367 | output_attentions=output_attentions,
368 | )
369 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
370 | hidden_states = residual + hidden_states
371 | hidden_states = self.encoder_attn_layer_norm(hidden_states)
372 |
373 | # add cross-attn to positions 3,4 of present_key_value tuple
374 | present_key_value = present_key_value + cross_attn_present_key_value
375 |
376 | # Fully Connected
377 | residual = hidden_states
378 | hidden_states = self.activation_fn(self.fc1(hidden_states))
379 | hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
380 | hidden_states = self.fc2(hidden_states)
381 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
382 | hidden_states = residual + hidden_states
383 | hidden_states = self.final_layer_norm(hidden_states)
384 |
385 | outputs = (hidden_states,)
386 |
387 | if output_attentions:
388 | outputs += (self_attn_weights, cross_attn_weights)
389 |
390 | if use_cache:
391 | outputs += (present_key_value,)
392 |
393 | return outputs
394 |
395 |
396 | class CPTClassificationHead(nn.Module):
397 | """Head for sentence-level classification tasks."""
398 |
399 | def __init__(
400 | self,
401 | input_dim: int,
402 | inner_dim: int,
403 | num_classes: int,
404 | pooler_dropout: float,
405 | ):
406 | super().__init__()
407 | self.dense = nn.Linear(input_dim, inner_dim)
408 | self.dropout = nn.Dropout(p=pooler_dropout)
409 | self.out_proj = nn.Linear(inner_dim, num_classes)
410 |
411 | def forward(self, hidden_states: torch.Tensor):
412 | hidden_states = self.dropout(hidden_states)
413 | hidden_states = self.dense(hidden_states)
414 | hidden_states = torch.tanh(hidden_states)
415 | hidden_states = self.dropout(hidden_states)
416 | hidden_states = self.out_proj(hidden_states)
417 | return hidden_states
418 |
419 |
420 | class CPTPretrainedModel(PreTrainedModel):
421 | config_class = CPTConfig
422 | base_model_prefix = "model"
423 |
424 | def _init_weights(self, module):
425 | std = self.config.init_std
426 | if isinstance(module, nn.Linear):
427 | module.weight.data.normal_(mean=0.0, std=std)
428 | if module.bias is not None:
429 | module.bias.data.zero_()
430 | elif isinstance(module, nn.Embedding):
431 | module.weight.data.normal_(mean=0.0, std=std)
432 | if module.padding_idx is not None:
433 | module.weight.data[module.padding_idx].zero_()
434 |
435 | @property
436 | def dummy_inputs(self):
437 | pad_token = self.config.pad_token_id
438 | input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
439 | dummy_inputs = {
440 | "attention_mask": input_ids.ne(pad_token),
441 | "input_ids": input_ids,
442 | }
443 | return dummy_inputs
444 |
445 | CPT_START_DOCSTRING = r"""
446 | This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
447 | methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
448 | pruning heads etc.)
449 |
450 | This model is also a PyTorch `torch.nn.Module `__
451 | subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
452 | general usage and behavior.
453 |
454 | Parameters:
455 | config (:class:`~transformers.CPTConfig`):
456 | Model configuration class with all the parameters of the model. Initializing with a config file does not
457 | load the weights associated with the model, only the configuration. Check out the
458 | :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
459 | """
460 |
461 | CPT_INPUTS_DOCSTRING = r"""
462 | Args:
463 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
464 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
465 | it.
466 |
467 | Indices can be obtained using :class:`~transformers.CPTTokenizer`. See
468 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
469 | details.
470 |
471 | `What are input IDs? <../glossary.html#input-ids>`__
472 | attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
473 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
474 |
475 | - 1 for tokens that are **not masked**,
476 | - 0 for tokens that are **masked**.
477 |
478 | `What are attention masks? <../glossary.html#attention-mask>`__
479 | decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
480 | Indices of decoder input sequence tokens in the vocabulary.
481 |
482 | Indices can be obtained using :class:`~transformers.CPTTokenizer`. See
483 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
484 | details.
485 |
486 | `What are input IDs? <../glossary.html#input-ids>`__
487 |
488 | CPT uses the :obj:`eos_token_id` as the starting token for :obj:`decoder_input_ids` generation. If
489 | :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see
490 | :obj:`past_key_values`).
491 |
492 | For translation and summarization training, :obj:`decoder_input_ids` should be provided. If no
493 | :obj:`decoder_input_ids` is provided, the model will create this tensor by shifting the :obj:`input_ids` to
494 | the right for denoising pre-training following the paper.
495 | decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
496 | Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
497 | also be used by default.
498 |
499 | If you want to change padding behavior, you should read :func:`modeling_cpt._prepare_decoder_inputs` and
500 | modify to your needs. See diagram 1 in `the paper `__ for more
501 | information on the default strategy.
502 | head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
503 | Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
504 |
505 | - 1 indicates the head is **not masked**,
506 | - 0 indicates the heas is **masked**.
507 |
508 | decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
509 | Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
510 |
511 | - 1 indicates the head is **not masked**,
512 | - 0 indicates the head is **masked**.
513 |
514 | encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
515 | Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
516 | :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
517 | `optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
518 | cross-attention of the decoder.
519 | past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
520 | Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
521 |
522 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
523 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
524 | instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
525 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
526 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
527 | This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
528 | vectors than the model's internal embedding lookup matrix.
529 | decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`):
530 | Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded
531 | representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds`
532 | have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert
533 | :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
534 |
535 | If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds`
536 | takes the value of :obj:`inputs_embeds`.
537 | use_cache (:obj:`bool`, `optional`):
538 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
539 | decoding (see :obj:`past_key_values`).
540 | output_attentions (:obj:`bool`, `optional`):
541 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
542 | tensors for more detail.
543 | output_hidden_states (:obj:`bool`, `optional`):
544 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
545 | more detail.
546 | return_dict (:obj:`bool`, `optional`):
547 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
548 | """
549 |
550 | class CPTDecoder(CPTPretrainedModel):
551 | """
552 | Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`CPTDecoderLayer`
553 |
554 | Args:
555 | config: CPTConfig
556 | embed_tokens (torch.nn.Embedding): output embedding
557 | """
558 |
559 | def __init__(self, config: CPTConfig, embed_tokens: Optional[nn.Embedding] = None):
560 | super().__init__(config)
561 | self.dropout = config.dropout
562 | self.layerdrop = config.decoder_layerdrop
563 | self.padding_idx = config.pad_token_id
564 | self.max_target_positions = config.max_position_embeddings
565 | self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
566 |
567 | if embed_tokens is not None:
568 | self.embed_tokens = embed_tokens
569 | else:
570 | self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
571 |
572 | self.embed_positions = CPTLearnedPositionalEmbedding(
573 | config.max_position_embeddings,
574 | config.d_model,
575 | )
576 | self.layers = nn.ModuleList([CPTDecoderLayer(config) for _ in range(config.decoder_layers)])
577 | self.layernorm_embedding = LayerNorm(config.d_model)
578 |
579 | self.init_weights()
580 |
581 | def get_input_embeddings(self):
582 | return self.embed_tokens
583 |
584 | def set_input_embeddings(self, value):
585 | self.embed_tokens = value
586 |
587 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
588 | # create causal mask
589 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
590 | combined_attention_mask = None
591 | if input_shape[-1] > 1:
592 | combined_attention_mask = _make_causal_mask(
593 | input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
594 | ).to(self.device)
595 |
596 | if attention_mask is not None:
597 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
598 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
599 | combined_attention_mask = (
600 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
601 | )
602 |
603 | return combined_attention_mask
604 |
605 | def forward(
606 | self,
607 | input_ids=None,
608 | attention_mask=None,
609 | encoder_hidden_states=None,
610 | encoder_attention_mask=None,
611 | head_mask=None,
612 | encoder_head_mask=None,
613 | past_key_values=None,
614 | inputs_embeds=None,
615 | use_cache=None,
616 | output_attentions=None,
617 | output_hidden_states=None,
618 | return_dict=None,
619 | ):
620 | r"""
621 | Args:
622 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
623 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
624 | provide it.
625 |
626 | Indices can be obtained using :class:`~transformers.CPTTokenizer`. See
627 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
628 | for details.
629 |
630 | `What are input IDs? <../glossary.html#input-ids>`__
631 | attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
632 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
633 |
634 | - 1 for tokens that are **not masked**,
635 | - 0 for tokens that are **masked**.
636 |
637 | `What are attention masks? <../glossary.html#attention-mask>`__
638 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`):
639 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
640 | of the decoder.
641 | encoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, encoder_sequence_length)`, `optional`):
642 | Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
643 | selected in ``[0, 1]``:
644 |
645 | - 1 for tokens that are **not masked**,
646 | - 0 for tokens that are **masked**.
647 |
648 | `What are attention masks? <../glossary.html#attention-mask>`__
649 | head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
650 | Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
651 |
652 | - 1 indicates the head is **not masked**,
653 | - 0 indicates the heas is **masked**.
654 |
655 | encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
656 | Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
657 | on hidden heads. Mask values selected in ``[0, 1]``:
658 |
659 | - 1 indicates the head is **not masked**,
660 | - 0 indicates the heas is **masked**.
661 |
662 | past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
663 | Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
664 | decoding.
665 |
666 | If :obj:`past_key_values` are used, the user can optionally input only the last
667 | :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of
668 | shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size,
669 | sequence_length)`.
670 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
671 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
672 | representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
673 | into associated vectors than the model's internal embedding lookup matrix.
674 | output_attentions (:obj:`bool`, `optional`):
675 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
676 | returned tensors for more detail.
677 | output_hidden_states (:obj:`bool`, `optional`):
678 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
679 | for more detail.
680 | return_dict (:obj:`bool`, `optional`):
681 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
682 | """
683 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
684 | output_hidden_states = (
685 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
686 | )
687 | use_cache = use_cache if use_cache is not None else self.config.use_cache
688 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
689 |
690 | # retrieve input_ids and inputs_embeds
691 | if input_ids is not None and inputs_embeds is not None:
692 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
693 | elif input_ids is not None:
694 | input_shape = input_ids.size()
695 | input_ids = input_ids.view(-1, input_shape[-1])
696 | elif inputs_embeds is not None:
697 | input_shape = inputs_embeds.size()[:-1]
698 | else:
699 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
700 |
701 | # past_key_values_length
702 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
703 |
704 | if inputs_embeds is None:
705 | inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
706 |
707 | attention_mask = self._prepare_decoder_attention_mask(
708 | attention_mask, input_shape, inputs_embeds, past_key_values_length
709 | )
710 |
711 | # expand encoder attention mask
712 | if encoder_hidden_states is not None and encoder_attention_mask is not None:
713 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
714 | encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
715 |
716 | # embed positions
717 | positions = self.embed_positions(input_shape, past_key_values_length)
718 |
719 | hidden_states = inputs_embeds + positions
720 | hidden_states = self.layernorm_embedding(hidden_states)
721 |
722 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
723 |
724 | # decoder layers
725 | all_hidden_states = () if output_hidden_states else None
726 | all_self_attns = () if output_attentions else None
727 | all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
728 | next_decoder_cache = () if use_cache else None
729 |
730 | # check if head_mask has a correct number of layers specified if desired
731 | if head_mask is not None:
732 | assert head_mask.size()[0] == (
733 | len(self.layers)
734 | ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
735 | for idx, decoder_layer in enumerate(self.layers):
736 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
737 | if output_hidden_states:
738 | all_hidden_states += (hidden_states,)
739 | dropout_probability = random.uniform(0, 1)
740 | if self.training and (dropout_probability < self.layerdrop):
741 | continue
742 |
743 | past_key_value = past_key_values[idx] if past_key_values is not None else None
744 |
745 | if getattr(self.config, "gradient_checkpointing", False) and self.training:
746 |
747 | if use_cache:
748 | logger.warn(
749 | "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
750 | "`use_cache=False`..."
751 | )
752 | use_cache = False
753 |
754 | def create_custom_forward(module):
755 | def custom_forward(*inputs):
756 | # None for past_key_value
757 | return module(*inputs, output_attentions, use_cache)
758 |
759 | return custom_forward
760 |
761 | # layer_outputs = mpu.checkpoint(
762 | layer_outputs = torch.utils.checkpoint(
763 | create_custom_forward(decoder_layer),
764 | hidden_states,
765 | attention_mask,
766 | encoder_hidden_states,
767 | encoder_attention_mask,
768 | head_mask[idx] if head_mask is not None else None,
769 | encoder_head_mask[idx] if encoder_head_mask is not None else None,
770 | None,
771 | )
772 | else:
773 |
774 | layer_outputs = decoder_layer(
775 | hidden_states,
776 | attention_mask=attention_mask,
777 | encoder_hidden_states=encoder_hidden_states,
778 | encoder_attention_mask=encoder_attention_mask,
779 | layer_head_mask=(head_mask[idx] if head_mask is not None else None),
780 | encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
781 | past_key_value=past_key_value,
782 | output_attentions=output_attentions,
783 | use_cache=use_cache,
784 | )
785 | hidden_states = layer_outputs[0]
786 |
787 | if use_cache:
788 | next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
789 |
790 | if output_attentions:
791 | all_self_attns += (layer_outputs[1],)
792 |
793 | if encoder_hidden_states is not None:
794 | all_cross_attentions += (layer_outputs[2],)
795 |
796 | # add hidden states from the last decoder layer
797 | if output_hidden_states:
798 | all_hidden_states += (hidden_states,)
799 |
800 | next_cache = next_decoder_cache if use_cache else None
801 | if not return_dict:
802 | return tuple(
803 | v
804 | for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
805 | if v is not None
806 | )
807 | return BaseModelOutputWithPastAndCrossAttentions(
808 | last_hidden_state=hidden_states,
809 | past_key_values=next_cache,
810 | hidden_states=all_hidden_states,
811 | attentions=all_self_attns,
812 | cross_attentions=all_cross_attentions,
813 | )
814 |
815 |
816 | @add_start_docstrings(
817 | "The bare CPT Model outputting raw hidden-states without any specific head on top.",
818 | CPT_START_DOCSTRING,
819 | )
820 | class CPTModel(CPTPretrainedModel):
821 | def __init__(self, config: CPTConfig):
822 | super().__init__(config)
823 | encoder_config = BertConfig(
824 | vocab_size=config.vocab_size,
825 | hidden_size=config.d_model,
826 | num_hidden_layers=config.encoder_layers,
827 | num_attention_heads=config.encoder_attention_heads,
828 | intermediate_size=config.encoder_ffn_dim,
829 | hidden_dropout_prob=config.activation_dropout,
830 | attention_probs_dropout_prob=config.attention_dropout,
831 | )
832 | config.vocab_size = encoder_config.vocab_size
833 | self.encoder = BertModel(encoder_config, add_pooling_layer=False)
834 | self.shared = self.encoder.get_input_embeddings()
835 | self.decoder = CPTDecoder(config, self.shared)
836 | self.num_decoder_layers = config.decoder_layers
837 | self.init_weights()
838 |
839 | def get_input_embeddings(self):
840 | return self.shared
841 |
842 | def set_input_embeddings(self, value):
843 | self.shared = value
844 | self.encoder.embed_tokens = self.shared
845 | self.decoder.embed_tokens = self.shared
846 |
847 | def get_encoder(self):
848 | class _Encoder(torch.nn.Module):
849 | def __init__(self, encoder):
850 | super().__init__()
851 | self.encoder = encoder
852 |
853 | def forward(self, *args, **kwargs):
854 | kwargs['output_hidden_states'] = True
855 | return self.encoder(*args, **kwargs)
856 | return _Encoder(self.encoder)
857 |
858 | def get_decoder(self):
859 | return self.decoder
860 |
861 | @add_start_docstrings_to_model_forward(CPT_INPUTS_DOCSTRING)
862 | @add_code_sample_docstrings(
863 | checkpoint=_CHECKPOINT_FOR_DOC,
864 | output_type=Seq2SeqModelOutput,
865 | )
866 | def forward(
867 | self,
868 | input_ids=None,
869 | token_type_ids=None,
870 | attention_mask=None,
871 | decoder_input_ids=None,
872 | decoder_attention_mask=None,
873 | head_mask=None,
874 | decoder_head_mask=None,
875 | encoder_outputs=None,
876 | past_key_values=None,
877 | inputs_embeds=None,
878 | decoder_inputs_embeds=None,
879 | use_cache=None,
880 | output_attentions=None,
881 | output_hidden_states=None,
882 | return_dict=None,
883 | ):
884 |
885 | # different to other models, CPT automatically creates decoder_input_ids from
886 | # input_ids if no decoder_input_ids are provided
887 | if decoder_input_ids is None and decoder_inputs_embeds is None:
888 | decoder_input_ids = shift_tokens_right(
889 | input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
890 | )
891 |
892 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
893 | output_hidden_states = (
894 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
895 | )
896 | use_cache = use_cache if use_cache is not None else self.config.use_cache
897 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
898 |
899 | if getattr(self.config, "gradient_checkpointing", False) and self.training:
900 | # mpu.reset_checkpointed_activations_memory_buffer()
901 | use_cache = False
902 |
903 | if encoder_outputs is None:
904 | encoder_outputs = self.encoder(
905 | input_ids=input_ids,
906 | attention_mask=attention_mask,
907 | token_type_ids=token_type_ids,
908 | head_mask=head_mask,
909 | inputs_embeds=inputs_embeds,
910 | output_attentions=output_attentions,
911 | output_hidden_states=True,
912 | return_dict=return_dict,
913 | )
914 | # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
915 | elif return_dict and isinstance(encoder_outputs, (tuple, list)):
916 | encoder_outputs = BaseModelOutput(
917 | last_hidden_state=encoder_outputs[0],
918 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
919 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
920 | )
921 |
922 | if isinstance(encoder_outputs, (torch.Tensor)):
923 | encoder_hidden_states = encoder_outputs
924 | else:
925 | encoder_hidden_states = encoder_outputs[1][-self.num_decoder_layers - 1]
926 |
927 | # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
928 | decoder_outputs = self.decoder(
929 | input_ids=decoder_input_ids,
930 | attention_mask=decoder_attention_mask,
931 | encoder_hidden_states=encoder_hidden_states,
932 | encoder_attention_mask=attention_mask,
933 | head_mask=decoder_head_mask,
934 | encoder_head_mask=head_mask,
935 | past_key_values=past_key_values,
936 | inputs_embeds=decoder_inputs_embeds,
937 | use_cache=use_cache,
938 | output_attentions=output_attentions,
939 | output_hidden_states=output_hidden_states,
940 | return_dict=return_dict,
941 | )
942 |
943 | if not return_dict:
944 | return decoder_outputs + encoder_outputs
945 |
946 | return Seq2SeqModelOutput(
947 | last_hidden_state=decoder_outputs.last_hidden_state,
948 | past_key_values=decoder_outputs.past_key_values,
949 | decoder_hidden_states=decoder_outputs.hidden_states,
950 | decoder_attentions=decoder_outputs.attentions,
951 | cross_attentions=decoder_outputs.cross_attentions,
952 | encoder_last_hidden_state=encoder_outputs.last_hidden_state if isinstance(encoder_outputs, dict) else None,
953 | encoder_hidden_states=encoder_outputs.hidden_states if isinstance(encoder_outputs, dict) else None,
954 | encoder_attentions=encoder_outputs.attentions if isinstance(encoder_outputs, dict) else None,
955 | )
956 |
957 |
958 | @add_start_docstrings(
959 | "The CPT Model with a language modeling head. Can be used for summarization.", CPT_START_DOCSTRING
960 | )
961 | class CPTForConditionalGeneration(CPTPretrainedModel):
962 | base_model_prefix = "model"
963 | _keys_to_ignore_on_load_missing = [
964 | r"final_logits_bias",
965 | r"encoder\.version",
966 | r"decoder\.version",
967 | r"lm_head\.weight",
968 | ]
969 |
970 | def __init__(self, config):
971 | super().__init__(config)
972 | self.model = CPTModel(config)
973 | self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
974 | self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
975 |
976 | self.init_weights()
977 |
978 | def get_encoder(self):
979 | return self.model.get_encoder()
980 |
981 | def get_decoder(self):
982 | return self.model.get_decoder()
983 |
984 | def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
985 | new_embeddings = super().resize_token_embeddings(new_num_tokens)
986 | self._resize_final_logits_bias(new_num_tokens)
987 | return new_embeddings
988 |
989 | def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
990 | old_num_tokens = self.final_logits_bias.shape[-1]
991 | if new_num_tokens <= old_num_tokens:
992 | new_bias = self.final_logits_bias[:, :new_num_tokens]
993 | else:
994 | extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
995 | new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
996 | self.register_buffer("final_logits_bias", new_bias)
997 |
998 | def get_output_embeddings(self):
999 | return self.lm_head
1000 |
1001 | def set_output_embeddings(self, new_embeddings):
1002 | self.lm_head = new_embeddings
1003 |
1004 | @add_start_docstrings_to_model_forward(CPT_INPUTS_DOCSTRING)
1005 | @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1006 | def forward(
1007 | self,
1008 | input_ids=None,
1009 | attention_mask=None,
1010 | token_type_ids=None,
1011 | decoder_input_ids=None,
1012 | decoder_attention_mask=None,
1013 | head_mask=None,
1014 | decoder_head_mask=None,
1015 | encoder_outputs=None,
1016 | past_key_values=None,
1017 | inputs_embeds=None,
1018 | decoder_inputs_embeds=None,
1019 | labels=None,
1020 | use_cache=None,
1021 | output_attentions=None,
1022 | output_hidden_states=None,
1023 | return_dict=None,
1024 | ):
1025 | r"""
1026 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1027 | Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
1028 | config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored
1029 | (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``.
1030 |
1031 | Returns:
1032 | """
1033 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1034 |
1035 | if labels is not None:
1036 | if decoder_input_ids is None:
1037 | decoder_input_ids = shift_tokens_right(
1038 | labels, self.config.pad_token_id, self.config.decoder_start_token_id
1039 | )
1040 |
1041 | outputs = self.model(
1042 | input_ids,
1043 | attention_mask=attention_mask,
1044 | decoder_input_ids=decoder_input_ids,
1045 | encoder_outputs=encoder_outputs,
1046 | token_type_ids=token_type_ids,
1047 | decoder_attention_mask=decoder_attention_mask,
1048 | head_mask=head_mask,
1049 | decoder_head_mask=decoder_head_mask,
1050 | past_key_values=past_key_values,
1051 | inputs_embeds=inputs_embeds,
1052 | decoder_inputs_embeds=decoder_inputs_embeds,
1053 | use_cache=use_cache,
1054 | output_attentions=output_attentions,
1055 | output_hidden_states=output_hidden_states,
1056 | return_dict=return_dict,
1057 | )
1058 | lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
1059 |
1060 | masked_lm_loss = None
1061 | if labels is not None:
1062 | loss_fct = CrossEntropyLoss()
1063 | masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
1064 |
1065 | if not return_dict:
1066 | output = (lm_logits,) + outputs[1:]
1067 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1068 |
1069 | return Seq2SeqLMOutput(
1070 | loss=masked_lm_loss,
1071 | logits=lm_logits,
1072 | past_key_values=outputs.past_key_values,
1073 | decoder_hidden_states=outputs.decoder_hidden_states,
1074 | decoder_attentions=outputs.decoder_attentions,
1075 | cross_attentions=outputs.cross_attentions,
1076 | encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1077 | encoder_hidden_states=outputs.encoder_hidden_states,
1078 | encoder_attentions=outputs.encoder_attentions,
1079 | )
1080 |
1081 | def prepare_inputs_for_generation(
1082 | self,
1083 | decoder_input_ids,
1084 | past=None,
1085 | attention_mask=None,
1086 | head_mask=None,
1087 | use_cache=None,
1088 | encoder_outputs=None,
1089 | **kwargs
1090 | ):
1091 | # cut decoder_input_ids if past is used
1092 | if past is not None:
1093 | decoder_input_ids = decoder_input_ids[:, -1:]
1094 |
1095 | return {
1096 | "input_ids": None, # encoder_outputs is defined. input_ids not needed
1097 | "encoder_outputs": encoder_outputs,
1098 | "past_key_values": past,
1099 | "decoder_input_ids": decoder_input_ids,
1100 | "attention_mask": attention_mask,
1101 | "head_mask": head_mask,
1102 | "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1103 | }
1104 |
1105 | @staticmethod
1106 | def _expand_inputs_for_generation(
1107 | input_ids: torch.LongTensor,
1108 | expand_size: int = 1,
1109 | is_encoder_decoder: bool = False,
1110 | attention_mask: torch.LongTensor = None,
1111 | encoder_outputs = None,
1112 | **model_kwargs,
1113 | ):
1114 | expanded_return_idx = (
1115 | torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
1116 | )
1117 | input_ids = input_ids.index_select(0, expanded_return_idx)
1118 |
1119 | if "token_type_ids" in model_kwargs:
1120 | token_type_ids = model_kwargs["token_type_ids"]
1121 | model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)
1122 |
1123 | if attention_mask is not None:
1124 | model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
1125 |
1126 | if is_encoder_decoder:
1127 | assert encoder_outputs is not None
1128 | device = encoder_outputs.last_hidden_state.device
1129 | encoder_outputs["hidden_states"] = tuple(h.index_select(0, expanded_return_idx.to(device)) \
1130 | for h in encoder_outputs["hidden_states"])
1131 | model_kwargs["encoder_outputs"] = encoder_outputs
1132 | return input_ids, model_kwargs
1133 |
1134 | def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1135 | return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
1136 |
1137 | @staticmethod
1138 | def _reorder_cache(past, beam_idx):
1139 | reordered_past = ()
1140 | for layer_past in past:
1141 | # cached cross_attention states don't have to be reordered -> they are always the same
1142 | reordered_past += (
1143 | tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
1144 | )
1145 | return reordered_past
1146 |
1147 |
1148 | @add_start_docstrings(
1149 | """
1150 | CPT model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
1151 | tasks.
1152 | """,
1153 | CPT_START_DOCSTRING,
1154 | )
1155 | class CPTForSequenceClassification(CPTPretrainedModel):
1156 | def __init__(self, config: CPTConfig, cls_mode=1, **kwargs):
1157 | super().__init__(config, **kwargs)
1158 | self.model = CPTModel(config)
1159 | cls_mode = getattr(config, 'cls_mode', cls_mode)
1160 | if cls_mode == 1:
1161 | logger.info('Encoder for classification.')
1162 | cls_dim = config.d_model
1163 | elif cls_mode == 2:
1164 | logger.info('Decoder for classification.')
1165 | cls_dim = config.d_model
1166 | elif cls_mode == 3:
1167 | logger.info('Both encoder & decoder for classification.')
1168 | cls_dim = config.d_model * 2
1169 | else:
1170 | raise NotImplementedError
1171 |
1172 | self.cls_head = CPTClassificationHead(
1173 | cls_dim,
1174 | cls_dim,
1175 | config.num_labels,
1176 | config.classifier_dropout,
1177 | )
1178 | self.model._init_weights(self.cls_head.dense)
1179 | self.model._init_weights(self.cls_head.out_proj)
1180 | self.cls_mode = cls_mode
1181 | config.cls_mode = cls_mode
1182 |
1183 | @add_start_docstrings_to_model_forward(CPT_INPUTS_DOCSTRING)
1184 | @add_code_sample_docstrings(
1185 | checkpoint=_CHECKPOINT_FOR_DOC,
1186 | output_type=Seq2SeqSequenceClassifierOutput,
1187 | config_class=_CONFIG_FOR_DOC,
1188 | )
1189 | def forward(
1190 | self,
1191 | input_ids=None,
1192 | attention_mask=None,
1193 | decoder_input_ids=None,
1194 | decoder_attention_mask=None,
1195 | head_mask=None,
1196 | decoder_head_mask=None,
1197 | encoder_outputs=None,
1198 | inputs_embeds=None,
1199 | decoder_inputs_embeds=None,
1200 | labels=None,
1201 | use_cache=None,
1202 | output_attentions=None,
1203 | output_hidden_states=None,
1204 | return_dict=None,
1205 | ):
1206 | r"""
1207 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1208 | Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1209 | config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1210 | """
1211 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1212 | if labels is not None:
1213 | use_cache = False
1214 |
1215 | if input_ids is None and inputs_embeds is not None:
1216 | raise NotImplementedError(
1217 | f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
1218 | )
1219 |
1220 | outputs = self.model(
1221 | input_ids,
1222 | attention_mask=attention_mask,
1223 | decoder_input_ids=decoder_input_ids,
1224 | decoder_attention_mask=decoder_attention_mask,
1225 | head_mask=head_mask,
1226 | decoder_head_mask=decoder_head_mask,
1227 | encoder_outputs=encoder_outputs,
1228 | inputs_embeds=inputs_embeds,
1229 | decoder_inputs_embeds=decoder_inputs_embeds,
1230 | use_cache=use_cache,
1231 | output_attentions=output_attentions,
1232 | output_hidden_states=output_hidden_states,
1233 | return_dict=True,
1234 | )
1235 |
1236 | hidden_states = outputs.last_hidden_state
1237 | enc_hidden_states = outputs.encoder_last_hidden_state
1238 | enc_rep = enc_hidden_states[:, 0]
1239 |
1240 | eos_mask = input_ids.eq(self.config.eos_token_id)
1241 |
1242 | if len(torch.unique(eos_mask.sum(1))) > 1:
1243 | raise ValueError("All examples must have the same number of tokens.")
1244 | dec_rep = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
1245 | :, -1, :
1246 | ]
1247 |
1248 | if self.cls_mode == 1:
1249 | logits = self.cls_head(enc_rep)
1250 | elif self.cls_mode == 2:
1251 | logits = self.cls_head(dec_rep)
1252 | elif self.cls_mode == 3:
1253 | rep = torch.cat([enc_rep, dec_rep], dim=-1)
1254 | logits = self.cls_head(rep)
1255 | else:
1256 | raise NotImplementedError
1257 |
1258 | loss = None
1259 | if labels is not None:
1260 | loss_fct = CrossEntropyLoss()
1261 | loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1262 |
1263 | if not return_dict:
1264 | output = (logits,) + outputs[1:]
1265 | return ((loss,) + output) if loss is not None else output
1266 |
1267 | return Seq2SeqSequenceClassifierOutput(
1268 | loss=loss,
1269 | logits=logits,
1270 | past_key_values=outputs.past_key_values,
1271 | decoder_hidden_states=outputs.decoder_hidden_states,
1272 | decoder_attentions=outputs.decoder_attentions,
1273 | cross_attentions=outputs.cross_attentions,
1274 | encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1275 | encoder_hidden_states=outputs.encoder_hidden_states,
1276 | encoder_attentions=outputs.encoder_attentions,
1277 | )
1278 |
1279 |
1280 | @add_start_docstrings(
1281 | """
1282 | CPT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1283 | layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1284 | """,
1285 | CPT_START_DOCSTRING,
1286 | )
1287 | class CPTForQuestionAnswering(CPTPretrainedModel):
1288 | def __init__(self, config: CPTConfig, cls_mode=1, **kwargs):
1289 | super().__init__(config, **kwargs)
1290 | config.num_labels = 2
1291 | self.num_labels = config.num_labels
1292 |
1293 | self.model = CPTModel(config)
1294 |
1295 | cls_mode = getattr(config, 'cls_mode', cls_mode)
1296 | if cls_mode == 1:
1297 | logger.info('Encoder for classification.')
1298 | cls_dim = config.d_model
1299 | elif cls_mode == 2:
1300 | logger.info('Decoder for classification.')
1301 | cls_dim = config.d_model
1302 | elif cls_mode == 3:
1303 | logger.info('Both encoder & decoder for classification.')
1304 | cls_dim = config.d_model * 2
1305 | else:
1306 | raise NotImplementedError
1307 |
1308 | self.qa_outputs = nn.Linear(cls_dim, config.num_labels)
1309 | self.model._init_weights(self.qa_outputs)
1310 |
1311 | self.cls_mode = cls_mode
1312 | config.cls_mode = cls_mode
1313 |
1314 | @add_start_docstrings_to_model_forward(CPT_INPUTS_DOCSTRING)
1315 | @add_code_sample_docstrings(
1316 | checkpoint=_CHECKPOINT_FOR_DOC,
1317 | output_type=Seq2SeqSequenceClassifierOutput,
1318 | config_class=_CONFIG_FOR_DOC,
1319 | )
1320 | def forward(
1321 | self,
1322 | input_ids=None,
1323 | attention_mask=None,
1324 | decoder_input_ids=None,
1325 | decoder_attention_mask=None,
1326 | head_mask=None,
1327 | decoder_head_mask=None,
1328 | encoder_outputs=None,
1329 | start_positions=None,
1330 | end_positions=None,
1331 | inputs_embeds=None,
1332 | decoder_inputs_embeds=None,
1333 | use_cache=None,
1334 | output_attentions=None,
1335 | output_hidden_states=None,
1336 | return_dict=None,
1337 | ):
1338 | r"""
1339 | start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1340 | Labels for position (index) of the start of the labelled span for computing the token classification loss.
1341 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1342 | are not taken into account for computing the loss.
1343 | end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1344 | Labels for position (index) of the end of the labelled span for computing the token classification loss.
1345 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1346 | are not taken into account for computing the loss.
1347 | """
1348 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1349 |
1350 | if input_ids is None and inputs_embeds is not None:
1351 | raise NotImplementedError(
1352 | f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
1353 | )
1354 |
1355 | outputs = self.model(
1356 | input_ids,
1357 | attention_mask=attention_mask,
1358 | decoder_input_ids=decoder_input_ids,
1359 | decoder_attention_mask=decoder_attention_mask,
1360 | head_mask=head_mask,
1361 | decoder_head_mask=decoder_head_mask,
1362 | encoder_outputs=encoder_outputs,
1363 | inputs_embeds=inputs_embeds,
1364 | decoder_inputs_embeds=decoder_inputs_embeds,
1365 | use_cache=use_cache,
1366 | output_attentions=output_attentions,
1367 | output_hidden_states=output_hidden_states,
1368 | return_dict=True,
1369 | )
1370 |
1371 | hidden_states = outputs.last_hidden_state
1372 | enc_hidden_states = outputs.encoder_last_hidden_state
1373 |
1374 | if self.cls_mode == 1:
1375 | logits = self.qa_outputs(enc_hidden_states)
1376 | elif self.cls_mode == 2:
1377 | logits = self.qa_outputs(hidden_states)
1378 | elif self.cls_mode == 3:
1379 | rep = torch.cat([enc_hidden_states, hidden_states], dim=-1)
1380 | logits = self.qa_outputs(rep)
1381 | else:
1382 | raise NotImplementedError
1383 |
1384 | start_logits, end_logits = logits.split(1, dim=-1)
1385 | start_logits = start_logits.squeeze(-1)
1386 | end_logits = end_logits.squeeze(-1)
1387 |
1388 | total_loss = None
1389 | if start_positions is not None and end_positions is not None:
1390 | # If we are on multi-GPU, split add a dimension
1391 | if len(start_positions.size()) > 1:
1392 | start_positions = start_positions.squeeze(-1)
1393 | if len(end_positions.size()) > 1:
1394 | end_positions = end_positions.squeeze(-1)
1395 | # sometimes the start/end positions are outside our model inputs, we ignore these terms
1396 | ignored_index = start_logits.size(1)
1397 | start_positions.clamp_(0, ignored_index)
1398 | end_positions.clamp_(0, ignored_index)
1399 |
1400 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1401 | start_loss = loss_fct(start_logits, start_positions)
1402 | end_loss = loss_fct(end_logits, end_positions)
1403 | total_loss = (start_loss + end_loss) / 2
1404 |
1405 | if not return_dict:
1406 | output = (
1407 | start_logits,
1408 | end_logits,
1409 | ) + outputs[1:]
1410 | return ((total_loss,) + output) if total_loss is not None else output
1411 |
1412 | return Seq2SeqQuestionAnsweringModelOutput(
1413 | loss=total_loss,
1414 | start_logits=start_logits,
1415 | end_logits=end_logits,
1416 | past_key_values=outputs.past_key_values,
1417 | decoder_hidden_states=outputs.decoder_hidden_states,
1418 | decoder_attentions=outputs.decoder_attentions,
1419 | cross_attentions=outputs.cross_attentions,
1420 | encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1421 | encoder_hidden_states=outputs.encoder_hidden_states,
1422 | encoder_attentions=outputs.encoder_attentions,
1423 | )
1424 |
1425 |
1426 | class CPTForMaskedLM(CPTPretrainedModel):
1427 | _keys_to_ignore_on_load_missing = [
1428 | r"final_logits_bias",
1429 | r"encoder\.version",
1430 | r"decoder\.version",
1431 | r"lm_head\.weight",
1432 | ]
1433 | def __init__(self, config, **kwargs):
1434 | super().__init__(config)
1435 | self.model = CPTModel(config)
1436 | self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
1437 | self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
1438 |
1439 | self.init_weights()
1440 |
1441 | def get_encoder(self):
1442 | return self.model.get_encoder()
1443 |
1444 | def get_decoder(self):
1445 | return self.model.get_decoder()
1446 |
1447 | def get_output_embeddings(self):
1448 | return self.lm_head
1449 |
1450 | def forward(
1451 | self,
1452 | input_ids=None,
1453 | attention_mask=None,
1454 | decoder_input_ids=None,
1455 | decoder_attention_mask=None,
1456 | head_mask=None,
1457 | decoder_head_mask=None,
1458 | encoder_outputs=None,
1459 | inputs_embeds=None,
1460 | decoder_inputs_embeds=None,
1461 | use_cache=None,
1462 | output_attentions=None,
1463 | output_hidden_states=None,
1464 | return_dict=None,
1465 | ):
1466 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1467 |
1468 | if input_ids is None and inputs_embeds is not None:
1469 | raise NotImplementedError(
1470 | f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
1471 | )
1472 |
1473 | outputs = self.model(
1474 | input_ids,
1475 | attention_mask=attention_mask,
1476 | decoder_input_ids=decoder_input_ids,
1477 | decoder_attention_mask=decoder_attention_mask,
1478 | head_mask=head_mask,
1479 | decoder_head_mask=decoder_head_mask,
1480 | encoder_outputs=encoder_outputs,
1481 | inputs_embeds=inputs_embeds,
1482 | decoder_inputs_embeds=decoder_inputs_embeds,
1483 | use_cache=use_cache,
1484 | output_attentions=output_attentions,
1485 | output_hidden_states=output_hidden_states,
1486 | return_dict=True,
1487 | )
1488 |
1489 | hidden_states = outputs.last_hidden_state
1490 | enc_hidden_states = outputs.encoder_last_hidden_state
1491 |
1492 | dec_logits = self.lm_head(hidden_states) + self.final_logits_bias
1493 | enc_logits = self.lm_head(enc_hidden_states) + self.final_logits_bias
1494 |
1495 | if not return_dict:
1496 | logits = (enc_logits, dec_logits)
1497 | output = (logits,) + outputs[1:]
1498 | return output
1499 |
1500 | return Seq2SeqLMOutput(
1501 | loss=None,
1502 | logits=(enc_logits, dec_logits),
1503 | past_key_values=outputs.past_key_values,
1504 | decoder_hidden_states=outputs.decoder_hidden_states,
1505 | decoder_attentions=outputs.decoder_attentions,
1506 | cross_attentions=outputs.cross_attentions,
1507 | encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1508 | encoder_hidden_states=outputs.encoder_hidden_states,
1509 | encoder_attentions=outputs.encoder_attentions,
1510 | )
--------------------------------------------------------------------------------