├── .gitignore
├── LICENSE
├── README.md
├── models
├── __init__.py
├── bert_model.py
└── modeling_bert.py
├── modules
├── metrics.py
└── train.py
├── processor
└── dataset.py
├── requirements.txt
├── resource
└── model.png
├── run.py
├── run_re_task.sh
├── run_twitter15.sh
└── run_twitter17.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | ckpt/
2 |
3 | data/
4 |
5 | logs/
6 |
7 | __pycache__/
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 ZJUNLP
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # HVPNet
2 |
3 | Code for the NAACL2022 (Findings) paper "[Good Visual Guidance Makes A Better Extractor: Hierarchical Visual Prefix for Multimodal Entity and Relation Extraction](https://arxiv.org/pdf/2205.03521.pdf)".
4 |
5 | Model Architecture
6 | ==========
7 |
8 |

9 |
10 | The overall architecture of our hierarchical modality fusion network.
11 |
12 |
13 | Requirements
14 | ==========
15 | To run the codes, you need to install the requirements:
16 | ```
17 | pip install -r requirements.txt
18 | ```
19 |
20 | Data Preprocess
21 | ==========
22 | To extract visual object images, we first use the NLTK parser to extract noun phrases from the text and apply the [visual grouding toolkit](https://github.com/zyang-ur/onestage_grounding) to detect objects. Detailed steps are as follows:
23 |
24 | 1. Using the NLTK parser (or Spacy, textblob) to extract noun phrases from the text.
25 | 2. Applying the [visual grouding toolkit](https://github.com/zyang-ur/onestage_grounding) to detect objects. Taking the twitter2015 dataset as an example, the extracted objects are stored in `twitter2015_aux_images`. The images of the object obey the following naming format: `imgname_pred_yolo_crop_num.png`, where `imgname` is the name of the raw image corresponding to the object, `num` is the number of the object predicted by the toolkit. (Note that in `train/val/test.txt`, text and raw image have a one-to-one relationship, so the `imgname` can be used as a unique identifier for the raw images)
26 | 3. Establishing the correspondence between the raw images and the objects. We construct a dictionary to record the correspondence between the raw images and the objects. Taking `twitter2015/twitter2015_train_dict.pth` as an example, the format of the dictionary can be seen as follows: `{imgname:['imgname_pred_yolo_crop_num0.png', 'imgname_pred_yolo_crop_num1.png', ...] }`, where key is the name of raw images, value is a List of the objects.
27 |
28 | The detected objects and the dictionary of the correspondence between the raw images and the objects are available in our data links.
29 |
30 | Data Download
31 | ==========
32 |
33 | + Twitter2015 & Twitter2017
34 |
35 | The text data follows the conll format. You can download the Twitter2015 data via this [link](https://drive.google.com/file/d/1qAWrV9IaiBadICFb7mAreXy3llao_teZ/view?usp=sharing) and download the Twitter2017 data via this [link](https://drive.google.com/file/d/1ogfbn-XEYtk9GpUECq1-IwzINnhKGJqy/view?usp=sharing). Please place them in `data/NER_data`.
36 |
37 | You can also put them anywhere and modify the path configuration in `run.py`
38 |
39 | + MNRE
40 |
41 | The MNRE dataset comes from [MEGA](https://github.com/thecharm/MNRE), many thanks.
42 |
43 | You can download the MRE dataset with detected visual objects from [Google Drive](https://drive.google.com/file/d/1q5_5vnHJ8Hik1iLA9f5-6nstcvvntLrS/view?usp=sharing) or use the following commands:
44 | ```bash
45 | cd data
46 | wget 120.27.214.45/Data/re/multimodal/data.tar.gz
47 | tar -xzvf data.tar.gz
48 | mv data RE_data
49 | ```
50 |
51 | The expected structure of files is:
52 |
53 | ```
54 | HMNeT
55 | |-- data
56 | | |-- NER_data
57 | | | |-- twitter2015 # text data
58 | | | | |-- train.txt
59 | | | | |-- valid.txt
60 | | | | |-- test.txt
61 | | | | |-- twitter2015_train_dict.pth # {imgname: [object-image]}
62 | | | | |-- ...
63 | | | |-- twitter2015_images # raw image data
64 | | | |-- twitter2015_aux_images # object image data
65 | | | |-- twitter2017
66 | | | |-- twitter2017_images
67 | | | |-- twitter2017_aux_images
68 | | |-- RE_data
69 | | | |-- img_org # raw image data
70 | | | |-- img_vg # object image data
71 | | | |-- txt # text data
72 | | | |-- ours_rel2id.json # relation data
73 | |-- models # models
74 | | |-- bert_model.py
75 | | |-- modeling_bert.py
76 | |-- modules
77 | | |-- metrics.py # metric
78 | | |-- train.py # trainer
79 | |-- processor
80 | | |-- dataset.py # processor, dataset
81 | |-- logs # code logs
82 | |-- run.py # main
83 | |-- run_ner_task.sh
84 | |-- run_re_task.sh
85 | ```
86 |
87 | Train
88 | ==========
89 |
90 | ## NER Task
91 |
92 | The data path and GPU related configuration are in the `run.py`. To train ner model, run this script.
93 |
94 | ```shell
95 | bash run_twitter15.sh
96 | bash run_twitter17.sh
97 | ```
98 |
99 | ## RE Task
100 |
101 | To train re model, run this script.
102 |
103 | ```shell
104 | bash run_re_task.sh
105 | ```
106 |
107 | Test
108 | ==========
109 | ## NER Task
110 |
111 | To test ner model, you can use the tained model and set `load_path` to the model path, then run following script:
112 |
113 | ```shell
114 | python -u run.py \
115 | --dataset_name="twitter15/twitter17" \
116 | --bert_name="bert-base-uncased" \
117 | --seed=1234 \
118 | --only_test \
119 | --max_seq=80 \
120 | --use_prompt \
121 | --prompt_len=4 \
122 | --sample_ratio=1.0 \
123 | --load_path='your_ner_ckpt_path'
124 |
125 | ```
126 |
127 | ## RE Task
128 |
129 | To test re model, you can use the tained model and set `load_path` to the model path, then run following script:
130 |
131 | ```shell
132 | python -u run.py \
133 | --dataset_name="MRE" \
134 | --bert_name="bert-base-uncased" \
135 | --seed=1234 \
136 | --only_test \
137 | --max_seq=80 \
138 | --use_prompt \
139 | --prompt_len=4 \
140 | --sample_ratio=1.0 \
141 | --load_path='your_re_ckpt_path'
142 |
143 | ```
144 |
145 | Acknowledgement
146 | ==========
147 |
148 | The acquisition of Twitter15 and Twitter17 data refer to the code from [UMT](https://github.com/jefferyYu/UMT/), many thanks.
149 |
150 | The acquisition of MNRE data for multimodal relation extraction task refer to the code from [MEGA](https://github.com/thecharm/Mega), many thanks.
151 |
152 | Papers for the Project & How to Cite
153 | ==========
154 |
155 |
156 | If you use or extend our work, please cite the paper as follows:
157 |
158 | ```bibtex
159 | @inproceedings{DBLP:conf/naacl/ChenZLYDTHSC22,
160 | author = {Xiang Chen and
161 | Ningyu Zhang and
162 | Lei Li and
163 | Yunzhi Yao and
164 | Shumin Deng and
165 | Chuanqi Tan and
166 | Fei Huang and
167 | Luo Si and
168 | Huajun Chen},
169 | editor = {Marine Carpuat and
170 | Marie{-}Catherine de Marneffe and
171 | Iv{\'{a}}n Vladimir Meza Ru{\'{\i}}z},
172 | title = {Good Visual Guidance Make {A} Better Extractor: Hierarchical Visual
173 | Prefix for Multimodal Entity and Relation Extraction},
174 | booktitle = {Findings of the Association for Computational Linguistics: {NAACL}
175 | 2022, Seattle, WA, United States, July 10-15, 2022},
176 | pages = {1607--1618},
177 | publisher = {Association for Computational Linguistics},
178 | year = {2022},
179 | url = {https://doi.org/10.18653/v1/2022.findings-naacl.121},
180 | doi = {10.18653/v1/2022.findings-naacl.121},
181 | timestamp = {Tue, 23 Aug 2022 08:36:33 +0200},
182 | biburl = {https://dblp.org/rec/conf/naacl/ChenZLYDTHSC22.bib},
183 | bibsource = {dblp computer science bibliography, https://dblp.org}
184 | }
185 | ```
186 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .bert_model import *
2 | from .modeling_bert import *
--------------------------------------------------------------------------------
/models/bert_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from torch import nn
4 | import torch.nn.functional as F
5 | from torchcrf import CRF
6 | from .modeling_bert import BertModel
7 | from transformers.modeling_outputs import TokenClassifierOutput
8 | from torchvision.models import resnet50
9 |
10 | class ImageModel(nn.Module):
11 | def __init__(self):
12 | super(ImageModel, self).__init__()
13 | self.resnet = resnet50(pretrained=True)
14 |
15 | def forward(self, x, aux_imgs=None):
16 | # full image prompt
17 | prompt_guids = self.get_resnet_prompt(x) # 4x[bsz, 256, 7, 7]
18 |
19 | # aux_imgs: bsz x 3(nums) x 3 x 224 x 224
20 | if aux_imgs is not None:
21 | aux_prompt_guids = [] # goal: 3 x (4 x [bsz, 256, 7, 7])
22 | aux_imgs = aux_imgs.permute([1, 0, 2, 3, 4]) # 3(nums) x bsz x 3 x 224 x 224
23 | for i in range(len(aux_imgs)):
24 | aux_prompt_guid = self.get_resnet_prompt(aux_imgs[i]) # 4 x [bsz, 256, 7, 7]
25 | aux_prompt_guids.append(aux_prompt_guid)
26 | return prompt_guids, aux_prompt_guids
27 | return prompt_guids, None
28 |
29 | def get_resnet_prompt(self, x):
30 | """generate image prompt
31 |
32 | Args:
33 | x ([torch.tenspr]): bsz x 3 x 224 x 224
34 |
35 | Returns:
36 | prompt_guids ([List[torch.tensor]]): 4 x List[bsz x 256 x 7 x 7]
37 | """
38 | # image: bsz x 3 x 224 x 224
39 | prompt_guids = []
40 | for name, layer in self.resnet.named_children():
41 | if name == 'fc' or name == 'avgpool': continue
42 | x = layer(x) # (bsz, 256, 56, 56)
43 | if 'layer' in name:
44 | bsz, channel, ft, _ = x.size()
45 | kernel = ft // 2
46 | prompt_kv = nn.AvgPool2d(kernel_size=(kernel, kernel), stride=kernel)(x) # (bsz, 256, 7, 7)
47 | prompt_guids.append(prompt_kv) # conv2: (bsz, 256, 7, 7)
48 | return prompt_guids
49 |
50 |
51 | class HMNeTREModel(nn.Module):
52 | def __init__(self, num_labels, tokenizer, args):
53 | super(HMNeTREModel, self).__init__()
54 | self.bert = BertModel.from_pretrained(args.bert_name)
55 | self.bert.resize_token_embeddings(len(tokenizer))
56 | self.args = args
57 |
58 | self.dropout = nn.Dropout(0.5)
59 | self.classifier = nn.Linear(self.bert.config.hidden_size*2, num_labels)
60 | self.head_start = tokenizer.convert_tokens_to_ids("")
61 | self.tail_start = tokenizer.convert_tokens_to_ids("")
62 | self.tokenizer = tokenizer
63 |
64 | if self.args.use_prompt:
65 | self.image_model = ImageModel()
66 |
67 | self.encoder_conv = nn.Sequential(
68 | nn.Linear(in_features=3840, out_features=800),
69 | nn.Tanh(),
70 | nn.Linear(in_features=800, out_features=4*2*768)
71 | )
72 |
73 | self.gates = nn.ModuleList([nn.Linear(4*768*2, 4) for i in range(12)])
74 |
75 | def forward(
76 | self,
77 | input_ids=None,
78 | attention_mask=None,
79 | token_type_ids=None,
80 | labels=None,
81 | images=None,
82 | aux_imgs=None,
83 | ):
84 |
85 | bsz = input_ids.size(0)
86 | if self.args.use_prompt:
87 | prompt_guids = self.get_visual_prompt(images, aux_imgs)
88 | prompt_guids_length = prompt_guids[0][0].shape[2]
89 | prompt_guids_mask = torch.ones((bsz, prompt_guids_length)).to(self.args.device)
90 | prompt_attention_mask = torch.cat((prompt_guids_mask, attention_mask), dim=1)
91 | else:
92 | prompt_guids = None
93 | prompt_attention_mask = attention_mask
94 |
95 | output = self.bert(
96 | input_ids=input_ids,
97 | token_type_ids=token_type_ids,
98 | attention_mask=prompt_attention_mask,
99 | past_key_values=prompt_guids,
100 | output_attentions=True,
101 | return_dict=True
102 | )
103 |
104 | last_hidden_state, pooler_output = output.last_hidden_state, output.pooler_output
105 | bsz, seq_len, hidden_size = last_hidden_state.shape
106 | entity_hidden_state = torch.Tensor(bsz, 2*hidden_size) # batch, 2*hidden
107 | for i in range(bsz):
108 | head_idx = input_ids[i].eq(self.head_start).nonzero().item()
109 | tail_idx = input_ids[i].eq(self.tail_start).nonzero().item()
110 | head_hidden = last_hidden_state[i, head_idx, :].squeeze()
111 | tail_hidden = last_hidden_state[i, tail_idx, :].squeeze()
112 | entity_hidden_state[i] = torch.cat([head_hidden, tail_hidden], dim=-1)
113 | entity_hidden_state = entity_hidden_state.to(self.args.device)
114 | logits = self.classifier(entity_hidden_state)
115 | if labels is not None:
116 | loss_fn = nn.CrossEntropyLoss()
117 | return loss_fn(logits, labels.view(-1)), logits
118 | return logits
119 |
120 | def get_visual_prompt(self, images, aux_imgs):
121 | bsz = images.size(0)
122 | # full image prompt
123 | prompt_guids, aux_prompt_guids = self.image_model(images, aux_imgs) # [bsz, 256, 2, 2], [bsz, 512, 2, 2]....
124 | prompt_guids = torch.cat(prompt_guids, dim=1).view(bsz, self.args.prompt_len, -1) # bsz, 4, 3840
125 |
126 | # aux image prompts # 3 x (4 x [bsz, 256, 2, 2])
127 | aux_prompt_guids = [torch.cat(aux_prompt_guid, dim=1).view(bsz, self.args.prompt_len, -1) for aux_prompt_guid in aux_prompt_guids] # 3 x [bsz, 4, 3840]
128 |
129 | prompt_guids = self.encoder_conv(prompt_guids) # bsz, 4, 4*2*768
130 | aux_prompt_guids = [self.encoder_conv(aux_prompt_guid) for aux_prompt_guid in aux_prompt_guids] # 3 x [bsz, 4, 4*2*768]
131 | split_prompt_guids = prompt_guids.split(768*2, dim=-1) # 4 x [bsz, 4, 768*2]
132 | split_aux_prompt_guids = [aux_prompt_guid.split(768*2, dim=-1) for aux_prompt_guid in aux_prompt_guids] # 3x [4 x [bsz, 4, 768*2]]
133 |
134 | sum_prompt_guids = torch.stack(split_prompt_guids).sum(0).view(bsz, -1) / 4 # bsz, 4, 768*2
135 |
136 | result = []
137 | for idx in range(12): # 12
138 | prompt_gate = F.softmax(F.leaky_relu(self.gates[idx](sum_prompt_guids)), dim=-1)
139 |
140 | key_val = torch.zeros_like(split_prompt_guids[0]).to(self.args.device) # bsz, 4, 768*2
141 | for i in range(4):
142 | key_val = key_val + torch.einsum('bg,blh->blh', prompt_gate[:, i].view(-1, 1), split_prompt_guids[i])
143 |
144 | # use gate mix aux image prompts
145 | aux_key_vals = [] # 3 x [bsz, 4, 768*2]
146 | for split_aux_prompt_guid in split_aux_prompt_guids:
147 | sum_aux_prompt_guids = torch.stack(split_aux_prompt_guid).sum(0).view(bsz, -1) / 4 # bsz, 4, 768*2
148 | aux_prompt_gate = F.softmax(F.leaky_relu(self.gates[idx](sum_aux_prompt_guids)), dim=-1)
149 | aux_key_val = torch.zeros_like(split_aux_prompt_guid[0]).to(self.args.device) # bsz, 4, 768*2
150 | for i in range(4):
151 | aux_key_val = aux_key_val + torch.einsum('bg,blh->blh', aux_prompt_gate[:, i].view(-1, 1), split_aux_prompt_guid[i])
152 | aux_key_vals.append(aux_key_val)
153 | key_val = [key_val] + aux_key_vals
154 | key_val = torch.cat(key_val, dim=1)
155 | key_val = key_val.split(768, dim=-1)
156 | key, value = key_val[0].reshape(bsz, 12, -1, 64).contiguous(), key_val[1].reshape(bsz, 12, -1, 64).contiguous() # bsz, 12, 4, 64
157 | temp_dict = (key, value)
158 | result.append(temp_dict)
159 | return result
160 |
161 |
162 | class HMNeTNERModel(nn.Module):
163 | def __init__(self, label_list, args):
164 | super(HMNeTNERModel, self).__init__()
165 | self.args = args
166 | self.prompt_dim = args.prompt_dim
167 | self.prompt_len = args.prompt_len
168 | self.bert = BertModel.from_pretrained(args.bert_name)
169 | self.bert_config = self.bert.config
170 |
171 | if args.use_prompt:
172 | self.image_model = ImageModel() # bsz, 6, 56, 56
173 | self.encoder_conv = nn.Sequential(
174 | nn.Linear(in_features=3840, out_features=800),
175 | nn.Tanh(),
176 | nn.Linear(in_features=800, out_features=4*2*768)
177 | )
178 | self.gates = nn.ModuleList([nn.Linear(4*768*2, 4) for i in range(12)])
179 |
180 | self.num_labels = len(label_list) # pad
181 | print(self.num_labels)
182 | self.crf = CRF(self.num_labels, batch_first=True)
183 | self.fc = nn.Linear(self.bert.config.hidden_size, self.num_labels)
184 | self.dropout = nn.Dropout(0.1)
185 |
186 | def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, images=None, aux_imgs=None):
187 | if self.args.use_prompt:
188 | prompt_guids = self.get_visual_prompt(images, aux_imgs)
189 | prompt_guids_length = prompt_guids[0][0].shape[2]
190 | # attention_mask: bsz, seq_len
191 | # prompt attention, attention mask
192 | bsz = attention_mask.size(0)
193 | prompt_guids_mask = torch.ones((bsz, prompt_guids_length)).to(self.args.device)
194 | prompt_attention_mask = torch.cat((prompt_guids_mask, attention_mask), dim=1)
195 | else:
196 | prompt_attention_mask = attention_mask
197 | prompt_guids = None
198 |
199 | bert_output = self.bert(input_ids=input_ids,
200 | attention_mask=prompt_attention_mask,
201 | token_type_ids=token_type_ids,
202 | past_key_values=prompt_guids,
203 | return_dict=True)
204 | sequence_output = bert_output['last_hidden_state'] # bsz, len, hidden
205 | sequence_output = self.dropout(sequence_output) # bsz, len, hidden
206 | emissions = self.fc(sequence_output) # bsz, len, labels
207 |
208 | logits = self.crf.decode(emissions, attention_mask.byte())
209 | loss = None
210 | if labels is not None:
211 | loss = -1 * self.crf(emissions, labels, mask=attention_mask.byte(), reduction='mean')
212 | return TokenClassifierOutput(
213 | loss=loss,
214 | logits=logits
215 | )
216 |
217 | def get_visual_prompt(self, images, aux_imgs):
218 | bsz = images.size(0)
219 | prompt_guids, aux_prompt_guids = self.image_model(images, aux_imgs) # [bsz, 256, 2, 2], [bsz, 512, 2, 2]....
220 |
221 | prompt_guids = torch.cat(prompt_guids, dim=1).view(bsz, self.args.prompt_len, -1) # bsz, 4, 3840
222 | aux_prompt_guids = [torch.cat(aux_prompt_guid, dim=1).view(bsz, self.args.prompt_len, -1) for aux_prompt_guid in aux_prompt_guids] # 3 x [bsz, 4, 3840]
223 |
224 | prompt_guids = self.encoder_conv(prompt_guids) # bsz, 4, 4*2*768
225 | aux_prompt_guids = [self.encoder_conv(aux_prompt_guid) for aux_prompt_guid in aux_prompt_guids] # 3 x [bsz, 4, 4*2*768]
226 | split_prompt_guids = prompt_guids.split(768*2, dim=-1) # 4 x [bsz, 4, 768*2]
227 | split_aux_prompt_guids = [aux_prompt_guid.split(768*2, dim=-1) for aux_prompt_guid in aux_prompt_guids] # 3x [4 x [bsz, 4, 768*2]]
228 |
229 | result = []
230 | for idx in range(12): # 12
231 | sum_prompt_guids = torch.stack(split_prompt_guids).sum(0).view(bsz, -1) / 4 # bsz, 4, 768*2
232 | prompt_gate = F.softmax(F.leaky_relu(self.gates[idx](sum_prompt_guids)), dim=-1)
233 |
234 | key_val = torch.zeros_like(split_prompt_guids[0]).to(self.args.device) # bsz, 4, 768*2
235 | for i in range(4):
236 | key_val = key_val + torch.einsum('bg,blh->blh', prompt_gate[:, i].view(-1, 1), split_prompt_guids[i])
237 |
238 | aux_key_vals = [] # 3 x [bsz, 4, 768*2]
239 | for split_aux_prompt_guid in split_aux_prompt_guids:
240 | sum_aux_prompt_guids = torch.stack(split_aux_prompt_guid).sum(0).view(bsz, -1) / 4 # bsz, 4, 768*2
241 | aux_prompt_gate = F.softmax(F.leaky_relu(self.gates[idx](sum_aux_prompt_guids)), dim=-1)
242 | aux_key_val = torch.zeros_like(split_aux_prompt_guid[0]).to(self.args.device) # bsz, 4, 768*2
243 | for i in range(4):
244 | aux_key_val = aux_key_val + torch.einsum('bg,blh->blh', aux_prompt_gate[:, i].view(-1, 1), split_aux_prompt_guid[i])
245 | aux_key_vals.append(aux_key_val)
246 | key_val = [key_val] + aux_key_vals
247 | key_val = torch.cat(key_val, dim=1)
248 | key_val = key_val.split(768, dim=-1)
249 | key, value = key_val[0].reshape(bsz, 12, -1, 64).contiguous(), key_val[1].reshape(bsz, 12, -1, 64).contiguous() # bsz, 12, 4, 64
250 | temp_dict = (key, value)
251 | result.append(temp_dict)
252 | return result
253 |
--------------------------------------------------------------------------------
/models/modeling_bert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """PyTorch BERT model. """
17 | import math
18 | import os
19 | import warnings
20 | from dataclasses import dataclass
21 | from typing import Optional, Tuple
22 |
23 | import torch
24 | import torch.utils.checkpoint
25 | from packaging import version
26 | from torch import nn
27 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28 |
29 | from transformers.activations import ACT2FN
30 | from transformers.file_utils import (
31 | ModelOutput,
32 | add_code_sample_docstrings,
33 | add_start_docstrings,
34 | add_start_docstrings_to_model_forward,
35 | replace_return_docstrings,
36 | )
37 | from transformers.modeling_outputs import (
38 | BaseModelOutputWithPastAndCrossAttentions,
39 | BaseModelOutputWithPoolingAndCrossAttentions,
40 | CausalLMOutputWithCrossAttentions,
41 | MaskedLMOutput,
42 | MultipleChoiceModelOutput,
43 | NextSentencePredictorOutput,
44 | QuestionAnsweringModelOutput,
45 | SequenceClassifierOutput,
46 | TokenClassifierOutput,
47 | )
48 | from transformers.modeling_utils import (
49 | PreTrainedModel,
50 | apply_chunking_to_forward,
51 | find_pruneable_heads_and_indices,
52 | prune_linear_layer,
53 | )
54 | from transformers.utils import logging
55 | from transformers import BertConfig
56 |
57 |
58 | logger = logging.get_logger(__name__)
59 |
60 | _CHECKPOINT_FOR_DOC = "bert-base-uncased"
61 | _CONFIG_FOR_DOC = "BertConfig"
62 | _TOKENIZER_FOR_DOC = "BertTokenizer"
63 |
64 | BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
65 | "bert-base-uncased",
66 | "bert-large-uncased",
67 | "bert-base-cased",
68 | "bert-large-cased",
69 | "bert-base-multilingual-uncased",
70 | "bert-base-multilingual-cased",
71 | "bert-base-chinese",
72 | "bert-base-german-cased",
73 | "bert-large-uncased-whole-word-masking",
74 | "bert-large-cased-whole-word-masking",
75 | "bert-large-uncased-whole-word-masking-finetuned-squad",
76 | "bert-large-cased-whole-word-masking-finetuned-squad",
77 | "bert-base-cased-finetuned-mrpc",
78 | "bert-base-german-dbmdz-cased",
79 | "bert-base-german-dbmdz-uncased",
80 | "cl-tohoku/bert-base-japanese",
81 | "cl-tohoku/bert-base-japanese-whole-word-masking",
82 | "cl-tohoku/bert-base-japanese-char",
83 | "cl-tohoku/bert-base-japanese-char-whole-word-masking",
84 | "TurkuNLP/bert-base-finnish-cased-v1",
85 | "TurkuNLP/bert-base-finnish-uncased-v1",
86 | "wietsedv/bert-base-dutch-cased",
87 | # See all BERT models at https://huggingface.co/models?filter=bert
88 | ]
89 |
90 |
91 | def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
92 | """Load tf checkpoints in a pytorch model."""
93 | try:
94 | import re
95 |
96 | import numpy as np
97 | import tensorflow as tf
98 | except ImportError:
99 | logger.error(
100 | "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
101 | "https://www.tensorflow.org/install/ for installation instructions."
102 | )
103 | raise
104 | tf_path = os.path.abspath(tf_checkpoint_path)
105 | logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
106 | # Load weights from TF model
107 | init_vars = tf.train.list_variables(tf_path)
108 | names = []
109 | arrays = []
110 | for name, shape in init_vars:
111 | logger.info(f"Loading TF weight {name} with shape {shape}")
112 | array = tf.train.load_variable(tf_path, name)
113 | names.append(name)
114 | arrays.append(array)
115 |
116 | for name, array in zip(names, arrays):
117 | name = name.split("/")
118 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
119 | # which are not required for using pretrained model
120 | if any(
121 | n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
122 | for n in name
123 | ):
124 | logger.info(f"Skipping {'/'.join(name)}")
125 | continue
126 | pointer = model
127 | for m_name in name:
128 | if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
129 | scope_names = re.split(r"_(\d+)", m_name)
130 | else:
131 | scope_names = [m_name]
132 | if scope_names[0] == "kernel" or scope_names[0] == "gamma":
133 | pointer = getattr(pointer, "weight")
134 | elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
135 | pointer = getattr(pointer, "bias")
136 | elif scope_names[0] == "output_weights":
137 | pointer = getattr(pointer, "weight")
138 | elif scope_names[0] == "squad":
139 | pointer = getattr(pointer, "classifier")
140 | else:
141 | try:
142 | pointer = getattr(pointer, scope_names[0])
143 | except AttributeError:
144 | logger.info(f"Skipping {'/'.join(name)}")
145 | continue
146 | if len(scope_names) >= 2:
147 | num = int(scope_names[1])
148 | pointer = pointer[num]
149 | if m_name[-11:] == "_embeddings":
150 | pointer = getattr(pointer, "weight")
151 | elif m_name == "kernel":
152 | array = np.transpose(array)
153 | try:
154 | assert (
155 | pointer.shape == array.shape
156 | ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
157 | except AssertionError as e:
158 | e.args += (pointer.shape, array.shape)
159 | raise
160 | logger.info(f"Initialize PyTorch weight {name}")
161 | pointer.data = torch.from_numpy(array)
162 | return model
163 |
164 |
165 | class BertEmbeddings(nn.Module):
166 | """Construct the embeddings from word, position and token_type embeddings."""
167 |
168 | def __init__(self, config):
169 | super().__init__()
170 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
171 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
172 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
173 |
174 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
175 | # any TensorFlow checkpoint file
176 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
177 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
178 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized
179 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
180 | self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
181 | if version.parse(torch.__version__) > version.parse("1.6.0"):
182 | self.register_buffer(
183 | "token_type_ids",
184 | torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
185 | persistent=False,
186 | )
187 |
188 | def forward(
189 | self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
190 | ):
191 | if input_ids is not None:
192 | input_shape = input_ids.size()
193 | else:
194 | input_shape = inputs_embeds.size()[:-1]
195 |
196 | seq_length = input_shape[1]
197 |
198 | if position_ids is None:
199 | position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
200 |
201 | # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
202 | # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
203 | # issue #5664
204 | if token_type_ids is None:
205 | if hasattr(self, "token_type_ids"):
206 | buffered_token_type_ids = self.token_type_ids[:, :seq_length]
207 | buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
208 | token_type_ids = buffered_token_type_ids_expanded
209 | else:
210 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
211 |
212 | if inputs_embeds is None:
213 | inputs_embeds = self.word_embeddings(input_ids)
214 | token_type_embeddings = self.token_type_embeddings(token_type_ids)
215 |
216 | embeddings = inputs_embeds + token_type_embeddings
217 | if self.position_embedding_type == "absolute":
218 | position_embeddings = self.position_embeddings(position_ids)
219 | embeddings += position_embeddings
220 | embeddings = self.LayerNorm(embeddings)
221 | embeddings = self.dropout(embeddings)
222 | return embeddings
223 |
224 |
225 | class BertSelfAttention(nn.Module):
226 | def __init__(self, config):
227 | super().__init__()
228 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
229 | raise ValueError(
230 | f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
231 | f"heads ({config.num_attention_heads})"
232 | )
233 |
234 | self.num_attention_heads = config.num_attention_heads # 12
235 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) # 64
236 | self.all_head_size = self.num_attention_heads * self.attention_head_size # 768
237 |
238 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
239 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
240 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
241 |
242 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
243 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
244 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
245 | self.max_position_embeddings = config.max_position_embeddings
246 | self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
247 |
248 | self.is_decoder = config.is_decoder
249 |
250 | def transpose_for_scores(self, x):
251 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
252 | x = x.view(*new_x_shape)
253 | return x.permute(0, 2, 1, 3)
254 |
255 | def forward(
256 | self,
257 | hidden_states,
258 | attention_mask=None,
259 | head_mask=None,
260 | encoder_hidden_states=None,
261 | encoder_attention_mask=None,
262 | past_key_value=None,
263 | output_attentions=False,
264 | ### add:
265 | current_layer=0
266 | ):
267 | mixed_query_layer = self.query(hidden_states)
268 |
269 | # If this is instantiated as a cross-attention module, the keys
270 | # and values come from an encoder; the attention mask needs to be
271 | # such that the encoder's padding tokens are not attended to.
272 | is_cross_attention = encoder_hidden_states is not None
273 |
274 | if is_cross_attention and past_key_value is not None:
275 | # reuse k,v, cross_attentions
276 | key_layer = past_key_value[0]
277 | value_layer = past_key_value[1]
278 | attention_mask = encoder_attention_mask
279 | elif is_cross_attention:
280 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
281 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
282 | attention_mask = encoder_attention_mask
283 | elif past_key_value is not None: # go here
284 | key_layer = self.transpose_for_scores(self.key(hidden_states))
285 | value_layer = self.transpose_for_scores(self.value(hidden_states))
286 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
287 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
288 | else:
289 | key_layer = self.transpose_for_scores(self.key(hidden_states))
290 | value_layer = self.transpose_for_scores(self.value(hidden_states))
291 |
292 | query_layer = self.transpose_for_scores(mixed_query_layer)
293 | if self.is_decoder:
294 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
295 | # Further calls to cross_attention layer can then reuse all cross-attention
296 | # key/value_states (first "if" case)
297 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
298 | # all previous decoder key/value_states. Further calls to uni-directional self-attention
299 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
300 | # if encoder bi-directional self-attention `past_key_value` is always `None`
301 | past_key_value = (key_layer, value_layer)
302 |
303 | # Take the dot product between "query" and "key" to get the raw attention scores.q
304 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
305 |
306 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
307 | seq_length = hidden_states.size()[1]
308 | position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
309 | position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
310 | distance = position_ids_l - position_ids_r
311 | positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
312 | positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
313 |
314 | if self.position_embedding_type == "relative_key":
315 | relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
316 | attention_scores = attention_scores + relative_position_scores
317 | elif self.position_embedding_type == "relative_key_query":
318 | relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
319 | relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
320 | attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
321 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
322 | if attention_mask is not None:
323 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
324 | attention_scores = attention_scores + attention_mask
325 | # Normalize the attention scores to probabilities.
326 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
327 | # This is actually dropping out entire tokens to attend to, which might
328 | # seem a bit unusual, but is taken from the original Transformer paper.
329 | attention_probs = self.dropout(attention_probs)
330 |
331 | # Mask heads if we want to
332 | if head_mask is not None:
333 | attention_probs = attention_probs * head_mask
334 | context_layer = torch.matmul(attention_probs, value_layer)
335 |
336 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
337 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
338 | context_layer = context_layer.view(*new_context_layer_shape)
339 |
340 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
341 | if self.is_decoder:
342 | outputs = outputs + (past_key_value,)
343 | return outputs
344 |
345 |
346 | class BertSelfOutput(nn.Module):
347 | def __init__(self, config):
348 | super().__init__()
349 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
350 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
351 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
352 |
353 | def forward(self, hidden_states, input_tensor):
354 | hidden_states = self.dense(hidden_states)
355 | hidden_states = self.dropout(hidden_states)
356 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
357 | return hidden_states
358 |
359 |
360 | class BertAttention(nn.Module):
361 | def __init__(self, config):
362 | super().__init__()
363 | self.self = BertSelfAttention(config)
364 | self.output = BertSelfOutput(config)
365 | self.pruned_heads = set()
366 |
367 | def prune_heads(self, heads):
368 | if len(heads) == 0:
369 | return
370 | heads, index = find_pruneable_heads_and_indices(
371 | heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
372 | )
373 |
374 | # Prune linear layers
375 | self.self.query = prune_linear_layer(self.self.query, index)
376 | self.self.key = prune_linear_layer(self.self.key, index)
377 | self.self.value = prune_linear_layer(self.self.value, index)
378 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
379 |
380 | # Update hyper params and store pruned heads
381 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
382 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
383 | self.pruned_heads = self.pruned_heads.union(heads)
384 |
385 | def forward(
386 | self,
387 | hidden_states,
388 | attention_mask=None,
389 | head_mask=None,
390 | encoder_hidden_states=None,
391 | encoder_attention_mask=None,
392 | past_key_value=None,
393 | output_attentions=False,
394 | # add:
395 | current_layer=0
396 | ):
397 | self_outputs = self.self(
398 | hidden_states,
399 | attention_mask,
400 | head_mask,
401 | encoder_hidden_states,
402 | encoder_attention_mask,
403 | past_key_value,
404 | output_attentions,
405 | ### add:
406 | current_layer=current_layer
407 | )
408 | attention_output = self.output(self_outputs[0], hidden_states)
409 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
410 | return outputs
411 |
412 |
413 | class BertIntermediate(nn.Module):
414 | def __init__(self, config):
415 | super().__init__()
416 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
417 | if isinstance(config.hidden_act, str):
418 | self.intermediate_act_fn = ACT2FN[config.hidden_act]
419 | else:
420 | self.intermediate_act_fn = config.hidden_act
421 |
422 | def forward(self, hidden_states):
423 | hidden_states = self.dense(hidden_states)
424 | hidden_states = self.intermediate_act_fn(hidden_states)
425 | return hidden_states
426 |
427 |
428 | class BertOutput(nn.Module): # FFN
429 | def __init__(self, config):
430 | super().__init__()
431 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
432 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
433 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
434 |
435 | def forward(self, hidden_states, input_tensor):
436 | hidden_states = self.dense(hidden_states)
437 | hidden_states = self.dropout(hidden_states)
438 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
439 | return hidden_states
440 |
441 |
442 | class BertLayer(nn.Module):
443 | def __init__(self, config):
444 | super().__init__()
445 | self.chunk_size_feed_forward = config.chunk_size_feed_forward
446 | self.seq_len_dim = 1
447 | self.attention = BertAttention(config)
448 | self.is_decoder = config.is_decoder
449 | self.add_cross_attention = config.add_cross_attention
450 | if self.add_cross_attention:
451 | assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
452 | self.crossattention = BertAttention(config)
453 | self.intermediate = BertIntermediate(config)
454 | self.output = BertOutput(config)
455 |
456 | def forward(
457 | self,
458 | hidden_states,
459 | attention_mask=None,
460 | head_mask=None,
461 | encoder_hidden_states=None,
462 | encoder_attention_mask=None,
463 | past_key_value=None,
464 | output_attentions=False,
465 | ### add:
466 | current_layer=0
467 | ):
468 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
469 | # self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
470 | self_attention_outputs = self.attention(
471 | hidden_states,
472 | attention_mask,
473 | head_mask,
474 | output_attentions=output_attentions,
475 | # past_key_value=self_attn_past_key_value,
476 | past_key_value=past_key_value,
477 | ### add:
478 | current_layer=current_layer
479 | )
480 | attention_output = self_attention_outputs[0]
481 |
482 | # if decoder, the last output is tuple of self-attn cache
483 | if self.is_decoder:
484 | outputs = self_attention_outputs[1:-1]
485 | present_key_value = self_attention_outputs[-1]
486 | else:
487 | outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
488 |
489 | cross_attn_present_key_value = None
490 | if self.is_decoder and encoder_hidden_states is not None:
491 | assert hasattr(
492 | self, "crossattention"
493 | ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
494 |
495 | # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
496 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
497 | cross_attention_outputs = self.crossattention(
498 | attention_output,
499 | attention_mask,
500 | head_mask,
501 | encoder_hidden_states,
502 | encoder_attention_mask,
503 | cross_attn_past_key_value,
504 | output_attentions,
505 | )
506 | attention_output = cross_attention_outputs[0]
507 | outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
508 |
509 | # add cross-attn cache to positions 3,4 of present_key_value tuple
510 | cross_attn_present_key_value = cross_attention_outputs[-1]
511 | present_key_value = present_key_value + cross_attn_present_key_value
512 |
513 | layer_output = apply_chunking_to_forward(
514 | self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
515 | )
516 | outputs = (layer_output,) + outputs
517 |
518 | # if decoder, return the attn key/values as the last output
519 | if self.is_decoder:
520 | outputs = outputs + (present_key_value,)
521 |
522 | return outputs
523 |
524 | def feed_forward_chunk(self, attention_output):
525 | intermediate_output = self.intermediate(attention_output)
526 | layer_output = self.output(intermediate_output, attention_output)
527 | return layer_output
528 |
529 |
530 | class BertEncoder(nn.Module):
531 | def __init__(self, config):
532 | super().__init__()
533 | self.config = config
534 | self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
535 | self.gradient_checkpointing = False
536 |
537 | def forward(
538 | self,
539 | hidden_states,
540 | attention_mask=None,
541 | head_mask=None,
542 | encoder_hidden_states=None,
543 | encoder_attention_mask=None,
544 | past_key_values=None,
545 | use_cache=None,
546 | output_attentions=False,
547 | output_hidden_states=False,
548 | return_dict=True,
549 | ):
550 | all_hidden_states = () if output_hidden_states else None
551 | all_self_attentions = () if output_attentions else None
552 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
553 |
554 | next_decoder_cache = () if use_cache else None
555 | for i, layer_module in enumerate(self.layer):
556 | if output_hidden_states:
557 | all_hidden_states = all_hidden_states + (hidden_states,)
558 |
559 | layer_head_mask = head_mask[i] if head_mask is not None else None
560 | past_key_value = past_key_values[i] if past_key_values is not None else None
561 |
562 | if self.gradient_checkpointing and self.training:
563 |
564 | if use_cache:
565 | logger.warning(
566 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
567 | )
568 | use_cache = False
569 |
570 | def create_custom_forward(module):
571 | def custom_forward(*inputs):
572 | return module(*inputs, past_key_value, output_attentions)
573 |
574 | return custom_forward
575 |
576 | layer_outputs = torch.utils.checkpoint.checkpoint(
577 | create_custom_forward(layer_module),
578 | hidden_states,
579 | attention_mask,
580 | layer_head_mask,
581 | encoder_hidden_states,
582 | encoder_attention_mask,
583 | )
584 | else:
585 | layer_outputs = layer_module(
586 | hidden_states,
587 | attention_mask,
588 | layer_head_mask,
589 | encoder_hidden_states,
590 | encoder_attention_mask,
591 | past_key_value,
592 | output_attentions,
593 | ### add:
594 | current_layer=i
595 | )
596 |
597 | hidden_states = layer_outputs[0]
598 | if use_cache:
599 | next_decoder_cache += (layer_outputs[-1],)
600 | if output_attentions:
601 | all_self_attentions = all_self_attentions + (layer_outputs[1],)
602 | if self.config.add_cross_attention:
603 | all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
604 |
605 | if output_hidden_states:
606 | all_hidden_states = all_hidden_states + (hidden_states,)
607 |
608 | if not return_dict:
609 | return tuple(
610 | v
611 | for v in [
612 | hidden_states,
613 | next_decoder_cache,
614 | all_hidden_states,
615 | all_self_attentions,
616 | all_cross_attentions,
617 | ]
618 | if v is not None
619 | )
620 | return BaseModelOutputWithPastAndCrossAttentions(
621 | last_hidden_state=hidden_states,
622 | past_key_values=next_decoder_cache,
623 | hidden_states=all_hidden_states,
624 | attentions=all_self_attentions,
625 | cross_attentions=all_cross_attentions,
626 | )
627 |
628 |
629 | class BertPooler(nn.Module):
630 | def __init__(self, config):
631 | super().__init__()
632 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
633 | self.activation = nn.Tanh()
634 |
635 | def forward(self, hidden_states):
636 | # We "pool" the model by simply taking the hidden state corresponding
637 | # to the first token.
638 | first_token_tensor = hidden_states[:, 0]
639 | pooled_output = self.dense(first_token_tensor)
640 | pooled_output = self.activation(pooled_output)
641 | return pooled_output
642 |
643 |
644 | class BertPredictionHeadTransform(nn.Module):
645 | def __init__(self, config):
646 | super().__init__()
647 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
648 | if isinstance(config.hidden_act, str):
649 | self.transform_act_fn = ACT2FN[config.hidden_act]
650 | else:
651 | self.transform_act_fn = config.hidden_act
652 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
653 |
654 | def forward(self, hidden_states):
655 | hidden_states = self.dense(hidden_states)
656 | hidden_states = self.transform_act_fn(hidden_states)
657 | hidden_states = self.LayerNorm(hidden_states)
658 | return hidden_states
659 |
660 |
661 | class BertLMPredictionHead(nn.Module):
662 | def __init__(self, config):
663 | super().__init__()
664 | self.transform = BertPredictionHeadTransform(config)
665 |
666 | # The output weights are the same as the input embeddings, but there is
667 | # an output-only bias for each token.
668 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
669 |
670 | self.bias = nn.Parameter(torch.zeros(config.vocab_size))
671 |
672 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
673 | self.decoder.bias = self.bias
674 |
675 | def forward(self, hidden_states):
676 | hidden_states = self.transform(hidden_states)
677 | hidden_states = self.decoder(hidden_states)
678 | return hidden_states
679 |
680 |
681 | class BertOnlyMLMHead(nn.Module):
682 | def __init__(self, config):
683 | super().__init__()
684 | self.predictions = BertLMPredictionHead(config)
685 |
686 | def forward(self, sequence_output):
687 | prediction_scores = self.predictions(sequence_output)
688 | return prediction_scores
689 |
690 |
691 | class BertOnlyNSPHead(nn.Module):
692 | def __init__(self, config):
693 | super().__init__()
694 | self.seq_relationship = nn.Linear(config.hidden_size, 2)
695 |
696 | def forward(self, pooled_output):
697 | seq_relationship_score = self.seq_relationship(pooled_output)
698 | return seq_relationship_score
699 |
700 |
701 | class BertPreTrainingHeads(nn.Module):
702 | def __init__(self, config):
703 | super().__init__()
704 | self.predictions = BertLMPredictionHead(config)
705 | self.seq_relationship = nn.Linear(config.hidden_size, 2)
706 |
707 | def forward(self, sequence_output, pooled_output):
708 | prediction_scores = self.predictions(sequence_output)
709 | seq_relationship_score = self.seq_relationship(pooled_output)
710 | return prediction_scores, seq_relationship_score
711 |
712 |
713 | class BertPreTrainedModel(PreTrainedModel):
714 | """
715 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
716 | models.
717 | """
718 |
719 | config_class = BertConfig
720 | load_tf_weights = load_tf_weights_in_bert
721 | base_model_prefix = "bert"
722 | supports_gradient_checkpointing = True
723 | _keys_to_ignore_on_load_missing = [r"position_ids"]
724 |
725 | def _init_weights(self, module):
726 | """Initialize the weights"""
727 | if isinstance(module, nn.Linear):
728 | # Slightly different from the TF version which uses truncated_normal for initialization
729 | # cf https://github.com/pytorch/pytorch/pull/5617
730 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
731 | if module.bias is not None:
732 | module.bias.data.zero_()
733 | elif isinstance(module, nn.Embedding):
734 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
735 | if module.padding_idx is not None:
736 | module.weight.data[module.padding_idx].zero_()
737 | elif isinstance(module, nn.LayerNorm):
738 | module.bias.data.zero_()
739 | module.weight.data.fill_(1.0)
740 |
741 | def _set_gradient_checkpointing(self, module, value=False):
742 | if isinstance(module, BertEncoder):
743 | module.gradient_checkpointing = value
744 |
745 |
746 | @dataclass
747 | class BertForPreTrainingOutput(ModelOutput):
748 | """
749 | Output type of :class:`~transformers.BertForPreTraining`.
750 |
751 | Args:
752 | loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
753 | Total loss as the sum of the masked language modeling loss and the next sequence prediction
754 | (classification) loss.
755 | prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
756 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
757 | seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
758 | Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
759 | before SoftMax).
760 | hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
761 | Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
762 | of shape :obj:`(batch_size, sequence_length, hidden_size)`.
763 |
764 | Hidden-states of the model at the output of each layer plus the initial embedding outputs.
765 | attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
766 | Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
767 | sequence_length, sequence_length)`.
768 |
769 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
770 | heads.
771 | """
772 |
773 | loss: Optional[torch.FloatTensor] = None
774 | prediction_logits: torch.FloatTensor = None
775 | seq_relationship_logits: torch.FloatTensor = None
776 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
777 | attentions: Optional[Tuple[torch.FloatTensor]] = None
778 |
779 |
780 | BERT_START_DOCSTRING = r"""
781 |
782 | This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
783 | methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
784 | pruning heads etc.)
785 |
786 | This model is also a PyTorch `torch.nn.Module `__
787 | subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
788 | general usage and behavior.
789 |
790 | Parameters:
791 | config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
792 | Initializing with a config file does not load the weights associated with the model, only the
793 | configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
794 | weights.
795 | """
796 |
797 | BERT_INPUTS_DOCSTRING = r"""
798 | Args:
799 | input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
800 | Indices of input sequence tokens in the vocabulary.
801 |
802 | Indices can be obtained using :class:`~transformers.BertTokenizer`. See
803 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
804 | details.
805 |
806 | `What are input IDs? <../glossary.html#input-ids>`__
807 | attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
808 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
809 |
810 | - 1 for tokens that are **not masked**,
811 | - 0 for tokens that are **masked**.
812 |
813 | `What are attention masks? <../glossary.html#attention-mask>`__
814 | token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
815 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
816 | 1]``:
817 |
818 | - 0 corresponds to a `sentence A` token,
819 | - 1 corresponds to a `sentence B` token.
820 |
821 | `What are token type IDs? <../glossary.html#token-type-ids>`_
822 | position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
823 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
824 | config.max_position_embeddings - 1]``.
825 |
826 | `What are position IDs? <../glossary.html#position-ids>`_
827 | head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
828 | Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
829 |
830 | - 1 indicates the head is **not masked**,
831 | - 0 indicates the head is **masked**.
832 |
833 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
834 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
835 | This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
836 | vectors than the model's internal embedding lookup matrix.
837 | output_attentions (:obj:`bool`, `optional`):
838 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
839 | tensors for more detail.
840 | output_hidden_states (:obj:`bool`, `optional`):
841 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
842 | more detail.
843 | return_dict (:obj:`bool`, `optional`):
844 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
845 | """
846 |
847 |
848 | @add_start_docstrings(
849 | "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
850 | BERT_START_DOCSTRING,
851 | )
852 | class BertModel(BertPreTrainedModel):
853 | """
854 |
855 | The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
856 | cross-attention is added between the self-attention layers, following the architecture described in `Attention is
857 | all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
858 | Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
859 |
860 | To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration
861 | set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
862 | argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
863 | input to the forward pass.
864 | """
865 |
866 | def __init__(self, config, add_pooling_layer=True):
867 | super().__init__(config)
868 | self.config = config
869 |
870 | self.embeddings = BertEmbeddings(config)
871 | self.encoder = BertEncoder(config)
872 |
873 | self.pooler = BertPooler(config) if add_pooling_layer else None
874 |
875 | self.init_weights()
876 |
877 | def get_input_embeddings(self):
878 | return self.embeddings.word_embeddings
879 |
880 | def set_input_embeddings(self, value):
881 | self.embeddings.word_embeddings = value
882 |
883 | def _prune_heads(self, heads_to_prune):
884 | """
885 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
886 | class PreTrainedModel
887 | """
888 | for layer, heads in heads_to_prune.items():
889 | self.encoder.layer[layer].attention.prune_heads(heads)
890 |
891 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
892 | @add_code_sample_docstrings(
893 | tokenizer_class=_TOKENIZER_FOR_DOC,
894 | checkpoint=_CHECKPOINT_FOR_DOC,
895 | output_type=BaseModelOutputWithPoolingAndCrossAttentions,
896 | config_class=_CONFIG_FOR_DOC,
897 | )
898 | def forward(
899 | self,
900 | input_ids=None,
901 | attention_mask=None,
902 | token_type_ids=None,
903 | position_ids=None,
904 | head_mask=None,
905 | inputs_embeds=None,
906 | encoder_hidden_states=None,
907 | encoder_attention_mask=None,
908 | past_key_values=None,
909 | use_cache=None,
910 | output_attentions=None,
911 | output_hidden_states=None,
912 | return_dict=None,
913 | ):
914 | r"""
915 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
916 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
917 | the model is configured as a decoder.
918 | encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
919 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
920 | the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
921 |
922 | - 1 for tokens that are **not masked**,
923 | - 0 for tokens that are **masked**.
924 | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
925 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
926 |
927 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
928 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
929 | instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
930 | use_cache (:obj:`bool`, `optional`):
931 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
932 | decoding (see :obj:`past_key_values`).
933 | """
934 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
935 | output_hidden_states = (
936 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
937 | )
938 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
939 |
940 | if self.config.is_decoder:
941 | use_cache = use_cache if use_cache is not None else self.config.use_cache
942 | else:
943 | use_cache = False
944 |
945 | if input_ids is not None and inputs_embeds is not None:
946 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
947 | elif input_ids is not None:
948 | input_shape = input_ids.size()
949 | elif inputs_embeds is not None:
950 | input_shape = inputs_embeds.size()[:-1]
951 | else:
952 | raise ValueError("You have to specify either input_ids or inputs_embeds")
953 |
954 | batch_size, seq_length = input_shape
955 | device = input_ids.device if input_ids is not None else inputs_embeds.device
956 |
957 | # past_key_values_length
958 | # past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
959 | # add:
960 | past_key_values_length = 0 # position_id
961 | ##
962 |
963 |
964 | if attention_mask is None:
965 | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
966 | if token_type_ids is None:
967 | if hasattr(self.embeddings, "token_type_ids"):
968 | buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
969 | buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
970 | token_type_ids = buffered_token_type_ids_expanded
971 | else:
972 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
973 |
974 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
975 | # ourselves in which case we just need to make it broadcastable to all heads.
976 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
977 | # If a 2D or 3D attention mask is provided for the cross-attention
978 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
979 | if self.config.is_decoder and encoder_hidden_states is not None:
980 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
981 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
982 | if encoder_attention_mask is None:
983 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
984 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
985 | else:
986 | encoder_extended_attention_mask = None
987 |
988 | # Prepare head mask if needed
989 | # 1.0 in head_mask indicate we keep the head
990 | # attention_probs has shape bsz x n_heads x N x N
991 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
992 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
993 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) # [None]*12
994 |
995 | embedding_output = self.embeddings(
996 | input_ids=input_ids,
997 | position_ids=position_ids,
998 | token_type_ids=token_type_ids,
999 | inputs_embeds=inputs_embeds,
1000 | past_key_values_length=past_key_values_length,
1001 | )
1002 | encoder_outputs = self.encoder(
1003 | embedding_output,
1004 | attention_mask=extended_attention_mask,
1005 | head_mask=head_mask,
1006 | encoder_hidden_states=encoder_hidden_states,
1007 | encoder_attention_mask=encoder_extended_attention_mask,
1008 | past_key_values=past_key_values,
1009 | use_cache=use_cache,
1010 | output_attentions=output_attentions,
1011 | output_hidden_states=output_hidden_states,
1012 | return_dict=return_dict,
1013 | )
1014 | sequence_output = encoder_outputs[0]
1015 | pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1016 |
1017 | if not return_dict:
1018 | return (sequence_output, pooled_output) + encoder_outputs[1:]
1019 |
1020 | return BaseModelOutputWithPoolingAndCrossAttentions(
1021 | last_hidden_state=sequence_output,
1022 | pooler_output=pooled_output,
1023 | past_key_values=encoder_outputs.past_key_values,
1024 | hidden_states=encoder_outputs.hidden_states,
1025 | attentions=encoder_outputs.attentions,
1026 | cross_attentions=encoder_outputs.cross_attentions,
1027 | )
1028 |
1029 |
1030 | @add_start_docstrings(
1031 | """
1032 | Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
1033 | sentence prediction (classification)` head.
1034 | """,
1035 | BERT_START_DOCSTRING,
1036 | )
1037 | class BertForPreTraining(BertPreTrainedModel):
1038 | def __init__(self, config):
1039 | super().__init__(config)
1040 |
1041 | self.bert = BertModel(config)
1042 | self.cls = BertPreTrainingHeads(config)
1043 |
1044 | self.init_weights()
1045 |
1046 | def get_output_embeddings(self):
1047 | return self.cls.predictions.decoder
1048 |
1049 | def set_output_embeddings(self, new_embeddings):
1050 | self.cls.predictions.decoder = new_embeddings
1051 |
1052 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1053 | @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1054 | def forward(
1055 | self,
1056 | input_ids=None,
1057 | attention_mask=None,
1058 | token_type_ids=None,
1059 | position_ids=None,
1060 | head_mask=None,
1061 | inputs_embeds=None,
1062 | labels=None,
1063 | next_sentence_label=None,
1064 | output_attentions=None,
1065 | output_hidden_states=None,
1066 | return_dict=None,
1067 | ):
1068 | r"""
1069 | labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`):
1070 | Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1071 | config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1072 | (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1073 | next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
1074 | Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
1075 | (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:
1076 |
1077 | - 0 indicates sequence B is a continuation of sequence A,
1078 | - 1 indicates sequence B is a random sequence.
1079 | kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
1080 | Used to hide legacy arguments that have been deprecated.
1081 |
1082 | Returns:
1083 |
1084 | Example::
1085 |
1086 | >>> from transformers import BertTokenizer, BertForPreTraining
1087 | >>> import torch
1088 |
1089 | >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1090 | >>> model = BertForPreTraining.from_pretrained('bert-base-uncased')
1091 |
1092 | >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1093 | >>> outputs = model(**inputs)
1094 |
1095 | >>> prediction_logits = outputs.prediction_logits
1096 | >>> seq_relationship_logits = outputs.seq_relationship_logits
1097 | """
1098 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1099 |
1100 | outputs = self.bert(
1101 | input_ids,
1102 | attention_mask=attention_mask,
1103 | token_type_ids=token_type_ids,
1104 | position_ids=position_ids,
1105 | head_mask=head_mask,
1106 | inputs_embeds=inputs_embeds,
1107 | output_attentions=output_attentions,
1108 | output_hidden_states=output_hidden_states,
1109 | return_dict=return_dict,
1110 | )
1111 |
1112 | sequence_output, pooled_output = outputs[:2]
1113 | prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
1114 |
1115 | total_loss = None
1116 | if labels is not None and next_sentence_label is not None:
1117 | loss_fct = CrossEntropyLoss()
1118 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1119 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
1120 | total_loss = masked_lm_loss + next_sentence_loss
1121 |
1122 | if not return_dict:
1123 | output = (prediction_scores, seq_relationship_score) + outputs[2:]
1124 | return ((total_loss,) + output) if total_loss is not None else output
1125 |
1126 | return BertForPreTrainingOutput(
1127 | loss=total_loss,
1128 | prediction_logits=prediction_scores,
1129 | seq_relationship_logits=seq_relationship_score,
1130 | hidden_states=outputs.hidden_states,
1131 | attentions=outputs.attentions,
1132 | )
1133 |
1134 |
1135 | @add_start_docstrings(
1136 | """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
1137 | )
1138 | class BertLMHeadModel(BertPreTrainedModel):
1139 |
1140 | _keys_to_ignore_on_load_unexpected = [r"pooler"]
1141 | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1142 |
1143 | def __init__(self, config):
1144 | super().__init__(config)
1145 |
1146 | if not config.is_decoder:
1147 | logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")
1148 |
1149 | self.bert = BertModel(config, add_pooling_layer=False)
1150 | self.cls = BertOnlyMLMHead(config)
1151 |
1152 | self.init_weights()
1153 |
1154 | def get_output_embeddings(self):
1155 | return self.cls.predictions.decoder
1156 |
1157 | def set_output_embeddings(self, new_embeddings):
1158 | self.cls.predictions.decoder = new_embeddings
1159 |
1160 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1161 | @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
1162 | def forward(
1163 | self,
1164 | input_ids=None,
1165 | attention_mask=None,
1166 | token_type_ids=None,
1167 | position_ids=None,
1168 | head_mask=None,
1169 | inputs_embeds=None,
1170 | encoder_hidden_states=None,
1171 | encoder_attention_mask=None,
1172 | labels=None,
1173 | past_key_values=None,
1174 | use_cache=None,
1175 | output_attentions=None,
1176 | output_hidden_states=None,
1177 | return_dict=None,
1178 | ):
1179 | r"""
1180 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1181 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1182 | the model is configured as a decoder.
1183 | encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1184 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1185 | the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1186 |
1187 | - 1 for tokens that are **not masked**,
1188 | - 0 for tokens that are **masked**.
1189 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1190 | Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1191 | ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1192 | ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1193 | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1194 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1195 |
1196 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1197 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1198 | instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1199 | use_cache (:obj:`bool`, `optional`):
1200 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1201 | decoding (see :obj:`past_key_values`).
1202 |
1203 | Returns:
1204 |
1205 | Example::
1206 |
1207 | >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1208 | >>> import torch
1209 |
1210 | >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1211 | >>> config = BertConfig.from_pretrained("bert-base-cased")
1212 | >>> config.is_decoder = True
1213 | >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1214 |
1215 | >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1216 | >>> outputs = model(**inputs)
1217 |
1218 | >>> prediction_logits = outputs.logits
1219 | """
1220 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1221 | if labels is not None:
1222 | use_cache = False
1223 |
1224 | outputs = self.bert(
1225 | input_ids,
1226 | attention_mask=attention_mask,
1227 | token_type_ids=token_type_ids,
1228 | position_ids=position_ids,
1229 | head_mask=head_mask,
1230 | inputs_embeds=inputs_embeds,
1231 | encoder_hidden_states=encoder_hidden_states,
1232 | encoder_attention_mask=encoder_attention_mask,
1233 | past_key_values=past_key_values,
1234 | use_cache=use_cache,
1235 | output_attentions=output_attentions,
1236 | output_hidden_states=output_hidden_states,
1237 | return_dict=return_dict,
1238 | )
1239 |
1240 | sequence_output = outputs[0]
1241 | prediction_scores = self.cls(sequence_output)
1242 |
1243 | lm_loss = None
1244 | if labels is not None:
1245 | # we are doing next-token prediction; shift prediction scores and input ids by one
1246 | shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1247 | labels = labels[:, 1:].contiguous()
1248 | loss_fct = CrossEntropyLoss()
1249 | lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1250 |
1251 | if not return_dict:
1252 | output = (prediction_scores,) + outputs[2:]
1253 | return ((lm_loss,) + output) if lm_loss is not None else output
1254 |
1255 | return CausalLMOutputWithCrossAttentions(
1256 | loss=lm_loss,
1257 | logits=prediction_scores,
1258 | past_key_values=outputs.past_key_values,
1259 | hidden_states=outputs.hidden_states,
1260 | attentions=outputs.attentions,
1261 | cross_attentions=outputs.cross_attentions,
1262 | )
1263 |
1264 | def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
1265 | input_shape = input_ids.shape
1266 | # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1267 | if attention_mask is None:
1268 | attention_mask = input_ids.new_ones(input_shape)
1269 |
1270 | # cut decoder_input_ids if past is used
1271 | if past is not None:
1272 | input_ids = input_ids[:, -1:]
1273 |
1274 | return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
1275 |
1276 | def _reorder_cache(self, past, beam_idx):
1277 | reordered_past = ()
1278 | for layer_past in past:
1279 | reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1280 | return reordered_past
1281 |
1282 |
1283 | @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
1284 | class BertForMaskedLM(BertPreTrainedModel):
1285 |
1286 | _keys_to_ignore_on_load_unexpected = [r"pooler"]
1287 | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1288 |
1289 | def __init__(self, config):
1290 | super().__init__(config)
1291 |
1292 | if config.is_decoder:
1293 | logger.warning(
1294 | "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
1295 | "bi-directional self-attention."
1296 | )
1297 |
1298 | self.bert = BertModel(config, add_pooling_layer=False)
1299 | self.cls = BertOnlyMLMHead(config)
1300 |
1301 | self.init_weights()
1302 |
1303 | def get_output_embeddings(self):
1304 | return self.cls.predictions.decoder
1305 |
1306 | def set_output_embeddings(self, new_embeddings):
1307 | self.cls.predictions.decoder = new_embeddings
1308 |
1309 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1310 | @add_code_sample_docstrings(
1311 | tokenizer_class=_TOKENIZER_FOR_DOC,
1312 | checkpoint=_CHECKPOINT_FOR_DOC,
1313 | output_type=MaskedLMOutput,
1314 | config_class=_CONFIG_FOR_DOC,
1315 | )
1316 | def forward(
1317 | self,
1318 | input_ids=None,
1319 | attention_mask=None,
1320 | token_type_ids=None,
1321 | position_ids=None,
1322 | head_mask=None,
1323 | inputs_embeds=None,
1324 | encoder_hidden_states=None,
1325 | encoder_attention_mask=None,
1326 | labels=None,
1327 | output_attentions=None,
1328 | output_hidden_states=None,
1329 | return_dict=None,
1330 | ):
1331 | r"""
1332 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1333 | Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1334 | config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1335 | (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1336 | """
1337 |
1338 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1339 |
1340 | outputs = self.bert(
1341 | input_ids,
1342 | attention_mask=attention_mask,
1343 | token_type_ids=token_type_ids,
1344 | position_ids=position_ids,
1345 | head_mask=head_mask,
1346 | inputs_embeds=inputs_embeds,
1347 | encoder_hidden_states=encoder_hidden_states,
1348 | encoder_attention_mask=encoder_attention_mask,
1349 | output_attentions=output_attentions,
1350 | output_hidden_states=output_hidden_states,
1351 | return_dict=return_dict,
1352 | )
1353 |
1354 | sequence_output = outputs[0]
1355 | prediction_scores = self.cls(sequence_output)
1356 |
1357 | masked_lm_loss = None
1358 | if labels is not None:
1359 | loss_fct = CrossEntropyLoss() # -100 index = padding token
1360 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1361 |
1362 | if not return_dict:
1363 | output = (prediction_scores,) + outputs[2:]
1364 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1365 |
1366 | return MaskedLMOutput(
1367 | loss=masked_lm_loss,
1368 | logits=prediction_scores,
1369 | hidden_states=outputs.hidden_states,
1370 | attentions=outputs.attentions,
1371 | )
1372 |
1373 | def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1374 | input_shape = input_ids.shape
1375 | effective_batch_size = input_shape[0]
1376 |
1377 | # add a dummy token
1378 | assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
1379 | attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1380 | dummy_token = torch.full(
1381 | (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1382 | )
1383 | input_ids = torch.cat([input_ids, dummy_token], dim=1)
1384 |
1385 | return {"input_ids": input_ids, "attention_mask": attention_mask}
1386 |
1387 |
1388 | @add_start_docstrings(
1389 | """Bert Model with a `next sentence prediction (classification)` head on top. """,
1390 | BERT_START_DOCSTRING,
1391 | )
1392 | class BertForNextSentencePrediction(BertPreTrainedModel):
1393 | def __init__(self, config):
1394 | super().__init__(config)
1395 |
1396 | self.bert = BertModel(config)
1397 | self.cls = BertOnlyNSPHead(config)
1398 |
1399 | self.init_weights()
1400 |
1401 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1402 | @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
1403 | def forward(
1404 | self,
1405 | input_ids=None,
1406 | attention_mask=None,
1407 | token_type_ids=None,
1408 | position_ids=None,
1409 | head_mask=None,
1410 | inputs_embeds=None,
1411 | labels=None,
1412 | output_attentions=None,
1413 | output_hidden_states=None,
1414 | return_dict=None,
1415 | **kwargs,
1416 | ):
1417 | r"""
1418 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1419 | Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
1420 | (see ``input_ids`` docstring). Indices should be in ``[0, 1]``:
1421 |
1422 | - 0 indicates sequence B is a continuation of sequence A,
1423 | - 1 indicates sequence B is a random sequence.
1424 |
1425 | Returns:
1426 |
1427 | Example::
1428 |
1429 | >>> from transformers import BertTokenizer, BertForNextSentencePrediction
1430 | >>> import torch
1431 |
1432 | >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1433 | >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
1434 |
1435 | >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1436 | >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
1437 | >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
1438 |
1439 | >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
1440 | >>> logits = outputs.logits
1441 | >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
1442 | """
1443 |
1444 | if "next_sentence_label" in kwargs:
1445 | warnings.warn(
1446 | "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
1447 | FutureWarning,
1448 | )
1449 | labels = kwargs.pop("next_sentence_label")
1450 |
1451 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1452 |
1453 | outputs = self.bert(
1454 | input_ids,
1455 | attention_mask=attention_mask,
1456 | token_type_ids=token_type_ids,
1457 | position_ids=position_ids,
1458 | head_mask=head_mask,
1459 | inputs_embeds=inputs_embeds,
1460 | output_attentions=output_attentions,
1461 | output_hidden_states=output_hidden_states,
1462 | return_dict=return_dict,
1463 | )
1464 |
1465 | pooled_output = outputs[1]
1466 |
1467 | seq_relationship_scores = self.cls(pooled_output)
1468 |
1469 | next_sentence_loss = None
1470 | if labels is not None:
1471 | loss_fct = CrossEntropyLoss()
1472 | next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
1473 |
1474 | if not return_dict:
1475 | output = (seq_relationship_scores,) + outputs[2:]
1476 | return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
1477 |
1478 | return NextSentencePredictorOutput(
1479 | loss=next_sentence_loss,
1480 | logits=seq_relationship_scores,
1481 | hidden_states=outputs.hidden_states,
1482 | attentions=outputs.attentions,
1483 | )
1484 |
1485 |
1486 | @add_start_docstrings(
1487 | """
1488 | Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
1489 | output) e.g. for GLUE tasks.
1490 | """,
1491 | BERT_START_DOCSTRING,
1492 | )
1493 | class BertForSequenceClassification(BertPreTrainedModel):
1494 | def __init__(self, config):
1495 | super().__init__(config)
1496 | self.num_labels = config.num_labels
1497 | self.config = config
1498 |
1499 | self.bert = BertModel(config)
1500 | classifier_dropout = (
1501 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1502 | )
1503 | self.dropout = nn.Dropout(classifier_dropout)
1504 | self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1505 |
1506 | self.init_weights()
1507 |
1508 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1509 | @add_code_sample_docstrings(
1510 | tokenizer_class=_TOKENIZER_FOR_DOC,
1511 | checkpoint=_CHECKPOINT_FOR_DOC,
1512 | output_type=SequenceClassifierOutput,
1513 | config_class=_CONFIG_FOR_DOC,
1514 | )
1515 | def forward(
1516 | self,
1517 | input_ids=None,
1518 | attention_mask=None,
1519 | token_type_ids=None,
1520 | position_ids=None,
1521 | head_mask=None,
1522 | inputs_embeds=None,
1523 | labels=None,
1524 | output_attentions=None,
1525 | output_hidden_states=None,
1526 | return_dict=None,
1527 | ):
1528 | r"""
1529 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1530 | Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1531 | config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1532 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1533 | """
1534 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1535 |
1536 | outputs = self.bert(
1537 | input_ids,
1538 | attention_mask=attention_mask,
1539 | token_type_ids=token_type_ids,
1540 | position_ids=position_ids,
1541 | head_mask=head_mask,
1542 | inputs_embeds=inputs_embeds,
1543 | output_attentions=output_attentions,
1544 | output_hidden_states=output_hidden_states,
1545 | return_dict=return_dict,
1546 | )
1547 |
1548 | pooled_output = outputs[1]
1549 |
1550 | pooled_output = self.dropout(pooled_output)
1551 | logits = self.classifier(pooled_output)
1552 |
1553 | loss = None
1554 | if labels is not None:
1555 | if self.config.problem_type is None:
1556 | if self.num_labels == 1:
1557 | self.config.problem_type = "regression"
1558 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1559 | self.config.problem_type = "single_label_classification"
1560 | else:
1561 | self.config.problem_type = "multi_label_classification"
1562 |
1563 | if self.config.problem_type == "regression":
1564 | loss_fct = MSELoss()
1565 | if self.num_labels == 1:
1566 | loss = loss_fct(logits.squeeze(), labels.squeeze())
1567 | else:
1568 | loss = loss_fct(logits, labels)
1569 | elif self.config.problem_type == "single_label_classification":
1570 | loss_fct = CrossEntropyLoss()
1571 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1572 | elif self.config.problem_type == "multi_label_classification":
1573 | loss_fct = BCEWithLogitsLoss()
1574 | loss = loss_fct(logits, labels)
1575 | if not return_dict:
1576 | output = (logits,) + outputs[2:]
1577 | return ((loss,) + output) if loss is not None else output
1578 |
1579 | return SequenceClassifierOutput(
1580 | loss=loss,
1581 | logits=logits,
1582 | hidden_states=outputs.hidden_states,
1583 | attentions=outputs.attentions,
1584 | )
1585 |
1586 |
1587 | @add_start_docstrings(
1588 | """
1589 | Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1590 | softmax) e.g. for RocStories/SWAG tasks.
1591 | """,
1592 | BERT_START_DOCSTRING,
1593 | )
1594 | class BertForMultipleChoice(BertPreTrainedModel):
1595 | def __init__(self, config):
1596 | super().__init__(config)
1597 |
1598 | self.bert = BertModel(config)
1599 | classifier_dropout = (
1600 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1601 | )
1602 | self.dropout = nn.Dropout(classifier_dropout)
1603 | self.classifier = nn.Linear(config.hidden_size, 1)
1604 |
1605 | self.init_weights()
1606 |
1607 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1608 | @add_code_sample_docstrings(
1609 | tokenizer_class=_TOKENIZER_FOR_DOC,
1610 | checkpoint=_CHECKPOINT_FOR_DOC,
1611 | output_type=MultipleChoiceModelOutput,
1612 | config_class=_CONFIG_FOR_DOC,
1613 | )
1614 | def forward(
1615 | self,
1616 | input_ids=None,
1617 | attention_mask=None,
1618 | token_type_ids=None,
1619 | position_ids=None,
1620 | head_mask=None,
1621 | inputs_embeds=None,
1622 | labels=None,
1623 | output_attentions=None,
1624 | output_hidden_states=None,
1625 | return_dict=None,
1626 | ):
1627 | r"""
1628 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1629 | Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
1630 | num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
1631 | :obj:`input_ids` above)
1632 | """
1633 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1634 | num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1635 |
1636 | input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1637 | attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1638 | token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1639 | position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1640 | inputs_embeds = (
1641 | inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1642 | if inputs_embeds is not None
1643 | else None
1644 | )
1645 |
1646 | outputs = self.bert(
1647 | input_ids,
1648 | attention_mask=attention_mask,
1649 | token_type_ids=token_type_ids,
1650 | position_ids=position_ids,
1651 | head_mask=head_mask,
1652 | inputs_embeds=inputs_embeds,
1653 | output_attentions=output_attentions,
1654 | output_hidden_states=output_hidden_states,
1655 | return_dict=return_dict,
1656 | )
1657 |
1658 | pooled_output = outputs[1]
1659 |
1660 | pooled_output = self.dropout(pooled_output)
1661 | logits = self.classifier(pooled_output)
1662 | reshaped_logits = logits.view(-1, num_choices)
1663 |
1664 | loss = None
1665 | if labels is not None:
1666 | loss_fct = CrossEntropyLoss()
1667 | loss = loss_fct(reshaped_logits, labels)
1668 |
1669 | if not return_dict:
1670 | output = (reshaped_logits,) + outputs[2:]
1671 | return ((loss,) + output) if loss is not None else output
1672 |
1673 | return MultipleChoiceModelOutput(
1674 | loss=loss,
1675 | logits=reshaped_logits,
1676 | hidden_states=outputs.hidden_states,
1677 | attentions=outputs.attentions,
1678 | )
1679 |
1680 |
1681 | @add_start_docstrings(
1682 | """
1683 | Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1684 | Named-Entity-Recognition (NER) tasks.
1685 | """,
1686 | BERT_START_DOCSTRING,
1687 | )
1688 | class BertForTokenClassification(BertPreTrainedModel):
1689 |
1690 | _keys_to_ignore_on_load_unexpected = [r"pooler"]
1691 |
1692 | def __init__(self, config):
1693 | super().__init__(config)
1694 | self.num_labels = config.num_labels
1695 |
1696 | self.bert = BertModel(config, add_pooling_layer=False)
1697 | # classifier_dropout = (
1698 | # config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1699 | # )
1700 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
1701 | self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1702 |
1703 | self.init_weights()
1704 |
1705 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1706 | @add_code_sample_docstrings(
1707 | tokenizer_class=_TOKENIZER_FOR_DOC,
1708 | checkpoint=_CHECKPOINT_FOR_DOC,
1709 | output_type=TokenClassifierOutput,
1710 | config_class=_CONFIG_FOR_DOC,
1711 | )
1712 | def forward(
1713 | self,
1714 | input_ids=None,
1715 | attention_mask=None,
1716 | token_type_ids=None,
1717 | position_ids=None,
1718 | head_mask=None,
1719 | inputs_embeds=None,
1720 | labels=None,
1721 | output_attentions=None,
1722 | output_hidden_states=None,
1723 | # add:
1724 | past_key_values=None,
1725 | prompt_attention_mask=None,
1726 | #
1727 | return_dict=None,
1728 | ):
1729 | r"""
1730 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1731 | Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1732 | 1]``.
1733 | """
1734 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1735 |
1736 | outputs = self.bert(
1737 | input_ids,
1738 | attention_mask=prompt_attention_mask,
1739 | token_type_ids=token_type_ids,
1740 | position_ids=position_ids,
1741 | head_mask=head_mask,
1742 | inputs_embeds=inputs_embeds,
1743 | # add:
1744 | past_key_values=past_key_values,
1745 | #
1746 | output_attentions=output_attentions,
1747 | output_hidden_states=output_hidden_states,
1748 | return_dict=return_dict,
1749 | )
1750 |
1751 | sequence_output = outputs[0]
1752 |
1753 | sequence_output = self.dropout(sequence_output)
1754 | logits = self.classifier(sequence_output)
1755 |
1756 | loss = None
1757 | if labels is not None:
1758 | loss_fct = CrossEntropyLoss()
1759 | # Only keep active parts of the loss
1760 | if attention_mask is not None:
1761 | active_loss = attention_mask.view(-1) == 1
1762 | active_logits = logits.view(-1, self.num_labels)
1763 | active_labels = torch.where(
1764 | active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
1765 | )
1766 | loss = loss_fct(active_logits, active_labels)
1767 | else:
1768 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1769 |
1770 | if not return_dict:
1771 | output = (logits,) + outputs[2:]
1772 | return ((loss,) + output) if loss is not None else output
1773 |
1774 | return TokenClassifierOutput(
1775 | loss=loss,
1776 | logits=logits,
1777 | hidden_states=outputs.hidden_states,
1778 | attentions=outputs.attentions,
1779 | )
1780 |
1781 |
1782 | @add_start_docstrings(
1783 | """
1784 | Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1785 | layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1786 | """,
1787 | BERT_START_DOCSTRING,
1788 | )
1789 | class BertForQuestionAnswering(BertPreTrainedModel):
1790 |
1791 | _keys_to_ignore_on_load_unexpected = [r"pooler"]
1792 |
1793 | def __init__(self, config):
1794 | super().__init__(config)
1795 | self.num_labels = config.num_labels
1796 |
1797 | self.bert = BertModel(config, add_pooling_layer=False)
1798 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1799 |
1800 | self.init_weights()
1801 |
1802 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1803 | @add_code_sample_docstrings(
1804 | tokenizer_class=_TOKENIZER_FOR_DOC,
1805 | checkpoint=_CHECKPOINT_FOR_DOC,
1806 | output_type=QuestionAnsweringModelOutput,
1807 | config_class=_CONFIG_FOR_DOC,
1808 | )
1809 | def forward(
1810 | self,
1811 | input_ids=None,
1812 | attention_mask=None,
1813 | token_type_ids=None,
1814 | position_ids=None,
1815 | head_mask=None,
1816 | inputs_embeds=None,
1817 | start_positions=None,
1818 | end_positions=None,
1819 | output_attentions=None,
1820 | output_hidden_states=None,
1821 | return_dict=None,
1822 | ):
1823 | r"""
1824 | start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1825 | Labels for position (index) of the start of the labelled span for computing the token classification loss.
1826 | Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
1827 | sequence are not taken into account for computing the loss.
1828 | end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1829 | Labels for position (index) of the end of the labelled span for computing the token classification loss.
1830 | Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
1831 | sequence are not taken into account for computing the loss.
1832 | """
1833 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1834 |
1835 | outputs = self.bert(
1836 | input_ids,
1837 | attention_mask=attention_mask,
1838 | token_type_ids=token_type_ids,
1839 | position_ids=position_ids,
1840 | head_mask=head_mask,
1841 | inputs_embeds=inputs_embeds,
1842 | output_attentions=output_attentions,
1843 | output_hidden_states=output_hidden_states,
1844 | return_dict=return_dict,
1845 | )
1846 |
1847 | sequence_output = outputs[0]
1848 |
1849 | logits = self.qa_outputs(sequence_output)
1850 | start_logits, end_logits = logits.split(1, dim=-1)
1851 | start_logits = start_logits.squeeze(-1).contiguous()
1852 | end_logits = end_logits.squeeze(-1).contiguous()
1853 |
1854 | total_loss = None
1855 | if start_positions is not None and end_positions is not None:
1856 | # If we are on multi-GPU, split add a dimension
1857 | if len(start_positions.size()) > 1:
1858 | start_positions = start_positions.squeeze(-1)
1859 | if len(end_positions.size()) > 1:
1860 | end_positions = end_positions.squeeze(-1)
1861 | # sometimes the start/end positions are outside our model inputs, we ignore these terms
1862 | ignored_index = start_logits.size(1)
1863 | start_positions = start_positions.clamp(0, ignored_index)
1864 | end_positions = end_positions.clamp(0, ignored_index)
1865 |
1866 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1867 | start_loss = loss_fct(start_logits, start_positions)
1868 | end_loss = loss_fct(end_logits, end_positions)
1869 | total_loss = (start_loss + end_loss) / 2
1870 |
1871 | if not return_dict:
1872 | output = (start_logits, end_logits) + outputs[2:]
1873 | return ((total_loss,) + output) if total_loss is not None else output
1874 |
1875 | return QuestionAnsweringModelOutput(
1876 | loss=total_loss,
1877 | start_logits=start_logits,
1878 | end_logits=end_logits,
1879 | hidden_states=outputs.hidden_states,
1880 | attentions=outputs.attentions,
1881 | )
--------------------------------------------------------------------------------
/modules/metrics.py:
--------------------------------------------------------------------------------
1 | def eval_result(true_labels, pred_result, rel2id, logger, use_name=False):
2 | correct = 0
3 | total = len(true_labels)
4 | correct_positive = 0
5 | pred_positive = 0
6 | gold_positive = 0
7 |
8 | neg = -1
9 | for name in ['NA', 'na', 'no_relation', 'Other', 'Others', 'none', 'None']:
10 | if name in rel2id:
11 | if use_name:
12 | neg = name
13 | else:
14 | neg = rel2id[name]
15 | break
16 | for i in range(total):
17 | if use_name:
18 | golden = true_labels[i]
19 | else:
20 | golden = true_labels[i]
21 |
22 | if golden == pred_result[i]:
23 | correct += 1
24 | if golden != neg:
25 | correct_positive += 1
26 | if golden != neg:
27 | gold_positive += 1
28 | if pred_result[i] != neg:
29 | pred_positive += 1
30 | acc = float(correct) / float(total)
31 | try:
32 | micro_p = float(correct_positive) / float(pred_positive)
33 | except:
34 | micro_p = 0
35 | try:
36 | micro_r = float(correct_positive) / float(gold_positive)
37 | except:
38 | micro_r = 0
39 | try:
40 | micro_f1 = 2 * micro_p * micro_r / (micro_p + micro_r)
41 | except:
42 | micro_f1 = 0
43 |
44 | result = {'acc': acc, 'micro_p': micro_p, 'micro_r': micro_r, 'micro_f1': micro_f1}
45 | logger.info('Evaluation result: {}.'.format(result))
46 | return result
--------------------------------------------------------------------------------
/modules/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import optim
3 | from tqdm import tqdm
4 | import random
5 | from sklearn.metrics import classification_report as sk_classification_report
6 | from seqeval.metrics import classification_report
7 | from transformers.optimization import get_linear_schedule_with_warmup
8 |
9 | from .metrics import eval_result
10 |
11 | class BaseTrainer(object):
12 | def train(self):
13 | raise NotImplementedError()
14 |
15 | def evaluate(self):
16 | raise NotImplementedError()
17 |
18 | def test(self):
19 | raise NotImplementedError()
20 |
21 | class RETrainer(BaseTrainer):
22 | def __init__(self, train_data=None, dev_data=None, test_data=None, model=None, processor=None, args=None, logger=None, writer=None) -> None:
23 | self.train_data = train_data
24 | self.dev_data = dev_data
25 | self.test_data = test_data
26 | self.model = model
27 | self.processor = processor
28 | self.re_dict = processor.get_relation_dict()
29 | self.logger = logger
30 | self.writer = writer
31 | self.refresh_step = 2
32 | self.best_dev_metric = 0
33 | self.best_test_metric = 0
34 | self.best_dev_epoch = None
35 | self.best_test_epoch = None
36 | self.optimizer = None
37 | if self.train_data is not None:
38 | self.train_num_steps = len(self.train_data) * args.num_epochs
39 | self.step = 0
40 | self.args = args
41 | if self.args.use_prompt:
42 | self.before_multimodal_train()
43 | else:
44 | self.before_train()
45 |
46 | def train(self):
47 | self.step = 0
48 | self.model.train()
49 | self.logger.info("***** Running training *****")
50 | self.logger.info(" Num instance = %d", len(self.train_data)*self.args.batch_size)
51 | self.logger.info(" Num epoch = %d", self.args.num_epochs)
52 | self.logger.info(" Batch size = %d", self.args.batch_size)
53 | self.logger.info(" Learning rate = {}".format(self.args.lr))
54 | self.logger.info(" Evaluate begin = %d", self.args.eval_begin_epoch)
55 |
56 | if self.args.load_path is not None: # load model from load_path
57 | self.logger.info("Loading model from {}".format(self.args.load_path))
58 | self.model.load_state_dict(torch.load(self.args.load_path))
59 | self.logger.info("Load model successful!")
60 |
61 | with tqdm(total=self.train_num_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True, initial=self.step) as pbar:
62 | self.pbar = pbar
63 | avg_loss = 0
64 | for epoch in range(1, self.args.num_epochs+1):
65 | pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.args.num_epochs))
66 | for batch in self.train_data:
67 | self.step += 1
68 | batch = (tup.to(self.args.device) if isinstance(tup, torch.Tensor) else tup for tup in batch)
69 | (loss, logits), labels = self._step(batch, mode="train")
70 | avg_loss += loss.detach().cpu().item()
71 |
72 | loss.backward()
73 | self.optimizer.step()
74 | self.scheduler.step()
75 | self.optimizer.zero_grad()
76 |
77 | if self.step % self.refresh_step == 0:
78 | avg_loss = float(avg_loss) / self.refresh_step
79 | print_output = "loss:{:<6.5f}".format(avg_loss)
80 | pbar.update(self.refresh_step)
81 | pbar.set_postfix_str(print_output)
82 | if self.writer:
83 | self.writer.add_scalar(tag='train_loss', scalar_value=avg_loss, global_step=self.step) # tensorbordx
84 | avg_loss = 0
85 |
86 | if epoch >= self.args.eval_begin_epoch:
87 | self.evaluate(epoch) # generator to dev.
88 |
89 | pbar.close()
90 | self.pbar = None
91 | self.logger.info("Get best dev performance at epoch {}, best dev f1 score is {}".format(self.best_dev_epoch, self.best_dev_metric))
92 | self.logger.info("Get best test performance at epoch {}, best test f1 score is {}".format(self.best_test_epoch, self.best_test_metric))
93 |
94 | def evaluate(self, epoch):
95 | self.model.eval()
96 | self.logger.info("***** Running evaluate *****")
97 | self.logger.info(" Num instance = %d", len(self.dev_data)*self.args.batch_size)
98 | self.logger.info(" Batch size = %d", self.args.batch_size)
99 | step = 0
100 | true_labels, pred_labels = [], []
101 | with torch.no_grad():
102 | with tqdm(total=len(self.dev_data), leave=False, dynamic_ncols=True) as pbar:
103 | pbar.set_description_str(desc="Dev")
104 | total_loss = 0
105 | for batch in self.dev_data:
106 | step += 1
107 | batch = (tup.to(self.args.device) if isinstance(tup, torch.Tensor) else tup for tup in batch) # to cpu/cuda device
108 | (loss, logits), labels = self._step(batch, mode="dev") # logits: batch, 3
109 | total_loss += loss.detach().cpu().item()
110 |
111 | preds = logits.argmax(-1)
112 | true_labels.extend(labels.view(-1).detach().cpu().tolist())
113 | pred_labels.extend(preds.view(-1).detach().cpu().tolist())
114 | pbar.update()
115 | # evaluate done
116 | pbar.close()
117 | sk_result = sk_classification_report(y_true=true_labels, y_pred=pred_labels, labels=list(self.re_dict.values())[1:], target_names=list(self.re_dict.keys())[1:], digits=4)
118 | self.logger.info("%s\n", sk_result)
119 | result = eval_result(true_labels, pred_labels, self.re_dict, self.logger)
120 | acc, micro_f1 = round(result['acc']*100, 4), round(result['micro_f1']*100, 4)
121 | if self.writer:
122 | self.writer.add_scalar(tag='dev_acc', scalar_value=acc, global_step=epoch) # tensorbordx
123 | self.writer.add_scalar(tag='dev_f1', scalar_value=micro_f1, global_step=epoch) # tensorbordx
124 | self.writer.add_scalar(tag='dev_loss', scalar_value=total_loss/len(self.dev_data), global_step=epoch) # tensorbordx
125 |
126 | self.logger.info("Epoch {}/{}, best dev f1: {}, best epoch: {}, current dev f1 score: {}, acc: {}."\
127 | .format(epoch, self.args.num_epochs, self.best_dev_metric, self.best_dev_epoch, micro_f1, acc))
128 | if micro_f1 >= self.best_dev_metric: # this epoch get best performance
129 | self.logger.info("Get better performance at epoch {}".format(epoch))
130 | self.best_dev_epoch = epoch
131 | self.best_dev_metric = micro_f1 # update best metric(f1 score)
132 | if self.args.save_path is not None:
133 | torch.save(self.model.state_dict(), self.args.save_path+"/best_model.pth")
134 | self.logger.info("Save best model at {}".format(self.args.save_path))
135 |
136 |
137 | self.model.train()
138 |
139 | def test(self):
140 | self.model.eval()
141 | self.logger.info("\n***** Running testing *****")
142 | self.logger.info(" Num instance = %d", len(self.test_data)*self.args.batch_size)
143 | self.logger.info(" Batch size = %d", self.args.batch_size)
144 |
145 | if self.args.load_path is not None: # load model from load_path
146 | self.logger.info("Loading model from {}".format(self.args.load_path))
147 | self.model.load_state_dict(torch.load(self.args.load_path))
148 | self.logger.info("Load model successful!")
149 | true_labels, pred_labels = [], []
150 | with torch.no_grad():
151 | with tqdm(total=len(self.test_data), leave=False, dynamic_ncols=True) as pbar:
152 | pbar.set_description_str(desc="Testing")
153 | total_loss = 0
154 | for batch in self.test_data:
155 | batch = (tup.to(self.args.device) if isinstance(tup, torch.Tensor) else tup for tup in batch) # to cpu/cuda device
156 | (loss, logits), labels = self._step(batch, mode="dev") # logits: batch, 3
157 | total_loss += loss.detach().cpu().item()
158 |
159 | preds = logits.argmax(-1)
160 | true_labels.extend(labels.view(-1).detach().cpu().tolist())
161 | pred_labels.extend(preds.view(-1).detach().cpu().tolist())
162 |
163 | pbar.update()
164 | # evaluate done
165 | pbar.close()
166 | sk_result = sk_classification_report(y_true=true_labels, y_pred=pred_labels, labels=list(self.re_dict.values())[1:], target_names=list(self.re_dict.keys())[1:], digits=4)
167 | self.logger.info("%s\n", sk_result)
168 | result = eval_result(true_labels, pred_labels, self.re_dict, self.logger)
169 | acc, micro_f1 = round(result['acc']*100, 4), round(result['micro_f1']*100, 4)
170 | if self.writer:
171 | self.writer.add_scalar(tag='test_acc', scalar_value=acc) # tensorbordx
172 | self.writer.add_scalar(tag='test_f1', scalar_value=micro_f1) # tensorbordx
173 | self.writer.add_scalar(tag='test_loss', scalar_value=total_loss/len(self.test_data)) # tensorbordx
174 | total_loss = 0
175 | self.logger.info("Test f1 score: {}, acc: {}.".format(micro_f1, acc))
176 |
177 | self.model.train()
178 |
179 | def _step(self, batch, mode="train"):
180 | if mode != "predict":
181 | if self.args.use_prompt:
182 | input_ids, token_type_ids, attention_mask, labels, images, aux_imgs = batch
183 | else:
184 | images, aux_imgs = None, None
185 | input_ids, token_type_ids, attention_mask, labels= batch
186 | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels, images=images, aux_imgs=aux_imgs)
187 | return outputs, labels
188 |
189 | def before_train(self):
190 | no_decay = ['bias', 'LayerNorm.weight']
191 | optimizer_grouped_parameters = [
192 | {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
193 | {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
194 | ]
195 | self.optimizer = optim.AdamW(optimizer_grouped_parameters, lr=self.args.lr)
196 | self.scheduler = get_linear_schedule_with_warmup(optimizer=self.optimizer,
197 | num_warmup_steps=self.args.warmup_ratio*self.train_num_steps,
198 | num_training_steps=self.train_num_steps)
199 | self.model.to(self.args.device)
200 |
201 |
202 | def before_multimodal_train(self):
203 | optimizer_grouped_parameters = []
204 | params = {'lr':self.args.lr, 'weight_decay':1e-2}
205 | params['params'] = []
206 | for name, param in self.model.named_parameters():
207 | if 'bert' in name:
208 | params['params'].append(param)
209 | optimizer_grouped_parameters.append(params)
210 |
211 | params = {'lr':self.args.lr, 'weight_decay':1e-2}
212 | params['params'] = []
213 | for name, param in self.model.named_parameters():
214 | if 'encoder_conv' in name or 'gates' in name:
215 | params['params'].append(param)
216 | optimizer_grouped_parameters.append(params)
217 |
218 | # freeze resnet
219 | for name, param in self.model.named_parameters():
220 | if 'image_model' in name:
221 | param.require_grad = False
222 | self.optimizer = optim.AdamW(optimizer_grouped_parameters, lr=self.args.lr)
223 | self.scheduler = get_linear_schedule_with_warmup(optimizer=self.optimizer,
224 | num_warmup_steps=self.args.warmup_ratio*self.train_num_steps,
225 | num_training_steps=self.train_num_steps)
226 | self.model.to(self.args.device)
227 |
228 |
229 | class NERTrainer(BaseTrainer):
230 | def __init__(self, train_data=None, dev_data=None, test_data=None, model=None, processor=None, label_map=None, args=None, logger=None, writer=None) -> None:
231 | self.train_data = train_data
232 | self.dev_data = dev_data
233 | self.test_data = test_data
234 | self.model = model
235 | self.processor = processor
236 | self.logger = logger
237 | self.label_map = label_map
238 | self.writer = writer
239 | self.refresh_step = 2
240 | self.best_dev_metric = 0
241 | self.best_test_metric = 0
242 | self.best_train_metric = 0
243 | self.best_dev_epoch = None
244 | self.best_test_epoch = None
245 | self.best_train_epoch = None
246 | self.optimizer = None
247 | if self.train_data is not None:
248 | self.train_num_steps = len(self.train_data) * args.num_epochs
249 | self.step = 0
250 | self.args = args
251 |
252 | def train(self):
253 | if self.args.use_prompt:
254 | self.multiModal_before_train()
255 | else:
256 | self.bert_before_train()
257 |
258 | self.step = 0
259 | self.model.train()
260 | self.logger.info("***** Running training *****")
261 | self.logger.info(" Num instance = %d", len(self.train_data)*self.args.batch_size)
262 | self.logger.info(" Num epoch = %d", self.args.num_epochs)
263 | self.logger.info(" Batch size = %d", self.args.batch_size)
264 | self.logger.info(" Learning rate = {}".format(self.args.lr))
265 | self.logger.info(" Evaluate begin = %d", self.args.eval_begin_epoch)
266 |
267 | if self.args.load_path is not None: # load model from load_path
268 | self.logger.info("Loading model from {}".format(self.args.load_path))
269 | self.model.load_state_dict(torch.load(self.args.load_path))
270 | self.logger.info("Load model successful!")
271 |
272 | with tqdm(total=self.train_num_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True, initial=self.step) as pbar:
273 | self.pbar = pbar
274 | avg_loss = 0
275 | for epoch in range(1, self.args.num_epochs+1):
276 | y_true, y_pred = [], []
277 | pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.args.num_epochs))
278 | for batch in self.train_data:
279 | self.step += 1
280 | batch = (tup.to(self.args.device) if isinstance(tup, torch.Tensor) else tup for tup in batch)
281 | attention_mask, labels, logits, loss = self._step(batch, mode="train")
282 | avg_loss += loss.detach().cpu().item()
283 |
284 | loss.backward()
285 | self.optimizer.step()
286 | self.scheduler.step()
287 | self.optimizer.zero_grad()
288 |
289 | if isinstance(logits, torch.Tensor): # CRF return lists
290 | logits = logits.argmax(-1).detach().cpu().numpy() # batch, seq, 1
291 | label_ids = labels.to('cpu').numpy()
292 | input_mask = attention_mask.to('cpu').numpy()
293 | label_map = {idx:label for label, idx in self.label_map.items()}
294 |
295 | for row, mask_line in enumerate(input_mask):
296 | true_label = []
297 | true_predict = []
298 | for column, mask in enumerate(mask_line):
299 | if column == 0:
300 | continue
301 | if mask:
302 | if label_map[label_ids[row][column]] != "X" and label_map[label_ids[row][column]] != "[SEP]":
303 | true_label.append(label_map[label_ids[row][column]])
304 | true_predict.append(label_map[logits[row][column]])
305 | else:
306 | break
307 | y_true.append(true_label)
308 | y_pred.append(true_predict)
309 |
310 | if self.step % self.refresh_step == 0:
311 | avg_loss = float(avg_loss) / self.refresh_step
312 | print_output = "loss:{:<6.5f}".format(avg_loss)
313 | pbar.update(self.refresh_step)
314 | pbar.set_postfix_str(print_output)
315 | if self.writer:
316 | self.writer.add_scalar(tag='train_loss', scalar_value=avg_loss, global_step=self.step) # tensorbordx
317 | avg_loss = 0
318 | results = classification_report(y_true, y_pred, digits=4)
319 | self.logger.info("***** Train Eval results *****")
320 | self.logger.info("\n%s", results)
321 | f1_score = float(results.split('\n')[-4].split(' ')[0].split(' ')[3])
322 | if self.writer:
323 | self.writer.add_scalar(tag='train_f1', scalar_value=f1_score, global_step=epoch) # tensorbordx
324 | self.logger.info("Epoch {}/{}, best train f1: {}, best epoch: {}, current train f1 score: {}."\
325 | .format(epoch, self.args.num_epochs, self.best_train_metric, self.best_train_epoch, f1_score))
326 | if f1_score > self.best_train_metric:
327 | self.best_train_metric = f1_score
328 | self.best_train_epoch = epoch
329 |
330 | if epoch >= self.args.eval_begin_epoch:
331 | self.evaluate(epoch) # generator to dev.
332 |
333 | torch.cuda.empty_cache()
334 |
335 | pbar.close()
336 | self.pbar = None
337 | self.logger.info("Get best dev performance at epoch {}, best dev f1 score is {}".format(self.best_dev_epoch, self.best_dev_metric))
338 | self.logger.info("Get best test performance at epoch {}, best test f1 score is {}".format(self.best_test_epoch, self.best_test_metric))
339 |
340 | def evaluate(self, epoch):
341 | self.model.eval()
342 | self.logger.info("***** Running evaluate *****")
343 | self.logger.info(" Num instance = %d", len(self.dev_data)*self.args.batch_size)
344 | self.logger.info(" Batch size = %d", self.args.batch_size)
345 |
346 | y_true, y_pred = [], []
347 | step = 0
348 | with torch.no_grad():
349 | with tqdm(total=len(self.dev_data), leave=False, dynamic_ncols=True) as pbar:
350 | pbar.set_description_str(desc="Dev")
351 | total_loss = 0
352 | for batch in self.dev_data:
353 | step += 1
354 | batch = (tup.to(self.args.device) if isinstance(tup, torch.Tensor) else tup for tup in batch) # to cpu/cuda device
355 | attention_mask, labels, logits, loss = self._step(batch, mode="dev") # logits: batch, seq, num_labels
356 | total_loss += loss.detach().cpu().item()
357 |
358 | if isinstance(logits, torch.Tensor):
359 | logits = logits.argmax(-1).detach().cpu().numpy() # batch, seq, 1
360 | label_ids = labels.detach().cpu().numpy()
361 | input_mask = attention_mask.detach().cpu().numpy()
362 | label_map = {idx:label for label, idx in self.label_map.items()}
363 | for row, mask_line in enumerate(input_mask):
364 | true_label = []
365 | true_predict = []
366 | for column, mask in enumerate(mask_line):
367 | if column == 0:
368 | continue
369 | if mask:
370 | if label_map[label_ids[row][column]] != "X" and label_map[label_ids[row][column]] != "[SEP]":
371 | true_label.append(label_map[label_ids[row][column]])
372 | true_predict.append(label_map[logits[row][column]])
373 | else:
374 | break
375 | y_true.append(true_label)
376 | y_pred.append(true_predict)
377 |
378 | pbar.update()
379 | pbar.close()
380 | results = classification_report(y_true, y_pred, digits=4)
381 | self.logger.info("***** Dev Eval results *****")
382 | self.logger.info("\n%s", results)
383 | f1_score = float(results.split('\n')[-4].split(' ')[-2].split(' ')[-1])
384 | if self.writer:
385 | self.writer.add_scalar(tag='dev_f1', scalar_value=f1_score, global_step=epoch) # tensorbordx
386 | self.writer.add_scalar(tag='dev_loss', scalar_value=total_loss/step, global_step=epoch) # tensorbordx
387 |
388 | self.logger.info("Epoch {}/{}, best dev f1: {}, best epoch: {}, current dev f1 score: {}."\
389 | .format(epoch, self.args.num_epochs, self.best_dev_metric, self.best_dev_epoch, f1_score))
390 | if f1_score >= self.best_dev_metric: # this epoch get best performance
391 | self.logger.info("Get better performance at epoch {}".format(epoch))
392 | self.best_dev_epoch = epoch
393 | self.best_dev_metric = f1_score # update best metric(f1 score)
394 | if self.args.save_path is not None:
395 | torch.save(self.model.state_dict(), self.args.save_path+"/best_model.pth")
396 | self.logger.info("Save best model at {}".format(self.args.save_path))
397 |
398 | self.model.train()
399 |
400 | def test(self):
401 | self.model.eval()
402 | self.logger.info("\n***** Running testing *****")
403 | self.logger.info(" Num instance = %d", len(self.test_data)*self.args.batch_size)
404 | self.logger.info(" Batch size = %d", self.args.batch_size)
405 |
406 | if self.args.load_path is not None: # load model from load_path
407 | self.logger.info("Loading model from {}".format(self.args.load_path))
408 | self.model.load_state_dict(torch.load(self.args.load_path))
409 | self.logger.info("Load model successful!")
410 | y_true, y_pred = [], []
411 | with torch.no_grad():
412 | with tqdm(total=len(self.test_data), leave=False, dynamic_ncols=True) as pbar:
413 | pbar.set_description_str(desc="Testing")
414 | total_loss = 0
415 | for batch in self.test_data:
416 | batch = (tup.to(self.args.device) if isinstance(tup, torch.Tensor) else tup for tup in batch) # to cpu/cuda device
417 | attention_mask, labels, logits, loss = self._step(batch, mode="dev") # logits: batch, seq, num_labels
418 | total_loss += loss.detach().cpu().item()
419 |
420 | if isinstance(logits, torch.Tensor):
421 | logits = logits.argmax(-1).detach().cpu().tolist() # batch, seq, 1
422 | label_ids = labels.detach().cpu().numpy()
423 | input_mask = attention_mask.detach().cpu().numpy()
424 | label_map = {idx:label for label, idx in self.label_map.items()}
425 | for row, mask_line in enumerate(input_mask):
426 | true_label = []
427 | true_predict = []
428 | for column, mask in enumerate(mask_line):
429 | if column == 0:
430 | continue
431 | if mask:
432 | if label_map[label_ids[row][column]] != "X" and label_map[label_ids[row][column]] != "[SEP]":
433 | true_label.append(label_map[label_ids[row][column]])
434 | true_predict.append(label_map[logits[row][column]])
435 | else:
436 | break
437 | y_true.append(true_label)
438 | y_pred.append(true_predict)
439 | pbar.update()
440 | # evaluate done
441 | pbar.close()
442 |
443 | results = classification_report(y_true, y_pred, digits=4)
444 |
445 | self.logger.info("***** Test Eval results *****")
446 | self.logger.info("\n%s", results)
447 | f1_score = float(results.split('\n')[-4].split(' ')[-2].split(' ')[-1])
448 | if self.writer:
449 | self.writer.add_scalar(tag='test_f1', scalar_value=f1_score) # tensorbordx
450 | self.writer.add_scalar(tag='test_loss', scalar_value=total_loss/len(self.test_data)) # tensorbordx
451 | total_loss = 0
452 | self.logger.info("Test f1 score: {}.".format(f1_score))
453 |
454 | self.model.train()
455 |
456 | def _step(self, batch, mode="train"):
457 | if self.args.use_prompt:
458 | input_ids, token_type_ids, attention_mask, labels, images, aux_imgs = batch
459 | else:
460 | images, aux_imgs = None, None
461 | input_ids, token_type_ids, attention_mask, labels = batch
462 | output = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels, images=images, aux_imgs=aux_imgs)
463 | logits, loss = output.logits, output.loss
464 | return attention_mask, labels, logits, loss
465 |
466 |
467 |
468 | def bert_before_train(self):
469 | self.optimizer = optim.AdamW(self.model.parameters(), lr=self.args.lr)
470 |
471 | self.model.to(self.args.device)
472 | self.scheduler = get_linear_schedule_with_warmup(optimizer=self.optimizer,
473 | num_warmup_steps=self.args.warmup_ratio*self.train_num_steps,
474 | num_training_steps=self.train_num_steps)
475 |
476 | def multiModal_before_train(self):
477 | # bert lr
478 | parameters = []
479 | params = {'lr':self.args.lr, 'weight_decay':1e-2}
480 | params['params'] = []
481 | for name, param in self.model.named_parameters():
482 | if 'bert' in name:
483 | params['params'].append(param)
484 | parameters.append(params)
485 |
486 | # prompt lr
487 | params = {'lr':self.args.lr, 'weight_decay':1e-2}
488 | params['params'] = []
489 | for name, param in self.model.named_parameters():
490 | if 'encoder_conv' in name or 'gates' in name:
491 | params['params'].append(param)
492 | parameters.append(params)
493 |
494 | # crf lr
495 | params = {'lr':5e-2, 'weight_decay':1e-2}
496 | params['params'] = []
497 | for name, param in self.model.named_parameters():
498 | if 'crf' in name or name.startswith('fc'):
499 | params['params'].append(param)
500 | parameters.append(params)
501 |
502 | self.optimizer = optim.AdamW(parameters)
503 |
504 | for name, par in self.model.named_parameters(): # freeze resnet
505 | if 'image_model' in name: par.requires_grad = False
506 |
507 | self.scheduler = get_linear_schedule_with_warmup(optimizer=self.optimizer,
508 | num_warmup_steps=self.args.warmup_ratio*self.train_num_steps,
509 | num_training_steps=self.train_num_steps)
510 | self.model.to(self.args.device)
511 |
--------------------------------------------------------------------------------
/processor/dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | import os
3 | import torch
4 | import json
5 | import ast
6 | from PIL import Image
7 | from torch.utils.data import Dataset, DataLoader
8 | from transformers import BertTokenizer
9 | from torchvision import transforms
10 | import logging
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class MMREProcessor(object):
15 | def __init__(self, data_path, bert_name):
16 | self.data_path = data_path
17 | self.re_path = data_path['re_path']
18 | self.tokenizer = BertTokenizer.from_pretrained(bert_name, do_lower_case=True)
19 | self.tokenizer.add_special_tokens({'additional_special_tokens':['', '', '', '']})
20 |
21 | def load_from_file(self, mode="train", sample_ratio=1.0):
22 | """
23 | Args:
24 | mode (str, optional): dataset mode. Defaults to "train".
25 | sample_ratio (float, optional): sample ratio in low resouce. Defaults to 1.0.
26 | """
27 | load_file = self.data_path[mode]
28 | logger.info("Loading data from {}".format(load_file))
29 | with open(load_file, "r", encoding="utf-8") as f:
30 | lines = f.readlines()
31 | words, relations, heads, tails, imgids, dataid = [], [], [], [], [], []
32 | for i, line in enumerate(lines):
33 | line = ast.literal_eval(line) # str to dict
34 | words.append(line['token'])
35 | relations.append(line['relation'])
36 | heads.append(line['h']) # {name, pos}
37 | tails.append(line['t'])
38 | imgids.append(line['img_id'])
39 | dataid.append(i)
40 |
41 | assert len(words) == len(relations) == len(heads) == len(tails) == (len(imgids))
42 |
43 | # aux image
44 | aux_path = self.data_path[mode+"_auximgs"]
45 | aux_imgs = torch.load(aux_path)
46 |
47 | # sample
48 | if sample_ratio != 1.0:
49 | sample_indexes = random.choices(list(range(len(words))), k=int(len(words)*sample_ratio))
50 | sample_words = [words[idx] for idx in sample_indexes]
51 | sample_relations = [relations[idx] for idx in sample_indexes]
52 | sample_heads = [heads[idx] for idx in sample_indexes]
53 | sample_tails = [tails[idx] for idx in sample_indexes]
54 | sample_imgids = [imgids[idx] for idx in sample_indexes]
55 | sample_dataid = [dataid[idx] for idx in sample_indexes]
56 | assert len(sample_words) == len(sample_relations) == len(sample_imgids), "{}, {}, {}".format(len(sample_words), len(sample_relations), len(sample_imgids))
57 | return {'words':sample_words, 'relations':sample_relations, 'heads':sample_heads, 'tails':sample_tails, \
58 | 'imgids':sample_imgids, 'dataid': sample_dataid, 'aux_imgs':aux_imgs}
59 |
60 | return {'words':words, 'relations':relations, 'heads':heads, 'tails':tails, 'imgids': imgids, 'dataid': dataid, 'aux_imgs':aux_imgs}
61 |
62 |
63 | def get_relation_dict(self):
64 | with open(self.re_path, 'r', encoding="utf-8") as f:
65 | line = f.readlines()[0]
66 | re_dict = json.loads(line)
67 | return re_dict
68 |
69 | class MMPNERProcessor(object):
70 | def __init__(self, data_path, bert_name) -> None:
71 | self.data_path = data_path
72 | self.tokenizer = BertTokenizer.from_pretrained(bert_name, do_lower_case=True)
73 |
74 | def load_from_file(self, mode="train", sample_ratio=1.0):
75 | """
76 | Args:
77 | mode (str, optional): dataset mode. Defaults to "train".
78 | sample_ratio (float, optional): sample ratio in low resouce. Defaults to 1.0.
79 | """
80 | load_file = self.data_path[mode]
81 | logger.info("Loading data from {}".format(load_file))
82 | with open(load_file, "r", encoding="utf-8") as f:
83 | lines = f.readlines()
84 | raw_words, raw_targets = [], []
85 | raw_word, raw_target = [], []
86 | imgs = []
87 | for line in lines:
88 | if line.startswith("IMGID:"):
89 | img_id = line.strip().split('IMGID:')[1] + '.jpg'
90 | imgs.append(img_id)
91 | continue
92 | if line != "\n":
93 | raw_word.append(line.split('\t')[0])
94 | label = line.split('\t')[1][:-1]
95 | if 'OTHER' in label:
96 | label = label[:2] + 'MISC'
97 | raw_target.append(label)
98 | else:
99 | raw_words.append(raw_word)
100 | raw_targets.append(raw_target)
101 | raw_word, raw_target = [], []
102 |
103 | assert len(raw_words) == len(raw_targets) == len(imgs), "{}, {}, {}".format(len(raw_words), len(raw_targets), len(imgs))
104 | # load aux image
105 | aux_path = self.data_path[mode+"_auximgs"]
106 | aux_imgs = torch.load(aux_path)
107 |
108 | # sample data, only for low-resource
109 | if sample_ratio != 1.0:
110 | sample_indexes = random.choices(list(range(len(raw_words))), k=int(len(raw_words)*sample_ratio))
111 | sample_raw_words = [raw_words[idx] for idx in sample_indexes]
112 | sample_raw_targets = [raw_targets[idx] for idx in sample_indexes]
113 | sample_imgs = [imgs[idx] for idx in sample_indexes]
114 | assert len(sample_raw_words) == len(sample_raw_targets) == len(sample_imgs), "{}, {}, {}".format(len(sample_raw_words), len(sample_raw_targets), len(sample_imgs))
115 | return {"words": sample_raw_words, "targets": sample_raw_targets, "imgs": sample_imgs, "aux_imgs":aux_imgs}
116 |
117 | return {"words": raw_words, "targets": raw_targets, "imgs": imgs, "aux_imgs":aux_imgs}
118 |
119 | def get_label_mapping(self):
120 | LABEL_LIST = ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X", "[CLS]", "[SEP]"]
121 | label_mapping = {label:idx for idx, label in enumerate(LABEL_LIST, 1)}
122 | label_mapping["PAD"] = 0
123 | return label_mapping
124 |
125 | class MMREDataset(Dataset):
126 | def __init__(self, processor, transform, img_path=None, aux_img_path=None, max_seq=40, sample_ratio=1.0, mode="train") -> None:
127 | self.processor = processor
128 | self.transform = transform
129 | self.max_seq = max_seq
130 | self.img_path = img_path[mode] if img_path is not None else img_path
131 | self.aux_img_path = aux_img_path[mode] if aux_img_path is not None else aux_img_path
132 | self.mode = mode
133 |
134 | self.data_dict = self.processor.load_from_file(mode, sample_ratio)
135 | self.re_dict = self.processor.get_relation_dict()
136 | self.tokenizer = self.processor.tokenizer
137 |
138 | def __len__(self):
139 | return len(self.data_dict['words'])
140 |
141 | def __getitem__(self, idx):
142 | word_list, relation, head_d, tail_d, imgid = self.data_dict['words'][idx], self.data_dict['relations'][idx], self.data_dict['heads'][idx], self.data_dict['tails'][idx], self.data_dict['imgids'][idx]
143 | item_id = self.data_dict['dataid'][idx]
144 | # [CLS] ... head ... tail .. [SEP]
145 | head_pos, tail_pos = head_d['pos'], tail_d['pos']
146 | # insert
147 | extend_word_list = []
148 | for i in range(len(word_list)):
149 | if i == head_pos[0]:
150 | extend_word_list.append('')
151 | if i == head_pos[1]:
152 | extend_word_list.append('')
153 | if i == tail_pos[0]:
154 | extend_word_list.append('')
155 | if i == tail_pos[1]:
156 | extend_word_list.append('')
157 | extend_word_list.append(word_list[i])
158 | extend_word_list = " ".join(extend_word_list)
159 | encode_dict = self.tokenizer.encode_plus(text=extend_word_list, max_length=self.max_seq, truncation=True, padding='max_length')
160 | input_ids, token_type_ids, attention_mask = encode_dict['input_ids'], encode_dict['token_type_ids'], encode_dict['attention_mask']
161 | input_ids, token_type_ids, attention_mask = torch.tensor(input_ids), torch.tensor(token_type_ids), torch.tensor(attention_mask)
162 |
163 | re_label = self.re_dict[relation] # label to id
164 |
165 | # image process
166 | if self.img_path is not None:
167 | try:
168 | img_path = os.path.join(self.img_path, imgid)
169 | image = Image.open(img_path).convert('RGB')
170 | image = self.transform(image)
171 | except:
172 | img_path = os.path.join(self.img_path, 'inf.png')
173 | image = Image.open(img_path).convert('RGB')
174 | image = self.transform(image)
175 | if self.aux_img_path is not None:
176 | # process aux image
177 | aux_imgs = []
178 | aux_img_paths = []
179 | imgid = imgid.split(".")[0]
180 | if item_id in self.data_dict['aux_imgs']:
181 | aux_img_paths = self.data_dict['aux_imgs'][item_id]
182 | aux_img_paths = [os.path.join(self.aux_img_path, path) for path in aux_img_paths]
183 | # discaed more than 3 aux image
184 | for i in range(min(3, len(aux_img_paths))):
185 | aux_img = Image.open(aux_img_paths[i]).convert('RGB')
186 | aux_img = self.transform(aux_img)
187 | aux_imgs.append(aux_img)
188 |
189 | # padding zero if less than 3
190 | for i in range(3-len(aux_img_paths)):
191 | aux_imgs.append(torch.zeros((3, 224, 224)))
192 |
193 | aux_imgs = torch.stack(aux_imgs, dim=0)
194 | assert len(aux_imgs) == 3
195 | return input_ids, token_type_ids, attention_mask, torch.tensor(re_label), image, aux_imgs
196 | return input_ids, token_type_ids, attention_mask, torch.tensor(re_label)
197 |
198 | class MMPNERDataset(Dataset):
199 | def __init__(self, processor, transform, img_path=None, aux_img_path=None, max_seq=40, sample_ratio=1, mode='train', ignore_idx=0) -> None:
200 | self.processor = processor
201 | self.transform = transform
202 | self.data_dict = processor.load_from_file(mode, sample_ratio)
203 | self.tokenizer = processor.tokenizer
204 | self.label_mapping = processor.get_label_mapping()
205 | self.max_seq = max_seq
206 | self.ignore_idx = ignore_idx
207 | self.img_path = img_path
208 | self.aux_img_path = aux_img_path[mode] if aux_img_path is not None else None
209 | self.mode = mode
210 | self.sample_ratio = sample_ratio
211 |
212 | def __len__(self):
213 | return len(self.data_dict['words'])
214 |
215 | def __getitem__(self, idx):
216 | word_list, label_list, img = self.data_dict['words'][idx], self.data_dict['targets'][idx], self.data_dict['imgs'][idx]
217 | tokens, labels = [], []
218 | for i, word in enumerate(word_list):
219 | token = self.tokenizer.tokenize(word)
220 | tokens.extend(token)
221 | label = label_list[i]
222 | for m in range(len(token)):
223 | if m == 0:
224 | labels.append(self.label_mapping[label])
225 | else:
226 | labels.append(self.label_mapping["X"])
227 | if len(tokens) >= self.max_seq - 1:
228 | tokens = tokens[0:(self.max_seq - 2)]
229 | labels = labels[0:(self.max_seq - 2)]
230 |
231 | encode_dict = self.tokenizer.encode_plus(tokens, max_length=self.max_seq, truncation=True, padding='max_length')
232 | input_ids, token_type_ids, attention_mask = encode_dict['input_ids'], encode_dict['token_type_ids'], encode_dict['attention_mask']
233 | labels = [self.label_mapping["[CLS]"]] + labels + [self.label_mapping["[SEP]"]] + [self.ignore_idx]*(self.max_seq-len(labels)-2)
234 |
235 | if self.img_path is not None:
236 | # image process
237 | try:
238 | img_path = os.path.join(self.img_path, img)
239 | image = Image.open(img_path).convert('RGB')
240 | image = self.transform(image)
241 | except:
242 | img_path = os.path.join(self.img_path, 'inf.png')
243 | image = Image.open(img_path).convert('RGB')
244 | image = self.transform(image)
245 |
246 | if self.aux_img_path is not None:
247 | aux_imgs = []
248 | aux_img_paths = []
249 | if img in self.data_dict['aux_imgs']:
250 | aux_img_paths = self.data_dict['aux_imgs'][img]
251 | aux_img_paths = [os.path.join(self.aux_img_path, path) for path in aux_img_paths]
252 | for i in range(min(3, len(aux_img_paths))):
253 | aux_img = Image.open(aux_img_paths[i]).convert('RGB')
254 | aux_img = self.transform(aux_img)
255 | aux_imgs.append(aux_img)
256 |
257 | for i in range(3-len(aux_img_paths)):
258 | aux_imgs.append(torch.zeros((3, 224, 224)))
259 |
260 | aux_imgs = torch.stack(aux_imgs, dim=0)
261 | assert len(aux_imgs) == 3
262 | return torch.tensor(input_ids), torch.tensor(token_type_ids), torch.tensor(attention_mask), torch.tensor(labels), image, aux_imgs
263 |
264 | assert len(input_ids) == len(token_type_ids) == len(attention_mask) == len(labels)
265 | return torch.tensor(input_ids), torch.tensor(token_type_ids), torch.tensor(attention_mask), torch.tensor(labels)
266 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.11.3
2 | pytorch==1.7.0
3 | tensorboardX==2.4
4 | TorchCRF==1.1.0
5 | wandb==0.12.1
6 | torchvision==0.8.2
7 | torch==1.7.1
8 | scikit-learn==1.0
9 | seqeval==1.2.2
--------------------------------------------------------------------------------
/resource/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjunlp/HVPNeT/410bd0c54325376cd308bd5ee8558d61ee3d9553/resource/model.png
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import logging
4 | import sys
5 | sys.path.append("..")
6 |
7 | import torch
8 | import numpy as np
9 | import random
10 | from torchvision import transforms
11 | from torch.utils.data import DataLoader
12 | from models.bert_model import HMNeTREModel, HMNeTNERModel
13 | from processor.dataset import MMREProcessor, MMPNERProcessor, MMREDataset, MMPNERDataset
14 | from modules.train import RETrainer, NERTrainer
15 |
16 | import warnings
17 | warnings.filterwarnings("ignore", category=UserWarning)
18 | # from tensorboardX import SummaryWriter
19 |
20 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
21 | datefmt = '%m/%d/%Y %H:%M:%S',
22 | level = logging.INFO)
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | MODEL_CLASSES = {
27 | 'MRE': HMNeTREModel,
28 | 'twitter15': HMNeTNERModel,
29 | 'twitter17': HMNeTNERModel
30 | }
31 |
32 | TRAINER_CLASSES = {
33 | 'MRE': RETrainer,
34 | 'twitter15': NERTrainer,
35 | 'twitter17': NERTrainer
36 | }
37 | DATA_PROCESS = {
38 | 'MRE': (MMREProcessor, MMREDataset),
39 | 'twitter15': (MMPNERProcessor, MMPNERDataset),
40 | 'twitter17': (MMPNERProcessor, MMPNERDataset)
41 | }
42 |
43 | DATA_PATH = {
44 | 'MRE': {
45 | # text data
46 | 'train': 'data/RE_data/txt/ours_train.txt',
47 | 'dev': 'data/RE_data/txt/ours_val.txt',
48 | 'test': 'data/RE_data/txt/ours_test.txt',
49 | # {data_id : object_crop_img_path}
50 | 'train_auximgs': 'data/RE_data/txt/mre_train_dict.pth',
51 | 'dev_auximgs': 'data/RE_data/txt/mre_dev_dict.pth',
52 | 'test_auximgs': 'data/RE_data/txt/mre_test_dict.pth',
53 | # relation json data
54 | 're_path': 'data/RE_data/ours_rel2id.json'
55 | },
56 |
57 | 'twitter15': {
58 | # text data
59 | 'train': 'data/NER_data/twitter2015/train.txt',
60 | 'dev': 'data/NER_data/twitter2015/valid.txt',
61 | 'test': 'data/NER_data/twitter2015/test.txt',
62 | # {data_id : object_crop_img_path}
63 | 'train_auximgs': 'data/NER_data/twitter2015/twitter2015_train_dict.pth',
64 | 'dev_auximgs': 'data/NER_data/twitter2015/twitter2015_val_dict.pth',
65 | 'test_auximgs': 'data/NER_data/twitter2015/twitter2015_test_dict.pth'
66 | },
67 |
68 | 'twitter17': {
69 | # text data
70 | 'train': 'data/NER_data/twitter2017/train.txt',
71 | 'dev': 'data/NER_data/twitter2017/valid.txt',
72 | 'test': 'data/NER_data/twitter2017/test.txt',
73 | # {data_id : object_crop_img_path}
74 | 'train_auximgs': 'data/NER_data/twitter2017/twitter2017_train_dict.pth',
75 | 'dev_auximgs': 'data/NER_data/twitter2017/twitter2017_val_dict.pth',
76 | 'test_auximgs': 'data/NER_data/twitter2017/twitter2017_test_dict.pth'
77 | },
78 |
79 | }
80 |
81 | # image data
82 | IMG_PATH = {
83 | 'MRE': {'train': 'data/RE_data/img_org/train/',
84 | 'dev': 'data/RE_data/img_org/val/',
85 | 'test': 'data/RE_data/img_org/test'},
86 | 'twitter15': 'data/NER_data/twitter2015_images',
87 | 'twitter17': 'data/NER_data/twitter2017_images',
88 | }
89 |
90 | # auxiliary images
91 | AUX_PATH = {
92 | 'MRE':{
93 | 'train': 'data/RE_data/img_vg/train/crops',
94 | 'dev': 'data/RE_data/img_vg/val/crops',
95 | 'test': 'data/RE_data/img_vg/test/crops'
96 | },
97 | 'twitter15': {
98 | 'train': 'data/NER_data/twitter2015_aux_images/train/crops',
99 | 'dev': 'data/NER_data/twitter2015_aux_images/val/crops',
100 | 'test': 'data/NER_data/twitter2015_aux_images/test/crops',
101 | },
102 |
103 | 'twitter17': {
104 | 'train': 'data/NER_data/twitter2017_aux_images/train/crops',
105 | 'dev': 'data/NER_data/twitter2017_aux_images/val/crops',
106 | 'test': 'data/NER_data/twitter2017_aux_images/test/crops',
107 | }
108 | }
109 |
110 | def set_seed(seed=2021):
111 | """set random seed"""
112 | torch.manual_seed(seed)
113 | torch.cuda.manual_seed_all(seed)
114 | torch.backends.cudnn.deterministic = True
115 | np.random.seed(seed)
116 | random.seed(seed)
117 |
118 | def main():
119 | parser = argparse.ArgumentParser()
120 | parser.add_argument('--dataset_name', default='twitter15', type=str, help="The name of dataset.")
121 | parser.add_argument('--bert_name', default='bert-base-uncased', type=str, help="Pretrained language model path")
122 | parser.add_argument('--num_epochs', default=30, type=int, help="num training epochs")
123 | parser.add_argument('--device', default='cuda', type=str, help="cuda or cpu")
124 | parser.add_argument('--batch_size', default=32, type=int, help="batch size")
125 | parser.add_argument('--lr', default=1e-5, type=float, help="learning rate")
126 | parser.add_argument('--warmup_ratio', default=0.01, type=float)
127 | parser.add_argument('--eval_begin_epoch', default=16, type=int, help="epoch to start evluate")
128 | parser.add_argument('--seed', default=1, type=int, help="random seed, default is 1")
129 | parser.add_argument('--prompt_len', default=10, type=int, help="prompt length")
130 | parser.add_argument('--prompt_dim', default=800, type=int, help="mid dimension of prompt project layer")
131 | parser.add_argument('--load_path', default=None, type=str, help="Load model from load_path")
132 | parser.add_argument('--save_path', default=None, type=str, help="save model at save_path")
133 | parser.add_argument('--write_path', default=None, type=str, help="do_test=True, predictions will be write in write_path")
134 | parser.add_argument('--notes', default="", type=str, help="input some remarks for making save path dir.")
135 | parser.add_argument('--use_prompt', action='store_true')
136 | parser.add_argument('--do_train', action='store_true')
137 | parser.add_argument('--only_test', action='store_true')
138 | parser.add_argument('--max_seq', default=128, type=int)
139 | parser.add_argument('--ignore_idx', default=-100, type=int)
140 | parser.add_argument('--sample_ratio', default=1.0, type=float, help="only for low resource.")
141 |
142 | args = parser.parse_args()
143 |
144 | data_path, img_path, aux_path = DATA_PATH[args.dataset_name], IMG_PATH[args.dataset_name], AUX_PATH[args.dataset_name]
145 | model_class, Trainer = MODEL_CLASSES[args.dataset_name], TRAINER_CLASSES[args.dataset_name]
146 | data_process, dataset_class = DATA_PROCESS[args.dataset_name]
147 |
148 | transform = transforms.Compose([
149 | transforms.Resize(256),
150 | transforms.CenterCrop(224),
151 | transforms.ToTensor(),
152 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
153 | std=[0.229, 0.224, 0.225])])
154 |
155 | set_seed(args.seed) # set seed, default is 1
156 | if args.save_path is not None: # make save_path dir
157 | # args.save_path = os.path.join(args.save_path, args.dataset_name+"_"+str(args.batch_size)+"_"+str(args.lr)+"_"+args.notes)
158 | if not os.path.exists(args.save_path):
159 | os.makedirs(args.save_path, exist_ok=True)
160 | print(args)
161 | logdir = "logs/" + args.dataset_name+ "_"+str(args.batch_size) + "_" + str(args.lr) + args.notes
162 | # writer = SummaryWriter(logdir=logdir)
163 | writer=None
164 |
165 | if not args.use_prompt:
166 | img_path, aux_path = None, None
167 |
168 | processor = data_process(data_path, args.bert_name)
169 | train_dataset = dataset_class(processor, transform, img_path, aux_path, args.max_seq, sample_ratio=args.sample_ratio, mode='train')
170 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
171 |
172 | dev_dataset = dataset_class(processor, transform, img_path, aux_path, args.max_seq, mode='dev')
173 | dev_dataloader = DataLoader(dev_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)
174 |
175 | test_dataset = dataset_class(processor, transform, img_path, aux_path, args.max_seq, mode='test')
176 | test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)
177 |
178 | if args.dataset_name == 'MRE': # RE task
179 | re_dict = processor.get_relation_dict()
180 | num_labels = len(re_dict)
181 | tokenizer = processor.tokenizer
182 | model = HMNeTREModel(num_labels, tokenizer, args=args)
183 |
184 | trainer = Trainer(train_data=train_dataloader, dev_data=dev_dataloader, test_data=test_dataloader, model=model, processor=processor, args=args, logger=logger, writer=writer)
185 | else: # NER task
186 | label_mapping = processor.get_label_mapping()
187 | label_list = list(label_mapping.keys())
188 | model = HMNeTNERModel(label_list, args)
189 |
190 | trainer = Trainer(train_data=train_dataloader, dev_data=dev_dataloader, test_data=test_dataloader, model=model, label_map=label_mapping, args=args, logger=logger, writer=writer)
191 |
192 | if args.do_train:
193 | # train
194 | trainer.train()
195 | # test best model
196 | args.load_path = os.path.join(args.save_path, 'best_model.pth')
197 | trainer.test()
198 |
199 | if args.only_test:
200 | # only do test
201 | trainer.test()
202 |
203 | torch.cuda.empty_cache()
204 | # writer.close()
205 |
206 |
207 | if __name__ == "__main__":
208 | main()
--------------------------------------------------------------------------------
/run_re_task.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | DATASET_NAME="MRE"
4 | BERT_NAME="bert-base-uncased"
5 |
6 | CUDA_VISIBLE_DEVICES=2 python -u run.py \
7 | --dataset_name=${DATASET_NAME} \
8 | --bert_name=${BERT_NAME} \
9 | --num_epochs=15 \
10 | --batch_size=16 \
11 | --lr=3e-5 \
12 | --warmup_ratio=0.06 \
13 | --eval_begin_epoch=1 \
14 | --seed=1234 \
15 | --do_train \
16 | --max_seq=80 \
17 | --use_prompt \
18 | --prompt_len=4 \
19 | --sample_ratio=1.0 \
20 | --save_path='ckpt/re/'
--------------------------------------------------------------------------------
/run_twitter15.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Required environment variables:
3 | # batch_size (recommendation: 8 / 16)
4 | # lr: learning rate (recommendation: 3e-5 / 5e-5)
5 | # seed: random seed, default is 1234
6 | # BERT_NAME: pre-trained text model name ( bert-*)
7 | # max_seq: max sequence length
8 | # sample_ratio: few-shot learning, default is 1.0
9 | # save_path: model saved path
10 |
11 | DATASET_NAME="twitter15"
12 | BERT_NAME="bert-base-uncased"
13 |
14 | lr=3e-5
15 |
16 | CUDA_VISIBLE_DEVICES=0 python -u run.py \
17 | --dataset_name=${DATASET_NAME} \
18 | --bert_name=${BERT_NAME} \
19 | --num_epochs=30 \
20 | --batch_size=8 \
21 | --lr=$lr \
22 | --warmup_ratio=0.01 \
23 | --eval_begin_epoch=3 \
24 | --seed=1234 \
25 | --do_train \
26 | --ignore_idx=0 \
27 | --max_seq=80 \
28 | --use_prompt \
29 | --prompt_len=4 \
30 | --sample_ratio=1.0 \
31 | --save_path=your_ckpt_path
32 |
--------------------------------------------------------------------------------
/run_twitter17.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Required environment variables:
3 | # batch_size (recommendation: 8 / 16)
4 | # lr: learning rate (recommendation: 3e-5 / 5e-5)
5 | # seed: random seed, default is 1234
6 | # BERT_NAME: pre-trained text model name ( bert-*)
7 | # max_seq: max sequence length
8 | # sample_ratio: few-shot learning, default is 1.0
9 | # save_path: model saved path
10 |
11 | DATASET_NAME="twitter17"
12 | BERT_NAME="bert-base-uncased"
13 |
14 | CUDA_VISIBLE_DEVICES=0 python -u run.py \
15 | --dataset_name=${DATASET_NAME} \
16 | --bert_name=${BERT_NAME} \
17 | --num_epochs=30 \
18 | --batch_size=8 \
19 | --lr=3e-5 \
20 | --warmup_ratio=0.01 \
21 | --eval_begin_epoch=3 \
22 | --seed=1234 \
23 | --do_train \
24 | --ignore_idx=0 \
25 | --max_seq=128 \
26 | --use_prompt \
27 | --prompt_len=4 \
28 | --sample_ratio=1.0 \
29 | --save_path=your_ckpt_path
--------------------------------------------------------------------------------