├── .gitignore ├── LICENSE ├── README.md ├── models ├── __init__.py ├── bert_model.py └── modeling_bert.py ├── modules ├── metrics.py └── train.py ├── processor └── dataset.py ├── requirements.txt ├── resource └── model.png ├── run.py ├── run_re_task.sh ├── run_twitter15.sh └── run_twitter17.sh /.gitignore: -------------------------------------------------------------------------------- 1 | ckpt/ 2 | 3 | data/ 4 | 5 | logs/ 6 | 7 | __pycache__/ 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 ZJUNLP 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HVPNet 2 | 3 | Code for the NAACL2022 (Findings) paper "[Good Visual Guidance Makes A Better Extractor: Hierarchical Visual Prefix for Multimodal Entity and Relation Extraction](https://arxiv.org/pdf/2205.03521.pdf)". 4 | 5 | Model Architecture 6 | ========== 7 |
8 | 9 |
10 | The overall architecture of our hierarchical modality fusion network. 11 | 12 | 13 | Requirements 14 | ========== 15 | To run the codes, you need to install the requirements: 16 | ``` 17 | pip install -r requirements.txt 18 | ``` 19 | 20 | Data Preprocess 21 | ========== 22 | To extract visual object images, we first use the NLTK parser to extract noun phrases from the text and apply the [visual grouding toolkit](https://github.com/zyang-ur/onestage_grounding) to detect objects. Detailed steps are as follows: 23 | 24 | 1. Using the NLTK parser (or Spacy, textblob) to extract noun phrases from the text. 25 | 2. Applying the [visual grouding toolkit](https://github.com/zyang-ur/onestage_grounding) to detect objects. Taking the twitter2015 dataset as an example, the extracted objects are stored in `twitter2015_aux_images`. The images of the object obey the following naming format: `imgname_pred_yolo_crop_num.png`, where `imgname` is the name of the raw image corresponding to the object, `num` is the number of the object predicted by the toolkit. (Note that in `train/val/test.txt`, text and raw image have a one-to-one relationship, so the `imgname` can be used as a unique identifier for the raw images) 26 | 3. Establishing the correspondence between the raw images and the objects. We construct a dictionary to record the correspondence between the raw images and the objects. Taking `twitter2015/twitter2015_train_dict.pth` as an example, the format of the dictionary can be seen as follows: `{imgname:['imgname_pred_yolo_crop_num0.png', 'imgname_pred_yolo_crop_num1.png', ...] }`, where key is the name of raw images, value is a List of the objects. 27 | 28 | The detected objects and the dictionary of the correspondence between the raw images and the objects are available in our data links. 29 | 30 | Data Download 31 | ========== 32 | 33 | + Twitter2015 & Twitter2017 34 | 35 | The text data follows the conll format. You can download the Twitter2015 data via this [link](https://drive.google.com/file/d/1qAWrV9IaiBadICFb7mAreXy3llao_teZ/view?usp=sharing) and download the Twitter2017 data via this [link](https://drive.google.com/file/d/1ogfbn-XEYtk9GpUECq1-IwzINnhKGJqy/view?usp=sharing). Please place them in `data/NER_data`. 36 | 37 | You can also put them anywhere and modify the path configuration in `run.py` 38 | 39 | + MNRE 40 | 41 | The MNRE dataset comes from [MEGA](https://github.com/thecharm/MNRE), many thanks. 42 | 43 | You can download the MRE dataset with detected visual objects from [Google Drive](https://drive.google.com/file/d/1q5_5vnHJ8Hik1iLA9f5-6nstcvvntLrS/view?usp=sharing) or use the following commands: 44 | ```bash 45 | cd data 46 | wget 120.27.214.45/Data/re/multimodal/data.tar.gz 47 | tar -xzvf data.tar.gz 48 | mv data RE_data 49 | ``` 50 | 51 | The expected structure of files is: 52 | 53 | ``` 54 | HMNeT 55 | |-- data 56 | | |-- NER_data 57 | | | |-- twitter2015 # text data 58 | | | | |-- train.txt 59 | | | | |-- valid.txt 60 | | | | |-- test.txt 61 | | | | |-- twitter2015_train_dict.pth # {imgname: [object-image]} 62 | | | | |-- ... 63 | | | |-- twitter2015_images # raw image data 64 | | | |-- twitter2015_aux_images # object image data 65 | | | |-- twitter2017 66 | | | |-- twitter2017_images 67 | | | |-- twitter2017_aux_images 68 | | |-- RE_data 69 | | | |-- img_org # raw image data 70 | | | |-- img_vg # object image data 71 | | | |-- txt # text data 72 | | | |-- ours_rel2id.json # relation data 73 | |-- models # models 74 | | |-- bert_model.py 75 | | |-- modeling_bert.py 76 | |-- modules 77 | | |-- metrics.py # metric 78 | | |-- train.py # trainer 79 | |-- processor 80 | | |-- dataset.py # processor, dataset 81 | |-- logs # code logs 82 | |-- run.py # main 83 | |-- run_ner_task.sh 84 | |-- run_re_task.sh 85 | ``` 86 | 87 | Train 88 | ========== 89 | 90 | ## NER Task 91 | 92 | The data path and GPU related configuration are in the `run.py`. To train ner model, run this script. 93 | 94 | ```shell 95 | bash run_twitter15.sh 96 | bash run_twitter17.sh 97 | ``` 98 | 99 | ## RE Task 100 | 101 | To train re model, run this script. 102 | 103 | ```shell 104 | bash run_re_task.sh 105 | ``` 106 | 107 | Test 108 | ========== 109 | ## NER Task 110 | 111 | To test ner model, you can use the tained model and set `load_path` to the model path, then run following script: 112 | 113 | ```shell 114 | python -u run.py \ 115 | --dataset_name="twitter15/twitter17" \ 116 | --bert_name="bert-base-uncased" \ 117 | --seed=1234 \ 118 | --only_test \ 119 | --max_seq=80 \ 120 | --use_prompt \ 121 | --prompt_len=4 \ 122 | --sample_ratio=1.0 \ 123 | --load_path='your_ner_ckpt_path' 124 | 125 | ``` 126 | 127 | ## RE Task 128 | 129 | To test re model, you can use the tained model and set `load_path` to the model path, then run following script: 130 | 131 | ```shell 132 | python -u run.py \ 133 | --dataset_name="MRE" \ 134 | --bert_name="bert-base-uncased" \ 135 | --seed=1234 \ 136 | --only_test \ 137 | --max_seq=80 \ 138 | --use_prompt \ 139 | --prompt_len=4 \ 140 | --sample_ratio=1.0 \ 141 | --load_path='your_re_ckpt_path' 142 | 143 | ``` 144 | 145 | Acknowledgement 146 | ========== 147 | 148 | The acquisition of Twitter15 and Twitter17 data refer to the code from [UMT](https://github.com/jefferyYu/UMT/), many thanks. 149 | 150 | The acquisition of MNRE data for multimodal relation extraction task refer to the code from [MEGA](https://github.com/thecharm/Mega), many thanks. 151 | 152 | Papers for the Project & How to Cite 153 | ========== 154 | 155 | 156 | If you use or extend our work, please cite the paper as follows: 157 | 158 | ```bibtex 159 | @inproceedings{DBLP:conf/naacl/ChenZLYDTHSC22, 160 | author = {Xiang Chen and 161 | Ningyu Zhang and 162 | Lei Li and 163 | Yunzhi Yao and 164 | Shumin Deng and 165 | Chuanqi Tan and 166 | Fei Huang and 167 | Luo Si and 168 | Huajun Chen}, 169 | editor = {Marine Carpuat and 170 | Marie{-}Catherine de Marneffe and 171 | Iv{\'{a}}n Vladimir Meza Ru{\'{\i}}z}, 172 | title = {Good Visual Guidance Make {A} Better Extractor: Hierarchical Visual 173 | Prefix for Multimodal Entity and Relation Extraction}, 174 | booktitle = {Findings of the Association for Computational Linguistics: {NAACL} 175 | 2022, Seattle, WA, United States, July 10-15, 2022}, 176 | pages = {1607--1618}, 177 | publisher = {Association for Computational Linguistics}, 178 | year = {2022}, 179 | url = {https://doi.org/10.18653/v1/2022.findings-naacl.121}, 180 | doi = {10.18653/v1/2022.findings-naacl.121}, 181 | timestamp = {Tue, 23 Aug 2022 08:36:33 +0200}, 182 | biburl = {https://dblp.org/rec/conf/naacl/ChenZLYDTHSC22.bib}, 183 | bibsource = {dblp computer science bibliography, https://dblp.org} 184 | } 185 | ``` 186 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .bert_model import * 2 | from .modeling_bert import * -------------------------------------------------------------------------------- /models/bert_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from torchcrf import CRF 6 | from .modeling_bert import BertModel 7 | from transformers.modeling_outputs import TokenClassifierOutput 8 | from torchvision.models import resnet50 9 | 10 | class ImageModel(nn.Module): 11 | def __init__(self): 12 | super(ImageModel, self).__init__() 13 | self.resnet = resnet50(pretrained=True) 14 | 15 | def forward(self, x, aux_imgs=None): 16 | # full image prompt 17 | prompt_guids = self.get_resnet_prompt(x) # 4x[bsz, 256, 7, 7] 18 | 19 | # aux_imgs: bsz x 3(nums) x 3 x 224 x 224 20 | if aux_imgs is not None: 21 | aux_prompt_guids = [] # goal: 3 x (4 x [bsz, 256, 7, 7]) 22 | aux_imgs = aux_imgs.permute([1, 0, 2, 3, 4]) # 3(nums) x bsz x 3 x 224 x 224 23 | for i in range(len(aux_imgs)): 24 | aux_prompt_guid = self.get_resnet_prompt(aux_imgs[i]) # 4 x [bsz, 256, 7, 7] 25 | aux_prompt_guids.append(aux_prompt_guid) 26 | return prompt_guids, aux_prompt_guids 27 | return prompt_guids, None 28 | 29 | def get_resnet_prompt(self, x): 30 | """generate image prompt 31 | 32 | Args: 33 | x ([torch.tenspr]): bsz x 3 x 224 x 224 34 | 35 | Returns: 36 | prompt_guids ([List[torch.tensor]]): 4 x List[bsz x 256 x 7 x 7] 37 | """ 38 | # image: bsz x 3 x 224 x 224 39 | prompt_guids = [] 40 | for name, layer in self.resnet.named_children(): 41 | if name == 'fc' or name == 'avgpool': continue 42 | x = layer(x) # (bsz, 256, 56, 56) 43 | if 'layer' in name: 44 | bsz, channel, ft, _ = x.size() 45 | kernel = ft // 2 46 | prompt_kv = nn.AvgPool2d(kernel_size=(kernel, kernel), stride=kernel)(x) # (bsz, 256, 7, 7) 47 | prompt_guids.append(prompt_kv) # conv2: (bsz, 256, 7, 7) 48 | return prompt_guids 49 | 50 | 51 | class HMNeTREModel(nn.Module): 52 | def __init__(self, num_labels, tokenizer, args): 53 | super(HMNeTREModel, self).__init__() 54 | self.bert = BertModel.from_pretrained(args.bert_name) 55 | self.bert.resize_token_embeddings(len(tokenizer)) 56 | self.args = args 57 | 58 | self.dropout = nn.Dropout(0.5) 59 | self.classifier = nn.Linear(self.bert.config.hidden_size*2, num_labels) 60 | self.head_start = tokenizer.convert_tokens_to_ids("") 61 | self.tail_start = tokenizer.convert_tokens_to_ids("") 62 | self.tokenizer = tokenizer 63 | 64 | if self.args.use_prompt: 65 | self.image_model = ImageModel() 66 | 67 | self.encoder_conv = nn.Sequential( 68 | nn.Linear(in_features=3840, out_features=800), 69 | nn.Tanh(), 70 | nn.Linear(in_features=800, out_features=4*2*768) 71 | ) 72 | 73 | self.gates = nn.ModuleList([nn.Linear(4*768*2, 4) for i in range(12)]) 74 | 75 | def forward( 76 | self, 77 | input_ids=None, 78 | attention_mask=None, 79 | token_type_ids=None, 80 | labels=None, 81 | images=None, 82 | aux_imgs=None, 83 | ): 84 | 85 | bsz = input_ids.size(0) 86 | if self.args.use_prompt: 87 | prompt_guids = self.get_visual_prompt(images, aux_imgs) 88 | prompt_guids_length = prompt_guids[0][0].shape[2] 89 | prompt_guids_mask = torch.ones((bsz, prompt_guids_length)).to(self.args.device) 90 | prompt_attention_mask = torch.cat((prompt_guids_mask, attention_mask), dim=1) 91 | else: 92 | prompt_guids = None 93 | prompt_attention_mask = attention_mask 94 | 95 | output = self.bert( 96 | input_ids=input_ids, 97 | token_type_ids=token_type_ids, 98 | attention_mask=prompt_attention_mask, 99 | past_key_values=prompt_guids, 100 | output_attentions=True, 101 | return_dict=True 102 | ) 103 | 104 | last_hidden_state, pooler_output = output.last_hidden_state, output.pooler_output 105 | bsz, seq_len, hidden_size = last_hidden_state.shape 106 | entity_hidden_state = torch.Tensor(bsz, 2*hidden_size) # batch, 2*hidden 107 | for i in range(bsz): 108 | head_idx = input_ids[i].eq(self.head_start).nonzero().item() 109 | tail_idx = input_ids[i].eq(self.tail_start).nonzero().item() 110 | head_hidden = last_hidden_state[i, head_idx, :].squeeze() 111 | tail_hidden = last_hidden_state[i, tail_idx, :].squeeze() 112 | entity_hidden_state[i] = torch.cat([head_hidden, tail_hidden], dim=-1) 113 | entity_hidden_state = entity_hidden_state.to(self.args.device) 114 | logits = self.classifier(entity_hidden_state) 115 | if labels is not None: 116 | loss_fn = nn.CrossEntropyLoss() 117 | return loss_fn(logits, labels.view(-1)), logits 118 | return logits 119 | 120 | def get_visual_prompt(self, images, aux_imgs): 121 | bsz = images.size(0) 122 | # full image prompt 123 | prompt_guids, aux_prompt_guids = self.image_model(images, aux_imgs) # [bsz, 256, 2, 2], [bsz, 512, 2, 2].... 124 | prompt_guids = torch.cat(prompt_guids, dim=1).view(bsz, self.args.prompt_len, -1) # bsz, 4, 3840 125 | 126 | # aux image prompts # 3 x (4 x [bsz, 256, 2, 2]) 127 | aux_prompt_guids = [torch.cat(aux_prompt_guid, dim=1).view(bsz, self.args.prompt_len, -1) for aux_prompt_guid in aux_prompt_guids] # 3 x [bsz, 4, 3840] 128 | 129 | prompt_guids = self.encoder_conv(prompt_guids) # bsz, 4, 4*2*768 130 | aux_prompt_guids = [self.encoder_conv(aux_prompt_guid) for aux_prompt_guid in aux_prompt_guids] # 3 x [bsz, 4, 4*2*768] 131 | split_prompt_guids = prompt_guids.split(768*2, dim=-1) # 4 x [bsz, 4, 768*2] 132 | split_aux_prompt_guids = [aux_prompt_guid.split(768*2, dim=-1) for aux_prompt_guid in aux_prompt_guids] # 3x [4 x [bsz, 4, 768*2]] 133 | 134 | sum_prompt_guids = torch.stack(split_prompt_guids).sum(0).view(bsz, -1) / 4 # bsz, 4, 768*2 135 | 136 | result = [] 137 | for idx in range(12): # 12 138 | prompt_gate = F.softmax(F.leaky_relu(self.gates[idx](sum_prompt_guids)), dim=-1) 139 | 140 | key_val = torch.zeros_like(split_prompt_guids[0]).to(self.args.device) # bsz, 4, 768*2 141 | for i in range(4): 142 | key_val = key_val + torch.einsum('bg,blh->blh', prompt_gate[:, i].view(-1, 1), split_prompt_guids[i]) 143 | 144 | # use gate mix aux image prompts 145 | aux_key_vals = [] # 3 x [bsz, 4, 768*2] 146 | for split_aux_prompt_guid in split_aux_prompt_guids: 147 | sum_aux_prompt_guids = torch.stack(split_aux_prompt_guid).sum(0).view(bsz, -1) / 4 # bsz, 4, 768*2 148 | aux_prompt_gate = F.softmax(F.leaky_relu(self.gates[idx](sum_aux_prompt_guids)), dim=-1) 149 | aux_key_val = torch.zeros_like(split_aux_prompt_guid[0]).to(self.args.device) # bsz, 4, 768*2 150 | for i in range(4): 151 | aux_key_val = aux_key_val + torch.einsum('bg,blh->blh', aux_prompt_gate[:, i].view(-1, 1), split_aux_prompt_guid[i]) 152 | aux_key_vals.append(aux_key_val) 153 | key_val = [key_val] + aux_key_vals 154 | key_val = torch.cat(key_val, dim=1) 155 | key_val = key_val.split(768, dim=-1) 156 | key, value = key_val[0].reshape(bsz, 12, -1, 64).contiguous(), key_val[1].reshape(bsz, 12, -1, 64).contiguous() # bsz, 12, 4, 64 157 | temp_dict = (key, value) 158 | result.append(temp_dict) 159 | return result 160 | 161 | 162 | class HMNeTNERModel(nn.Module): 163 | def __init__(self, label_list, args): 164 | super(HMNeTNERModel, self).__init__() 165 | self.args = args 166 | self.prompt_dim = args.prompt_dim 167 | self.prompt_len = args.prompt_len 168 | self.bert = BertModel.from_pretrained(args.bert_name) 169 | self.bert_config = self.bert.config 170 | 171 | if args.use_prompt: 172 | self.image_model = ImageModel() # bsz, 6, 56, 56 173 | self.encoder_conv = nn.Sequential( 174 | nn.Linear(in_features=3840, out_features=800), 175 | nn.Tanh(), 176 | nn.Linear(in_features=800, out_features=4*2*768) 177 | ) 178 | self.gates = nn.ModuleList([nn.Linear(4*768*2, 4) for i in range(12)]) 179 | 180 | self.num_labels = len(label_list) # pad 181 | print(self.num_labels) 182 | self.crf = CRF(self.num_labels, batch_first=True) 183 | self.fc = nn.Linear(self.bert.config.hidden_size, self.num_labels) 184 | self.dropout = nn.Dropout(0.1) 185 | 186 | def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, images=None, aux_imgs=None): 187 | if self.args.use_prompt: 188 | prompt_guids = self.get_visual_prompt(images, aux_imgs) 189 | prompt_guids_length = prompt_guids[0][0].shape[2] 190 | # attention_mask: bsz, seq_len 191 | # prompt attention, attention mask 192 | bsz = attention_mask.size(0) 193 | prompt_guids_mask = torch.ones((bsz, prompt_guids_length)).to(self.args.device) 194 | prompt_attention_mask = torch.cat((prompt_guids_mask, attention_mask), dim=1) 195 | else: 196 | prompt_attention_mask = attention_mask 197 | prompt_guids = None 198 | 199 | bert_output = self.bert(input_ids=input_ids, 200 | attention_mask=prompt_attention_mask, 201 | token_type_ids=token_type_ids, 202 | past_key_values=prompt_guids, 203 | return_dict=True) 204 | sequence_output = bert_output['last_hidden_state'] # bsz, len, hidden 205 | sequence_output = self.dropout(sequence_output) # bsz, len, hidden 206 | emissions = self.fc(sequence_output) # bsz, len, labels 207 | 208 | logits = self.crf.decode(emissions, attention_mask.byte()) 209 | loss = None 210 | if labels is not None: 211 | loss = -1 * self.crf(emissions, labels, mask=attention_mask.byte(), reduction='mean') 212 | return TokenClassifierOutput( 213 | loss=loss, 214 | logits=logits 215 | ) 216 | 217 | def get_visual_prompt(self, images, aux_imgs): 218 | bsz = images.size(0) 219 | prompt_guids, aux_prompt_guids = self.image_model(images, aux_imgs) # [bsz, 256, 2, 2], [bsz, 512, 2, 2].... 220 | 221 | prompt_guids = torch.cat(prompt_guids, dim=1).view(bsz, self.args.prompt_len, -1) # bsz, 4, 3840 222 | aux_prompt_guids = [torch.cat(aux_prompt_guid, dim=1).view(bsz, self.args.prompt_len, -1) for aux_prompt_guid in aux_prompt_guids] # 3 x [bsz, 4, 3840] 223 | 224 | prompt_guids = self.encoder_conv(prompt_guids) # bsz, 4, 4*2*768 225 | aux_prompt_guids = [self.encoder_conv(aux_prompt_guid) for aux_prompt_guid in aux_prompt_guids] # 3 x [bsz, 4, 4*2*768] 226 | split_prompt_guids = prompt_guids.split(768*2, dim=-1) # 4 x [bsz, 4, 768*2] 227 | split_aux_prompt_guids = [aux_prompt_guid.split(768*2, dim=-1) for aux_prompt_guid in aux_prompt_guids] # 3x [4 x [bsz, 4, 768*2]] 228 | 229 | result = [] 230 | for idx in range(12): # 12 231 | sum_prompt_guids = torch.stack(split_prompt_guids).sum(0).view(bsz, -1) / 4 # bsz, 4, 768*2 232 | prompt_gate = F.softmax(F.leaky_relu(self.gates[idx](sum_prompt_guids)), dim=-1) 233 | 234 | key_val = torch.zeros_like(split_prompt_guids[0]).to(self.args.device) # bsz, 4, 768*2 235 | for i in range(4): 236 | key_val = key_val + torch.einsum('bg,blh->blh', prompt_gate[:, i].view(-1, 1), split_prompt_guids[i]) 237 | 238 | aux_key_vals = [] # 3 x [bsz, 4, 768*2] 239 | for split_aux_prompt_guid in split_aux_prompt_guids: 240 | sum_aux_prompt_guids = torch.stack(split_aux_prompt_guid).sum(0).view(bsz, -1) / 4 # bsz, 4, 768*2 241 | aux_prompt_gate = F.softmax(F.leaky_relu(self.gates[idx](sum_aux_prompt_guids)), dim=-1) 242 | aux_key_val = torch.zeros_like(split_aux_prompt_guid[0]).to(self.args.device) # bsz, 4, 768*2 243 | for i in range(4): 244 | aux_key_val = aux_key_val + torch.einsum('bg,blh->blh', aux_prompt_gate[:, i].view(-1, 1), split_aux_prompt_guid[i]) 245 | aux_key_vals.append(aux_key_val) 246 | key_val = [key_val] + aux_key_vals 247 | key_val = torch.cat(key_val, dim=1) 248 | key_val = key_val.split(768, dim=-1) 249 | key, value = key_val[0].reshape(bsz, 12, -1, 64).contiguous(), key_val[1].reshape(bsz, 12, -1, 64).contiguous() # bsz, 12, 4, 64 250 | temp_dict = (key, value) 251 | result.append(temp_dict) 252 | return result 253 | -------------------------------------------------------------------------------- /models/modeling_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model. """ 17 | import math 18 | import os 19 | import warnings 20 | from dataclasses import dataclass 21 | from typing import Optional, Tuple 22 | 23 | import torch 24 | import torch.utils.checkpoint 25 | from packaging import version 26 | from torch import nn 27 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 28 | 29 | from transformers.activations import ACT2FN 30 | from transformers.file_utils import ( 31 | ModelOutput, 32 | add_code_sample_docstrings, 33 | add_start_docstrings, 34 | add_start_docstrings_to_model_forward, 35 | replace_return_docstrings, 36 | ) 37 | from transformers.modeling_outputs import ( 38 | BaseModelOutputWithPastAndCrossAttentions, 39 | BaseModelOutputWithPoolingAndCrossAttentions, 40 | CausalLMOutputWithCrossAttentions, 41 | MaskedLMOutput, 42 | MultipleChoiceModelOutput, 43 | NextSentencePredictorOutput, 44 | QuestionAnsweringModelOutput, 45 | SequenceClassifierOutput, 46 | TokenClassifierOutput, 47 | ) 48 | from transformers.modeling_utils import ( 49 | PreTrainedModel, 50 | apply_chunking_to_forward, 51 | find_pruneable_heads_and_indices, 52 | prune_linear_layer, 53 | ) 54 | from transformers.utils import logging 55 | from transformers import BertConfig 56 | 57 | 58 | logger = logging.get_logger(__name__) 59 | 60 | _CHECKPOINT_FOR_DOC = "bert-base-uncased" 61 | _CONFIG_FOR_DOC = "BertConfig" 62 | _TOKENIZER_FOR_DOC = "BertTokenizer" 63 | 64 | BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ 65 | "bert-base-uncased", 66 | "bert-large-uncased", 67 | "bert-base-cased", 68 | "bert-large-cased", 69 | "bert-base-multilingual-uncased", 70 | "bert-base-multilingual-cased", 71 | "bert-base-chinese", 72 | "bert-base-german-cased", 73 | "bert-large-uncased-whole-word-masking", 74 | "bert-large-cased-whole-word-masking", 75 | "bert-large-uncased-whole-word-masking-finetuned-squad", 76 | "bert-large-cased-whole-word-masking-finetuned-squad", 77 | "bert-base-cased-finetuned-mrpc", 78 | "bert-base-german-dbmdz-cased", 79 | "bert-base-german-dbmdz-uncased", 80 | "cl-tohoku/bert-base-japanese", 81 | "cl-tohoku/bert-base-japanese-whole-word-masking", 82 | "cl-tohoku/bert-base-japanese-char", 83 | "cl-tohoku/bert-base-japanese-char-whole-word-masking", 84 | "TurkuNLP/bert-base-finnish-cased-v1", 85 | "TurkuNLP/bert-base-finnish-uncased-v1", 86 | "wietsedv/bert-base-dutch-cased", 87 | # See all BERT models at https://huggingface.co/models?filter=bert 88 | ] 89 | 90 | 91 | def load_tf_weights_in_bert(model, config, tf_checkpoint_path): 92 | """Load tf checkpoints in a pytorch model.""" 93 | try: 94 | import re 95 | 96 | import numpy as np 97 | import tensorflow as tf 98 | except ImportError: 99 | logger.error( 100 | "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " 101 | "https://www.tensorflow.org/install/ for installation instructions." 102 | ) 103 | raise 104 | tf_path = os.path.abspath(tf_checkpoint_path) 105 | logger.info(f"Converting TensorFlow checkpoint from {tf_path}") 106 | # Load weights from TF model 107 | init_vars = tf.train.list_variables(tf_path) 108 | names = [] 109 | arrays = [] 110 | for name, shape in init_vars: 111 | logger.info(f"Loading TF weight {name} with shape {shape}") 112 | array = tf.train.load_variable(tf_path, name) 113 | names.append(name) 114 | arrays.append(array) 115 | 116 | for name, array in zip(names, arrays): 117 | name = name.split("/") 118 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v 119 | # which are not required for using pretrained model 120 | if any( 121 | n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] 122 | for n in name 123 | ): 124 | logger.info(f"Skipping {'/'.join(name)}") 125 | continue 126 | pointer = model 127 | for m_name in name: 128 | if re.fullmatch(r"[A-Za-z]+_\d+", m_name): 129 | scope_names = re.split(r"_(\d+)", m_name) 130 | else: 131 | scope_names = [m_name] 132 | if scope_names[0] == "kernel" or scope_names[0] == "gamma": 133 | pointer = getattr(pointer, "weight") 134 | elif scope_names[0] == "output_bias" or scope_names[0] == "beta": 135 | pointer = getattr(pointer, "bias") 136 | elif scope_names[0] == "output_weights": 137 | pointer = getattr(pointer, "weight") 138 | elif scope_names[0] == "squad": 139 | pointer = getattr(pointer, "classifier") 140 | else: 141 | try: 142 | pointer = getattr(pointer, scope_names[0]) 143 | except AttributeError: 144 | logger.info(f"Skipping {'/'.join(name)}") 145 | continue 146 | if len(scope_names) >= 2: 147 | num = int(scope_names[1]) 148 | pointer = pointer[num] 149 | if m_name[-11:] == "_embeddings": 150 | pointer = getattr(pointer, "weight") 151 | elif m_name == "kernel": 152 | array = np.transpose(array) 153 | try: 154 | assert ( 155 | pointer.shape == array.shape 156 | ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" 157 | except AssertionError as e: 158 | e.args += (pointer.shape, array.shape) 159 | raise 160 | logger.info(f"Initialize PyTorch weight {name}") 161 | pointer.data = torch.from_numpy(array) 162 | return model 163 | 164 | 165 | class BertEmbeddings(nn.Module): 166 | """Construct the embeddings from word, position and token_type embeddings.""" 167 | 168 | def __init__(self, config): 169 | super().__init__() 170 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) 171 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 172 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 173 | 174 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 175 | # any TensorFlow checkpoint file 176 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 177 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 178 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized 179 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") 180 | self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) 181 | if version.parse(torch.__version__) > version.parse("1.6.0"): 182 | self.register_buffer( 183 | "token_type_ids", 184 | torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), 185 | persistent=False, 186 | ) 187 | 188 | def forward( 189 | self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 190 | ): 191 | if input_ids is not None: 192 | input_shape = input_ids.size() 193 | else: 194 | input_shape = inputs_embeds.size()[:-1] 195 | 196 | seq_length = input_shape[1] 197 | 198 | if position_ids is None: 199 | position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] 200 | 201 | # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs 202 | # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves 203 | # issue #5664 204 | if token_type_ids is None: 205 | if hasattr(self, "token_type_ids"): 206 | buffered_token_type_ids = self.token_type_ids[:, :seq_length] 207 | buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) 208 | token_type_ids = buffered_token_type_ids_expanded 209 | else: 210 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) 211 | 212 | if inputs_embeds is None: 213 | inputs_embeds = self.word_embeddings(input_ids) 214 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 215 | 216 | embeddings = inputs_embeds + token_type_embeddings 217 | if self.position_embedding_type == "absolute": 218 | position_embeddings = self.position_embeddings(position_ids) 219 | embeddings += position_embeddings 220 | embeddings = self.LayerNorm(embeddings) 221 | embeddings = self.dropout(embeddings) 222 | return embeddings 223 | 224 | 225 | class BertSelfAttention(nn.Module): 226 | def __init__(self, config): 227 | super().__init__() 228 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): 229 | raise ValueError( 230 | f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " 231 | f"heads ({config.num_attention_heads})" 232 | ) 233 | 234 | self.num_attention_heads = config.num_attention_heads # 12 235 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) # 64 236 | self.all_head_size = self.num_attention_heads * self.attention_head_size # 768 237 | 238 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 239 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 240 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 241 | 242 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 243 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") 244 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 245 | self.max_position_embeddings = config.max_position_embeddings 246 | self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) 247 | 248 | self.is_decoder = config.is_decoder 249 | 250 | def transpose_for_scores(self, x): 251 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 252 | x = x.view(*new_x_shape) 253 | return x.permute(0, 2, 1, 3) 254 | 255 | def forward( 256 | self, 257 | hidden_states, 258 | attention_mask=None, 259 | head_mask=None, 260 | encoder_hidden_states=None, 261 | encoder_attention_mask=None, 262 | past_key_value=None, 263 | output_attentions=False, 264 | ### add: 265 | current_layer=0 266 | ): 267 | mixed_query_layer = self.query(hidden_states) 268 | 269 | # If this is instantiated as a cross-attention module, the keys 270 | # and values come from an encoder; the attention mask needs to be 271 | # such that the encoder's padding tokens are not attended to. 272 | is_cross_attention = encoder_hidden_states is not None 273 | 274 | if is_cross_attention and past_key_value is not None: 275 | # reuse k,v, cross_attentions 276 | key_layer = past_key_value[0] 277 | value_layer = past_key_value[1] 278 | attention_mask = encoder_attention_mask 279 | elif is_cross_attention: 280 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) 281 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) 282 | attention_mask = encoder_attention_mask 283 | elif past_key_value is not None: # go here 284 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 285 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 286 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2) 287 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2) 288 | else: 289 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 290 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 291 | 292 | query_layer = self.transpose_for_scores(mixed_query_layer) 293 | if self.is_decoder: 294 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 295 | # Further calls to cross_attention layer can then reuse all cross-attention 296 | # key/value_states (first "if" case) 297 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 298 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 299 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 300 | # if encoder bi-directional self-attention `past_key_value` is always `None` 301 | past_key_value = (key_layer, value_layer) 302 | 303 | # Take the dot product between "query" and "key" to get the raw attention scores.q 304 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 305 | 306 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 307 | seq_length = hidden_states.size()[1] 308 | position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) 309 | position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) 310 | distance = position_ids_l - position_ids_r 311 | positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) 312 | positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility 313 | 314 | if self.position_embedding_type == "relative_key": 315 | relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 316 | attention_scores = attention_scores + relative_position_scores 317 | elif self.position_embedding_type == "relative_key_query": 318 | relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 319 | relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) 320 | attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key 321 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 322 | if attention_mask is not None: 323 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 324 | attention_scores = attention_scores + attention_mask 325 | # Normalize the attention scores to probabilities. 326 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 327 | # This is actually dropping out entire tokens to attend to, which might 328 | # seem a bit unusual, but is taken from the original Transformer paper. 329 | attention_probs = self.dropout(attention_probs) 330 | 331 | # Mask heads if we want to 332 | if head_mask is not None: 333 | attention_probs = attention_probs * head_mask 334 | context_layer = torch.matmul(attention_probs, value_layer) 335 | 336 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 337 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 338 | context_layer = context_layer.view(*new_context_layer_shape) 339 | 340 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) 341 | if self.is_decoder: 342 | outputs = outputs + (past_key_value,) 343 | return outputs 344 | 345 | 346 | class BertSelfOutput(nn.Module): 347 | def __init__(self, config): 348 | super().__init__() 349 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 350 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 351 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 352 | 353 | def forward(self, hidden_states, input_tensor): 354 | hidden_states = self.dense(hidden_states) 355 | hidden_states = self.dropout(hidden_states) 356 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 357 | return hidden_states 358 | 359 | 360 | class BertAttention(nn.Module): 361 | def __init__(self, config): 362 | super().__init__() 363 | self.self = BertSelfAttention(config) 364 | self.output = BertSelfOutput(config) 365 | self.pruned_heads = set() 366 | 367 | def prune_heads(self, heads): 368 | if len(heads) == 0: 369 | return 370 | heads, index = find_pruneable_heads_and_indices( 371 | heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads 372 | ) 373 | 374 | # Prune linear layers 375 | self.self.query = prune_linear_layer(self.self.query, index) 376 | self.self.key = prune_linear_layer(self.self.key, index) 377 | self.self.value = prune_linear_layer(self.self.value, index) 378 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) 379 | 380 | # Update hyper params and store pruned heads 381 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads) 382 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads 383 | self.pruned_heads = self.pruned_heads.union(heads) 384 | 385 | def forward( 386 | self, 387 | hidden_states, 388 | attention_mask=None, 389 | head_mask=None, 390 | encoder_hidden_states=None, 391 | encoder_attention_mask=None, 392 | past_key_value=None, 393 | output_attentions=False, 394 | # add: 395 | current_layer=0 396 | ): 397 | self_outputs = self.self( 398 | hidden_states, 399 | attention_mask, 400 | head_mask, 401 | encoder_hidden_states, 402 | encoder_attention_mask, 403 | past_key_value, 404 | output_attentions, 405 | ### add: 406 | current_layer=current_layer 407 | ) 408 | attention_output = self.output(self_outputs[0], hidden_states) 409 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them 410 | return outputs 411 | 412 | 413 | class BertIntermediate(nn.Module): 414 | def __init__(self, config): 415 | super().__init__() 416 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 417 | if isinstance(config.hidden_act, str): 418 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 419 | else: 420 | self.intermediate_act_fn = config.hidden_act 421 | 422 | def forward(self, hidden_states): 423 | hidden_states = self.dense(hidden_states) 424 | hidden_states = self.intermediate_act_fn(hidden_states) 425 | return hidden_states 426 | 427 | 428 | class BertOutput(nn.Module): # FFN 429 | def __init__(self, config): 430 | super().__init__() 431 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 432 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 433 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 434 | 435 | def forward(self, hidden_states, input_tensor): 436 | hidden_states = self.dense(hidden_states) 437 | hidden_states = self.dropout(hidden_states) 438 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 439 | return hidden_states 440 | 441 | 442 | class BertLayer(nn.Module): 443 | def __init__(self, config): 444 | super().__init__() 445 | self.chunk_size_feed_forward = config.chunk_size_feed_forward 446 | self.seq_len_dim = 1 447 | self.attention = BertAttention(config) 448 | self.is_decoder = config.is_decoder 449 | self.add_cross_attention = config.add_cross_attention 450 | if self.add_cross_attention: 451 | assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added" 452 | self.crossattention = BertAttention(config) 453 | self.intermediate = BertIntermediate(config) 454 | self.output = BertOutput(config) 455 | 456 | def forward( 457 | self, 458 | hidden_states, 459 | attention_mask=None, 460 | head_mask=None, 461 | encoder_hidden_states=None, 462 | encoder_attention_mask=None, 463 | past_key_value=None, 464 | output_attentions=False, 465 | ### add: 466 | current_layer=0 467 | ): 468 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 469 | # self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 470 | self_attention_outputs = self.attention( 471 | hidden_states, 472 | attention_mask, 473 | head_mask, 474 | output_attentions=output_attentions, 475 | # past_key_value=self_attn_past_key_value, 476 | past_key_value=past_key_value, 477 | ### add: 478 | current_layer=current_layer 479 | ) 480 | attention_output = self_attention_outputs[0] 481 | 482 | # if decoder, the last output is tuple of self-attn cache 483 | if self.is_decoder: 484 | outputs = self_attention_outputs[1:-1] 485 | present_key_value = self_attention_outputs[-1] 486 | else: 487 | outputs = self_attention_outputs[1:] # add self attentions if we output attention weights 488 | 489 | cross_attn_present_key_value = None 490 | if self.is_decoder and encoder_hidden_states is not None: 491 | assert hasattr( 492 | self, "crossattention" 493 | ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" 494 | 495 | # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple 496 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None 497 | cross_attention_outputs = self.crossattention( 498 | attention_output, 499 | attention_mask, 500 | head_mask, 501 | encoder_hidden_states, 502 | encoder_attention_mask, 503 | cross_attn_past_key_value, 504 | output_attentions, 505 | ) 506 | attention_output = cross_attention_outputs[0] 507 | outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights 508 | 509 | # add cross-attn cache to positions 3,4 of present_key_value tuple 510 | cross_attn_present_key_value = cross_attention_outputs[-1] 511 | present_key_value = present_key_value + cross_attn_present_key_value 512 | 513 | layer_output = apply_chunking_to_forward( 514 | self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output 515 | ) 516 | outputs = (layer_output,) + outputs 517 | 518 | # if decoder, return the attn key/values as the last output 519 | if self.is_decoder: 520 | outputs = outputs + (present_key_value,) 521 | 522 | return outputs 523 | 524 | def feed_forward_chunk(self, attention_output): 525 | intermediate_output = self.intermediate(attention_output) 526 | layer_output = self.output(intermediate_output, attention_output) 527 | return layer_output 528 | 529 | 530 | class BertEncoder(nn.Module): 531 | def __init__(self, config): 532 | super().__init__() 533 | self.config = config 534 | self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) 535 | self.gradient_checkpointing = False 536 | 537 | def forward( 538 | self, 539 | hidden_states, 540 | attention_mask=None, 541 | head_mask=None, 542 | encoder_hidden_states=None, 543 | encoder_attention_mask=None, 544 | past_key_values=None, 545 | use_cache=None, 546 | output_attentions=False, 547 | output_hidden_states=False, 548 | return_dict=True, 549 | ): 550 | all_hidden_states = () if output_hidden_states else None 551 | all_self_attentions = () if output_attentions else None 552 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 553 | 554 | next_decoder_cache = () if use_cache else None 555 | for i, layer_module in enumerate(self.layer): 556 | if output_hidden_states: 557 | all_hidden_states = all_hidden_states + (hidden_states,) 558 | 559 | layer_head_mask = head_mask[i] if head_mask is not None else None 560 | past_key_value = past_key_values[i] if past_key_values is not None else None 561 | 562 | if self.gradient_checkpointing and self.training: 563 | 564 | if use_cache: 565 | logger.warning( 566 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 567 | ) 568 | use_cache = False 569 | 570 | def create_custom_forward(module): 571 | def custom_forward(*inputs): 572 | return module(*inputs, past_key_value, output_attentions) 573 | 574 | return custom_forward 575 | 576 | layer_outputs = torch.utils.checkpoint.checkpoint( 577 | create_custom_forward(layer_module), 578 | hidden_states, 579 | attention_mask, 580 | layer_head_mask, 581 | encoder_hidden_states, 582 | encoder_attention_mask, 583 | ) 584 | else: 585 | layer_outputs = layer_module( 586 | hidden_states, 587 | attention_mask, 588 | layer_head_mask, 589 | encoder_hidden_states, 590 | encoder_attention_mask, 591 | past_key_value, 592 | output_attentions, 593 | ### add: 594 | current_layer=i 595 | ) 596 | 597 | hidden_states = layer_outputs[0] 598 | if use_cache: 599 | next_decoder_cache += (layer_outputs[-1],) 600 | if output_attentions: 601 | all_self_attentions = all_self_attentions + (layer_outputs[1],) 602 | if self.config.add_cross_attention: 603 | all_cross_attentions = all_cross_attentions + (layer_outputs[2],) 604 | 605 | if output_hidden_states: 606 | all_hidden_states = all_hidden_states + (hidden_states,) 607 | 608 | if not return_dict: 609 | return tuple( 610 | v 611 | for v in [ 612 | hidden_states, 613 | next_decoder_cache, 614 | all_hidden_states, 615 | all_self_attentions, 616 | all_cross_attentions, 617 | ] 618 | if v is not None 619 | ) 620 | return BaseModelOutputWithPastAndCrossAttentions( 621 | last_hidden_state=hidden_states, 622 | past_key_values=next_decoder_cache, 623 | hidden_states=all_hidden_states, 624 | attentions=all_self_attentions, 625 | cross_attentions=all_cross_attentions, 626 | ) 627 | 628 | 629 | class BertPooler(nn.Module): 630 | def __init__(self, config): 631 | super().__init__() 632 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 633 | self.activation = nn.Tanh() 634 | 635 | def forward(self, hidden_states): 636 | # We "pool" the model by simply taking the hidden state corresponding 637 | # to the first token. 638 | first_token_tensor = hidden_states[:, 0] 639 | pooled_output = self.dense(first_token_tensor) 640 | pooled_output = self.activation(pooled_output) 641 | return pooled_output 642 | 643 | 644 | class BertPredictionHeadTransform(nn.Module): 645 | def __init__(self, config): 646 | super().__init__() 647 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 648 | if isinstance(config.hidden_act, str): 649 | self.transform_act_fn = ACT2FN[config.hidden_act] 650 | else: 651 | self.transform_act_fn = config.hidden_act 652 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 653 | 654 | def forward(self, hidden_states): 655 | hidden_states = self.dense(hidden_states) 656 | hidden_states = self.transform_act_fn(hidden_states) 657 | hidden_states = self.LayerNorm(hidden_states) 658 | return hidden_states 659 | 660 | 661 | class BertLMPredictionHead(nn.Module): 662 | def __init__(self, config): 663 | super().__init__() 664 | self.transform = BertPredictionHeadTransform(config) 665 | 666 | # The output weights are the same as the input embeddings, but there is 667 | # an output-only bias for each token. 668 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 669 | 670 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 671 | 672 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` 673 | self.decoder.bias = self.bias 674 | 675 | def forward(self, hidden_states): 676 | hidden_states = self.transform(hidden_states) 677 | hidden_states = self.decoder(hidden_states) 678 | return hidden_states 679 | 680 | 681 | class BertOnlyMLMHead(nn.Module): 682 | def __init__(self, config): 683 | super().__init__() 684 | self.predictions = BertLMPredictionHead(config) 685 | 686 | def forward(self, sequence_output): 687 | prediction_scores = self.predictions(sequence_output) 688 | return prediction_scores 689 | 690 | 691 | class BertOnlyNSPHead(nn.Module): 692 | def __init__(self, config): 693 | super().__init__() 694 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 695 | 696 | def forward(self, pooled_output): 697 | seq_relationship_score = self.seq_relationship(pooled_output) 698 | return seq_relationship_score 699 | 700 | 701 | class BertPreTrainingHeads(nn.Module): 702 | def __init__(self, config): 703 | super().__init__() 704 | self.predictions = BertLMPredictionHead(config) 705 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 706 | 707 | def forward(self, sequence_output, pooled_output): 708 | prediction_scores = self.predictions(sequence_output) 709 | seq_relationship_score = self.seq_relationship(pooled_output) 710 | return prediction_scores, seq_relationship_score 711 | 712 | 713 | class BertPreTrainedModel(PreTrainedModel): 714 | """ 715 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 716 | models. 717 | """ 718 | 719 | config_class = BertConfig 720 | load_tf_weights = load_tf_weights_in_bert 721 | base_model_prefix = "bert" 722 | supports_gradient_checkpointing = True 723 | _keys_to_ignore_on_load_missing = [r"position_ids"] 724 | 725 | def _init_weights(self, module): 726 | """Initialize the weights""" 727 | if isinstance(module, nn.Linear): 728 | # Slightly different from the TF version which uses truncated_normal for initialization 729 | # cf https://github.com/pytorch/pytorch/pull/5617 730 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 731 | if module.bias is not None: 732 | module.bias.data.zero_() 733 | elif isinstance(module, nn.Embedding): 734 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 735 | if module.padding_idx is not None: 736 | module.weight.data[module.padding_idx].zero_() 737 | elif isinstance(module, nn.LayerNorm): 738 | module.bias.data.zero_() 739 | module.weight.data.fill_(1.0) 740 | 741 | def _set_gradient_checkpointing(self, module, value=False): 742 | if isinstance(module, BertEncoder): 743 | module.gradient_checkpointing = value 744 | 745 | 746 | @dataclass 747 | class BertForPreTrainingOutput(ModelOutput): 748 | """ 749 | Output type of :class:`~transformers.BertForPreTraining`. 750 | 751 | Args: 752 | loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`): 753 | Total loss as the sum of the masked language modeling loss and the next sequence prediction 754 | (classification) loss. 755 | prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): 756 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 757 | seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`): 758 | Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation 759 | before SoftMax). 760 | hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): 761 | Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) 762 | of shape :obj:`(batch_size, sequence_length, hidden_size)`. 763 | 764 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 765 | attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): 766 | Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, 767 | sequence_length, sequence_length)`. 768 | 769 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 770 | heads. 771 | """ 772 | 773 | loss: Optional[torch.FloatTensor] = None 774 | prediction_logits: torch.FloatTensor = None 775 | seq_relationship_logits: torch.FloatTensor = None 776 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 777 | attentions: Optional[Tuple[torch.FloatTensor]] = None 778 | 779 | 780 | BERT_START_DOCSTRING = r""" 781 | 782 | This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic 783 | methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, 784 | pruning heads etc.) 785 | 786 | This model is also a PyTorch `torch.nn.Module `__ 787 | subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to 788 | general usage and behavior. 789 | 790 | Parameters: 791 | config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model. 792 | Initializing with a config file does not load the weights associated with the model, only the 793 | configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model 794 | weights. 795 | """ 796 | 797 | BERT_INPUTS_DOCSTRING = r""" 798 | Args: 799 | input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): 800 | Indices of input sequence tokens in the vocabulary. 801 | 802 | Indices can be obtained using :class:`~transformers.BertTokenizer`. See 803 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for 804 | details. 805 | 806 | `What are input IDs? <../glossary.html#input-ids>`__ 807 | attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): 808 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 809 | 810 | - 1 for tokens that are **not masked**, 811 | - 0 for tokens that are **masked**. 812 | 813 | `What are attention masks? <../glossary.html#attention-mask>`__ 814 | token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): 815 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 816 | 1]``: 817 | 818 | - 0 corresponds to a `sentence A` token, 819 | - 1 corresponds to a `sentence B` token. 820 | 821 | `What are token type IDs? <../glossary.html#token-type-ids>`_ 822 | position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): 823 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, 824 | config.max_position_embeddings - 1]``. 825 | 826 | `What are position IDs? <../glossary.html#position-ids>`_ 827 | head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): 828 | Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: 829 | 830 | - 1 indicates the head is **not masked**, 831 | - 0 indicates the head is **masked**. 832 | 833 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): 834 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. 835 | This is useful if you want more control over how to convert :obj:`input_ids` indices into associated 836 | vectors than the model's internal embedding lookup matrix. 837 | output_attentions (:obj:`bool`, `optional`): 838 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned 839 | tensors for more detail. 840 | output_hidden_states (:obj:`bool`, `optional`): 841 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for 842 | more detail. 843 | return_dict (:obj:`bool`, `optional`): 844 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 845 | """ 846 | 847 | 848 | @add_start_docstrings( 849 | "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", 850 | BERT_START_DOCSTRING, 851 | ) 852 | class BertModel(BertPreTrainedModel): 853 | """ 854 | 855 | The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of 856 | cross-attention is added between the self-attention layers, following the architecture described in `Attention is 857 | all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, 858 | Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. 859 | 860 | To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration 861 | set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder` 862 | argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an 863 | input to the forward pass. 864 | """ 865 | 866 | def __init__(self, config, add_pooling_layer=True): 867 | super().__init__(config) 868 | self.config = config 869 | 870 | self.embeddings = BertEmbeddings(config) 871 | self.encoder = BertEncoder(config) 872 | 873 | self.pooler = BertPooler(config) if add_pooling_layer else None 874 | 875 | self.init_weights() 876 | 877 | def get_input_embeddings(self): 878 | return self.embeddings.word_embeddings 879 | 880 | def set_input_embeddings(self, value): 881 | self.embeddings.word_embeddings = value 882 | 883 | def _prune_heads(self, heads_to_prune): 884 | """ 885 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 886 | class PreTrainedModel 887 | """ 888 | for layer, heads in heads_to_prune.items(): 889 | self.encoder.layer[layer].attention.prune_heads(heads) 890 | 891 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 892 | @add_code_sample_docstrings( 893 | tokenizer_class=_TOKENIZER_FOR_DOC, 894 | checkpoint=_CHECKPOINT_FOR_DOC, 895 | output_type=BaseModelOutputWithPoolingAndCrossAttentions, 896 | config_class=_CONFIG_FOR_DOC, 897 | ) 898 | def forward( 899 | self, 900 | input_ids=None, 901 | attention_mask=None, 902 | token_type_ids=None, 903 | position_ids=None, 904 | head_mask=None, 905 | inputs_embeds=None, 906 | encoder_hidden_states=None, 907 | encoder_attention_mask=None, 908 | past_key_values=None, 909 | use_cache=None, 910 | output_attentions=None, 911 | output_hidden_states=None, 912 | return_dict=None, 913 | ): 914 | r""" 915 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 916 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if 917 | the model is configured as a decoder. 918 | encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 919 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in 920 | the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: 921 | 922 | - 1 for tokens that are **not masked**, 923 | - 0 for tokens that are **masked**. 924 | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 925 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 926 | 927 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` 928 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` 929 | instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. 930 | use_cache (:obj:`bool`, `optional`): 931 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up 932 | decoding (see :obj:`past_key_values`). 933 | """ 934 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 935 | output_hidden_states = ( 936 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 937 | ) 938 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 939 | 940 | if self.config.is_decoder: 941 | use_cache = use_cache if use_cache is not None else self.config.use_cache 942 | else: 943 | use_cache = False 944 | 945 | if input_ids is not None and inputs_embeds is not None: 946 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 947 | elif input_ids is not None: 948 | input_shape = input_ids.size() 949 | elif inputs_embeds is not None: 950 | input_shape = inputs_embeds.size()[:-1] 951 | else: 952 | raise ValueError("You have to specify either input_ids or inputs_embeds") 953 | 954 | batch_size, seq_length = input_shape 955 | device = input_ids.device if input_ids is not None else inputs_embeds.device 956 | 957 | # past_key_values_length 958 | # past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 959 | # add: 960 | past_key_values_length = 0 # position_id 961 | ## 962 | 963 | 964 | if attention_mask is None: 965 | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) 966 | if token_type_ids is None: 967 | if hasattr(self.embeddings, "token_type_ids"): 968 | buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] 969 | buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) 970 | token_type_ids = buffered_token_type_ids_expanded 971 | else: 972 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 973 | 974 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 975 | # ourselves in which case we just need to make it broadcastable to all heads. 976 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) 977 | # If a 2D or 3D attention mask is provided for the cross-attention 978 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 979 | if self.config.is_decoder and encoder_hidden_states is not None: 980 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 981 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 982 | if encoder_attention_mask is None: 983 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 984 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 985 | else: 986 | encoder_extended_attention_mask = None 987 | 988 | # Prepare head mask if needed 989 | # 1.0 in head_mask indicate we keep the head 990 | # attention_probs has shape bsz x n_heads x N x N 991 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 992 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 993 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) # [None]*12 994 | 995 | embedding_output = self.embeddings( 996 | input_ids=input_ids, 997 | position_ids=position_ids, 998 | token_type_ids=token_type_ids, 999 | inputs_embeds=inputs_embeds, 1000 | past_key_values_length=past_key_values_length, 1001 | ) 1002 | encoder_outputs = self.encoder( 1003 | embedding_output, 1004 | attention_mask=extended_attention_mask, 1005 | head_mask=head_mask, 1006 | encoder_hidden_states=encoder_hidden_states, 1007 | encoder_attention_mask=encoder_extended_attention_mask, 1008 | past_key_values=past_key_values, 1009 | use_cache=use_cache, 1010 | output_attentions=output_attentions, 1011 | output_hidden_states=output_hidden_states, 1012 | return_dict=return_dict, 1013 | ) 1014 | sequence_output = encoder_outputs[0] 1015 | pooled_output = self.pooler(sequence_output) if self.pooler is not None else None 1016 | 1017 | if not return_dict: 1018 | return (sequence_output, pooled_output) + encoder_outputs[1:] 1019 | 1020 | return BaseModelOutputWithPoolingAndCrossAttentions( 1021 | last_hidden_state=sequence_output, 1022 | pooler_output=pooled_output, 1023 | past_key_values=encoder_outputs.past_key_values, 1024 | hidden_states=encoder_outputs.hidden_states, 1025 | attentions=encoder_outputs.attentions, 1026 | cross_attentions=encoder_outputs.cross_attentions, 1027 | ) 1028 | 1029 | 1030 | @add_start_docstrings( 1031 | """ 1032 | Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next 1033 | sentence prediction (classification)` head. 1034 | """, 1035 | BERT_START_DOCSTRING, 1036 | ) 1037 | class BertForPreTraining(BertPreTrainedModel): 1038 | def __init__(self, config): 1039 | super().__init__(config) 1040 | 1041 | self.bert = BertModel(config) 1042 | self.cls = BertPreTrainingHeads(config) 1043 | 1044 | self.init_weights() 1045 | 1046 | def get_output_embeddings(self): 1047 | return self.cls.predictions.decoder 1048 | 1049 | def set_output_embeddings(self, new_embeddings): 1050 | self.cls.predictions.decoder = new_embeddings 1051 | 1052 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1053 | @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) 1054 | def forward( 1055 | self, 1056 | input_ids=None, 1057 | attention_mask=None, 1058 | token_type_ids=None, 1059 | position_ids=None, 1060 | head_mask=None, 1061 | inputs_embeds=None, 1062 | labels=None, 1063 | next_sentence_label=None, 1064 | output_attentions=None, 1065 | output_hidden_states=None, 1066 | return_dict=None, 1067 | ): 1068 | r""" 1069 | labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`): 1070 | Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., 1071 | config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored 1072 | (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` 1073 | next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`): 1074 | Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair 1075 | (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``: 1076 | 1077 | - 0 indicates sequence B is a continuation of sequence A, 1078 | - 1 indicates sequence B is a random sequence. 1079 | kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): 1080 | Used to hide legacy arguments that have been deprecated. 1081 | 1082 | Returns: 1083 | 1084 | Example:: 1085 | 1086 | >>> from transformers import BertTokenizer, BertForPreTraining 1087 | >>> import torch 1088 | 1089 | >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 1090 | >>> model = BertForPreTraining.from_pretrained('bert-base-uncased') 1091 | 1092 | >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") 1093 | >>> outputs = model(**inputs) 1094 | 1095 | >>> prediction_logits = outputs.prediction_logits 1096 | >>> seq_relationship_logits = outputs.seq_relationship_logits 1097 | """ 1098 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1099 | 1100 | outputs = self.bert( 1101 | input_ids, 1102 | attention_mask=attention_mask, 1103 | token_type_ids=token_type_ids, 1104 | position_ids=position_ids, 1105 | head_mask=head_mask, 1106 | inputs_embeds=inputs_embeds, 1107 | output_attentions=output_attentions, 1108 | output_hidden_states=output_hidden_states, 1109 | return_dict=return_dict, 1110 | ) 1111 | 1112 | sequence_output, pooled_output = outputs[:2] 1113 | prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) 1114 | 1115 | total_loss = None 1116 | if labels is not None and next_sentence_label is not None: 1117 | loss_fct = CrossEntropyLoss() 1118 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 1119 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 1120 | total_loss = masked_lm_loss + next_sentence_loss 1121 | 1122 | if not return_dict: 1123 | output = (prediction_scores, seq_relationship_score) + outputs[2:] 1124 | return ((total_loss,) + output) if total_loss is not None else output 1125 | 1126 | return BertForPreTrainingOutput( 1127 | loss=total_loss, 1128 | prediction_logits=prediction_scores, 1129 | seq_relationship_logits=seq_relationship_score, 1130 | hidden_states=outputs.hidden_states, 1131 | attentions=outputs.attentions, 1132 | ) 1133 | 1134 | 1135 | @add_start_docstrings( 1136 | """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING 1137 | ) 1138 | class BertLMHeadModel(BertPreTrainedModel): 1139 | 1140 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 1141 | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] 1142 | 1143 | def __init__(self, config): 1144 | super().__init__(config) 1145 | 1146 | if not config.is_decoder: 1147 | logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`") 1148 | 1149 | self.bert = BertModel(config, add_pooling_layer=False) 1150 | self.cls = BertOnlyMLMHead(config) 1151 | 1152 | self.init_weights() 1153 | 1154 | def get_output_embeddings(self): 1155 | return self.cls.predictions.decoder 1156 | 1157 | def set_output_embeddings(self, new_embeddings): 1158 | self.cls.predictions.decoder = new_embeddings 1159 | 1160 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1161 | @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) 1162 | def forward( 1163 | self, 1164 | input_ids=None, 1165 | attention_mask=None, 1166 | token_type_ids=None, 1167 | position_ids=None, 1168 | head_mask=None, 1169 | inputs_embeds=None, 1170 | encoder_hidden_states=None, 1171 | encoder_attention_mask=None, 1172 | labels=None, 1173 | past_key_values=None, 1174 | use_cache=None, 1175 | output_attentions=None, 1176 | output_hidden_states=None, 1177 | return_dict=None, 1178 | ): 1179 | r""" 1180 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 1181 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if 1182 | the model is configured as a decoder. 1183 | encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 1184 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in 1185 | the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: 1186 | 1187 | - 1 for tokens that are **not masked**, 1188 | - 0 for tokens that are **masked**. 1189 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 1190 | Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in 1191 | ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are 1192 | ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` 1193 | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 1194 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 1195 | 1196 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` 1197 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` 1198 | instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. 1199 | use_cache (:obj:`bool`, `optional`): 1200 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up 1201 | decoding (see :obj:`past_key_values`). 1202 | 1203 | Returns: 1204 | 1205 | Example:: 1206 | 1207 | >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig 1208 | >>> import torch 1209 | 1210 | >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') 1211 | >>> config = BertConfig.from_pretrained("bert-base-cased") 1212 | >>> config.is_decoder = True 1213 | >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) 1214 | 1215 | >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") 1216 | >>> outputs = model(**inputs) 1217 | 1218 | >>> prediction_logits = outputs.logits 1219 | """ 1220 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1221 | if labels is not None: 1222 | use_cache = False 1223 | 1224 | outputs = self.bert( 1225 | input_ids, 1226 | attention_mask=attention_mask, 1227 | token_type_ids=token_type_ids, 1228 | position_ids=position_ids, 1229 | head_mask=head_mask, 1230 | inputs_embeds=inputs_embeds, 1231 | encoder_hidden_states=encoder_hidden_states, 1232 | encoder_attention_mask=encoder_attention_mask, 1233 | past_key_values=past_key_values, 1234 | use_cache=use_cache, 1235 | output_attentions=output_attentions, 1236 | output_hidden_states=output_hidden_states, 1237 | return_dict=return_dict, 1238 | ) 1239 | 1240 | sequence_output = outputs[0] 1241 | prediction_scores = self.cls(sequence_output) 1242 | 1243 | lm_loss = None 1244 | if labels is not None: 1245 | # we are doing next-token prediction; shift prediction scores and input ids by one 1246 | shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() 1247 | labels = labels[:, 1:].contiguous() 1248 | loss_fct = CrossEntropyLoss() 1249 | lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 1250 | 1251 | if not return_dict: 1252 | output = (prediction_scores,) + outputs[2:] 1253 | return ((lm_loss,) + output) if lm_loss is not None else output 1254 | 1255 | return CausalLMOutputWithCrossAttentions( 1256 | loss=lm_loss, 1257 | logits=prediction_scores, 1258 | past_key_values=outputs.past_key_values, 1259 | hidden_states=outputs.hidden_states, 1260 | attentions=outputs.attentions, 1261 | cross_attentions=outputs.cross_attentions, 1262 | ) 1263 | 1264 | def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): 1265 | input_shape = input_ids.shape 1266 | # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly 1267 | if attention_mask is None: 1268 | attention_mask = input_ids.new_ones(input_shape) 1269 | 1270 | # cut decoder_input_ids if past is used 1271 | if past is not None: 1272 | input_ids = input_ids[:, -1:] 1273 | 1274 | return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} 1275 | 1276 | def _reorder_cache(self, past, beam_idx): 1277 | reordered_past = () 1278 | for layer_past in past: 1279 | reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) 1280 | return reordered_past 1281 | 1282 | 1283 | @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING) 1284 | class BertForMaskedLM(BertPreTrainedModel): 1285 | 1286 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 1287 | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] 1288 | 1289 | def __init__(self, config): 1290 | super().__init__(config) 1291 | 1292 | if config.is_decoder: 1293 | logger.warning( 1294 | "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " 1295 | "bi-directional self-attention." 1296 | ) 1297 | 1298 | self.bert = BertModel(config, add_pooling_layer=False) 1299 | self.cls = BertOnlyMLMHead(config) 1300 | 1301 | self.init_weights() 1302 | 1303 | def get_output_embeddings(self): 1304 | return self.cls.predictions.decoder 1305 | 1306 | def set_output_embeddings(self, new_embeddings): 1307 | self.cls.predictions.decoder = new_embeddings 1308 | 1309 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1310 | @add_code_sample_docstrings( 1311 | tokenizer_class=_TOKENIZER_FOR_DOC, 1312 | checkpoint=_CHECKPOINT_FOR_DOC, 1313 | output_type=MaskedLMOutput, 1314 | config_class=_CONFIG_FOR_DOC, 1315 | ) 1316 | def forward( 1317 | self, 1318 | input_ids=None, 1319 | attention_mask=None, 1320 | token_type_ids=None, 1321 | position_ids=None, 1322 | head_mask=None, 1323 | inputs_embeds=None, 1324 | encoder_hidden_states=None, 1325 | encoder_attention_mask=None, 1326 | labels=None, 1327 | output_attentions=None, 1328 | output_hidden_states=None, 1329 | return_dict=None, 1330 | ): 1331 | r""" 1332 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 1333 | Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., 1334 | config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored 1335 | (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` 1336 | """ 1337 | 1338 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1339 | 1340 | outputs = self.bert( 1341 | input_ids, 1342 | attention_mask=attention_mask, 1343 | token_type_ids=token_type_ids, 1344 | position_ids=position_ids, 1345 | head_mask=head_mask, 1346 | inputs_embeds=inputs_embeds, 1347 | encoder_hidden_states=encoder_hidden_states, 1348 | encoder_attention_mask=encoder_attention_mask, 1349 | output_attentions=output_attentions, 1350 | output_hidden_states=output_hidden_states, 1351 | return_dict=return_dict, 1352 | ) 1353 | 1354 | sequence_output = outputs[0] 1355 | prediction_scores = self.cls(sequence_output) 1356 | 1357 | masked_lm_loss = None 1358 | if labels is not None: 1359 | loss_fct = CrossEntropyLoss() # -100 index = padding token 1360 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 1361 | 1362 | if not return_dict: 1363 | output = (prediction_scores,) + outputs[2:] 1364 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 1365 | 1366 | return MaskedLMOutput( 1367 | loss=masked_lm_loss, 1368 | logits=prediction_scores, 1369 | hidden_states=outputs.hidden_states, 1370 | attentions=outputs.attentions, 1371 | ) 1372 | 1373 | def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): 1374 | input_shape = input_ids.shape 1375 | effective_batch_size = input_shape[0] 1376 | 1377 | # add a dummy token 1378 | assert self.config.pad_token_id is not None, "The PAD token should be defined for generation" 1379 | attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) 1380 | dummy_token = torch.full( 1381 | (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device 1382 | ) 1383 | input_ids = torch.cat([input_ids, dummy_token], dim=1) 1384 | 1385 | return {"input_ids": input_ids, "attention_mask": attention_mask} 1386 | 1387 | 1388 | @add_start_docstrings( 1389 | """Bert Model with a `next sentence prediction (classification)` head on top. """, 1390 | BERT_START_DOCSTRING, 1391 | ) 1392 | class BertForNextSentencePrediction(BertPreTrainedModel): 1393 | def __init__(self, config): 1394 | super().__init__(config) 1395 | 1396 | self.bert = BertModel(config) 1397 | self.cls = BertOnlyNSPHead(config) 1398 | 1399 | self.init_weights() 1400 | 1401 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1402 | @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) 1403 | def forward( 1404 | self, 1405 | input_ids=None, 1406 | attention_mask=None, 1407 | token_type_ids=None, 1408 | position_ids=None, 1409 | head_mask=None, 1410 | inputs_embeds=None, 1411 | labels=None, 1412 | output_attentions=None, 1413 | output_hidden_states=None, 1414 | return_dict=None, 1415 | **kwargs, 1416 | ): 1417 | r""" 1418 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1419 | Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair 1420 | (see ``input_ids`` docstring). Indices should be in ``[0, 1]``: 1421 | 1422 | - 0 indicates sequence B is a continuation of sequence A, 1423 | - 1 indicates sequence B is a random sequence. 1424 | 1425 | Returns: 1426 | 1427 | Example:: 1428 | 1429 | >>> from transformers import BertTokenizer, BertForNextSentencePrediction 1430 | >>> import torch 1431 | 1432 | >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 1433 | >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased') 1434 | 1435 | >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." 1436 | >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." 1437 | >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt') 1438 | 1439 | >>> outputs = model(**encoding, labels=torch.LongTensor([1])) 1440 | >>> logits = outputs.logits 1441 | >>> assert logits[0, 0] < logits[0, 1] # next sentence was random 1442 | """ 1443 | 1444 | if "next_sentence_label" in kwargs: 1445 | warnings.warn( 1446 | "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.", 1447 | FutureWarning, 1448 | ) 1449 | labels = kwargs.pop("next_sentence_label") 1450 | 1451 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1452 | 1453 | outputs = self.bert( 1454 | input_ids, 1455 | attention_mask=attention_mask, 1456 | token_type_ids=token_type_ids, 1457 | position_ids=position_ids, 1458 | head_mask=head_mask, 1459 | inputs_embeds=inputs_embeds, 1460 | output_attentions=output_attentions, 1461 | output_hidden_states=output_hidden_states, 1462 | return_dict=return_dict, 1463 | ) 1464 | 1465 | pooled_output = outputs[1] 1466 | 1467 | seq_relationship_scores = self.cls(pooled_output) 1468 | 1469 | next_sentence_loss = None 1470 | if labels is not None: 1471 | loss_fct = CrossEntropyLoss() 1472 | next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) 1473 | 1474 | if not return_dict: 1475 | output = (seq_relationship_scores,) + outputs[2:] 1476 | return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output 1477 | 1478 | return NextSentencePredictorOutput( 1479 | loss=next_sentence_loss, 1480 | logits=seq_relationship_scores, 1481 | hidden_states=outputs.hidden_states, 1482 | attentions=outputs.attentions, 1483 | ) 1484 | 1485 | 1486 | @add_start_docstrings( 1487 | """ 1488 | Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled 1489 | output) e.g. for GLUE tasks. 1490 | """, 1491 | BERT_START_DOCSTRING, 1492 | ) 1493 | class BertForSequenceClassification(BertPreTrainedModel): 1494 | def __init__(self, config): 1495 | super().__init__(config) 1496 | self.num_labels = config.num_labels 1497 | self.config = config 1498 | 1499 | self.bert = BertModel(config) 1500 | classifier_dropout = ( 1501 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob 1502 | ) 1503 | self.dropout = nn.Dropout(classifier_dropout) 1504 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 1505 | 1506 | self.init_weights() 1507 | 1508 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1509 | @add_code_sample_docstrings( 1510 | tokenizer_class=_TOKENIZER_FOR_DOC, 1511 | checkpoint=_CHECKPOINT_FOR_DOC, 1512 | output_type=SequenceClassifierOutput, 1513 | config_class=_CONFIG_FOR_DOC, 1514 | ) 1515 | def forward( 1516 | self, 1517 | input_ids=None, 1518 | attention_mask=None, 1519 | token_type_ids=None, 1520 | position_ids=None, 1521 | head_mask=None, 1522 | inputs_embeds=None, 1523 | labels=None, 1524 | output_attentions=None, 1525 | output_hidden_states=None, 1526 | return_dict=None, 1527 | ): 1528 | r""" 1529 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1530 | Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., 1531 | config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), 1532 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1533 | """ 1534 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1535 | 1536 | outputs = self.bert( 1537 | input_ids, 1538 | attention_mask=attention_mask, 1539 | token_type_ids=token_type_ids, 1540 | position_ids=position_ids, 1541 | head_mask=head_mask, 1542 | inputs_embeds=inputs_embeds, 1543 | output_attentions=output_attentions, 1544 | output_hidden_states=output_hidden_states, 1545 | return_dict=return_dict, 1546 | ) 1547 | 1548 | pooled_output = outputs[1] 1549 | 1550 | pooled_output = self.dropout(pooled_output) 1551 | logits = self.classifier(pooled_output) 1552 | 1553 | loss = None 1554 | if labels is not None: 1555 | if self.config.problem_type is None: 1556 | if self.num_labels == 1: 1557 | self.config.problem_type = "regression" 1558 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 1559 | self.config.problem_type = "single_label_classification" 1560 | else: 1561 | self.config.problem_type = "multi_label_classification" 1562 | 1563 | if self.config.problem_type == "regression": 1564 | loss_fct = MSELoss() 1565 | if self.num_labels == 1: 1566 | loss = loss_fct(logits.squeeze(), labels.squeeze()) 1567 | else: 1568 | loss = loss_fct(logits, labels) 1569 | elif self.config.problem_type == "single_label_classification": 1570 | loss_fct = CrossEntropyLoss() 1571 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1572 | elif self.config.problem_type == "multi_label_classification": 1573 | loss_fct = BCEWithLogitsLoss() 1574 | loss = loss_fct(logits, labels) 1575 | if not return_dict: 1576 | output = (logits,) + outputs[2:] 1577 | return ((loss,) + output) if loss is not None else output 1578 | 1579 | return SequenceClassifierOutput( 1580 | loss=loss, 1581 | logits=logits, 1582 | hidden_states=outputs.hidden_states, 1583 | attentions=outputs.attentions, 1584 | ) 1585 | 1586 | 1587 | @add_start_docstrings( 1588 | """ 1589 | Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a 1590 | softmax) e.g. for RocStories/SWAG tasks. 1591 | """, 1592 | BERT_START_DOCSTRING, 1593 | ) 1594 | class BertForMultipleChoice(BertPreTrainedModel): 1595 | def __init__(self, config): 1596 | super().__init__(config) 1597 | 1598 | self.bert = BertModel(config) 1599 | classifier_dropout = ( 1600 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob 1601 | ) 1602 | self.dropout = nn.Dropout(classifier_dropout) 1603 | self.classifier = nn.Linear(config.hidden_size, 1) 1604 | 1605 | self.init_weights() 1606 | 1607 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) 1608 | @add_code_sample_docstrings( 1609 | tokenizer_class=_TOKENIZER_FOR_DOC, 1610 | checkpoint=_CHECKPOINT_FOR_DOC, 1611 | output_type=MultipleChoiceModelOutput, 1612 | config_class=_CONFIG_FOR_DOC, 1613 | ) 1614 | def forward( 1615 | self, 1616 | input_ids=None, 1617 | attention_mask=None, 1618 | token_type_ids=None, 1619 | position_ids=None, 1620 | head_mask=None, 1621 | inputs_embeds=None, 1622 | labels=None, 1623 | output_attentions=None, 1624 | output_hidden_states=None, 1625 | return_dict=None, 1626 | ): 1627 | r""" 1628 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1629 | Labels for computing the multiple choice classification loss. Indices should be in ``[0, ..., 1630 | num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See 1631 | :obj:`input_ids` above) 1632 | """ 1633 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1634 | num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] 1635 | 1636 | input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None 1637 | attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None 1638 | token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None 1639 | position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None 1640 | inputs_embeds = ( 1641 | inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) 1642 | if inputs_embeds is not None 1643 | else None 1644 | ) 1645 | 1646 | outputs = self.bert( 1647 | input_ids, 1648 | attention_mask=attention_mask, 1649 | token_type_ids=token_type_ids, 1650 | position_ids=position_ids, 1651 | head_mask=head_mask, 1652 | inputs_embeds=inputs_embeds, 1653 | output_attentions=output_attentions, 1654 | output_hidden_states=output_hidden_states, 1655 | return_dict=return_dict, 1656 | ) 1657 | 1658 | pooled_output = outputs[1] 1659 | 1660 | pooled_output = self.dropout(pooled_output) 1661 | logits = self.classifier(pooled_output) 1662 | reshaped_logits = logits.view(-1, num_choices) 1663 | 1664 | loss = None 1665 | if labels is not None: 1666 | loss_fct = CrossEntropyLoss() 1667 | loss = loss_fct(reshaped_logits, labels) 1668 | 1669 | if not return_dict: 1670 | output = (reshaped_logits,) + outputs[2:] 1671 | return ((loss,) + output) if loss is not None else output 1672 | 1673 | return MultipleChoiceModelOutput( 1674 | loss=loss, 1675 | logits=reshaped_logits, 1676 | hidden_states=outputs.hidden_states, 1677 | attentions=outputs.attentions, 1678 | ) 1679 | 1680 | 1681 | @add_start_docstrings( 1682 | """ 1683 | Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for 1684 | Named-Entity-Recognition (NER) tasks. 1685 | """, 1686 | BERT_START_DOCSTRING, 1687 | ) 1688 | class BertForTokenClassification(BertPreTrainedModel): 1689 | 1690 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 1691 | 1692 | def __init__(self, config): 1693 | super().__init__(config) 1694 | self.num_labels = config.num_labels 1695 | 1696 | self.bert = BertModel(config, add_pooling_layer=False) 1697 | # classifier_dropout = ( 1698 | # config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob 1699 | # ) 1700 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1701 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 1702 | 1703 | self.init_weights() 1704 | 1705 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1706 | @add_code_sample_docstrings( 1707 | tokenizer_class=_TOKENIZER_FOR_DOC, 1708 | checkpoint=_CHECKPOINT_FOR_DOC, 1709 | output_type=TokenClassifierOutput, 1710 | config_class=_CONFIG_FOR_DOC, 1711 | ) 1712 | def forward( 1713 | self, 1714 | input_ids=None, 1715 | attention_mask=None, 1716 | token_type_ids=None, 1717 | position_ids=None, 1718 | head_mask=None, 1719 | inputs_embeds=None, 1720 | labels=None, 1721 | output_attentions=None, 1722 | output_hidden_states=None, 1723 | # add: 1724 | past_key_values=None, 1725 | prompt_attention_mask=None, 1726 | # 1727 | return_dict=None, 1728 | ): 1729 | r""" 1730 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 1731 | Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - 1732 | 1]``. 1733 | """ 1734 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1735 | 1736 | outputs = self.bert( 1737 | input_ids, 1738 | attention_mask=prompt_attention_mask, 1739 | token_type_ids=token_type_ids, 1740 | position_ids=position_ids, 1741 | head_mask=head_mask, 1742 | inputs_embeds=inputs_embeds, 1743 | # add: 1744 | past_key_values=past_key_values, 1745 | # 1746 | output_attentions=output_attentions, 1747 | output_hidden_states=output_hidden_states, 1748 | return_dict=return_dict, 1749 | ) 1750 | 1751 | sequence_output = outputs[0] 1752 | 1753 | sequence_output = self.dropout(sequence_output) 1754 | logits = self.classifier(sequence_output) 1755 | 1756 | loss = None 1757 | if labels is not None: 1758 | loss_fct = CrossEntropyLoss() 1759 | # Only keep active parts of the loss 1760 | if attention_mask is not None: 1761 | active_loss = attention_mask.view(-1) == 1 1762 | active_logits = logits.view(-1, self.num_labels) 1763 | active_labels = torch.where( 1764 | active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) 1765 | ) 1766 | loss = loss_fct(active_logits, active_labels) 1767 | else: 1768 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1769 | 1770 | if not return_dict: 1771 | output = (logits,) + outputs[2:] 1772 | return ((loss,) + output) if loss is not None else output 1773 | 1774 | return TokenClassifierOutput( 1775 | loss=loss, 1776 | logits=logits, 1777 | hidden_states=outputs.hidden_states, 1778 | attentions=outputs.attentions, 1779 | ) 1780 | 1781 | 1782 | @add_start_docstrings( 1783 | """ 1784 | Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear 1785 | layers on top of the hidden-states output to compute `span start logits` and `span end logits`). 1786 | """, 1787 | BERT_START_DOCSTRING, 1788 | ) 1789 | class BertForQuestionAnswering(BertPreTrainedModel): 1790 | 1791 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 1792 | 1793 | def __init__(self, config): 1794 | super().__init__(config) 1795 | self.num_labels = config.num_labels 1796 | 1797 | self.bert = BertModel(config, add_pooling_layer=False) 1798 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) 1799 | 1800 | self.init_weights() 1801 | 1802 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1803 | @add_code_sample_docstrings( 1804 | tokenizer_class=_TOKENIZER_FOR_DOC, 1805 | checkpoint=_CHECKPOINT_FOR_DOC, 1806 | output_type=QuestionAnsweringModelOutput, 1807 | config_class=_CONFIG_FOR_DOC, 1808 | ) 1809 | def forward( 1810 | self, 1811 | input_ids=None, 1812 | attention_mask=None, 1813 | token_type_ids=None, 1814 | position_ids=None, 1815 | head_mask=None, 1816 | inputs_embeds=None, 1817 | start_positions=None, 1818 | end_positions=None, 1819 | output_attentions=None, 1820 | output_hidden_states=None, 1821 | return_dict=None, 1822 | ): 1823 | r""" 1824 | start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1825 | Labels for position (index) of the start of the labelled span for computing the token classification loss. 1826 | Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the 1827 | sequence are not taken into account for computing the loss. 1828 | end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1829 | Labels for position (index) of the end of the labelled span for computing the token classification loss. 1830 | Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the 1831 | sequence are not taken into account for computing the loss. 1832 | """ 1833 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1834 | 1835 | outputs = self.bert( 1836 | input_ids, 1837 | attention_mask=attention_mask, 1838 | token_type_ids=token_type_ids, 1839 | position_ids=position_ids, 1840 | head_mask=head_mask, 1841 | inputs_embeds=inputs_embeds, 1842 | output_attentions=output_attentions, 1843 | output_hidden_states=output_hidden_states, 1844 | return_dict=return_dict, 1845 | ) 1846 | 1847 | sequence_output = outputs[0] 1848 | 1849 | logits = self.qa_outputs(sequence_output) 1850 | start_logits, end_logits = logits.split(1, dim=-1) 1851 | start_logits = start_logits.squeeze(-1).contiguous() 1852 | end_logits = end_logits.squeeze(-1).contiguous() 1853 | 1854 | total_loss = None 1855 | if start_positions is not None and end_positions is not None: 1856 | # If we are on multi-GPU, split add a dimension 1857 | if len(start_positions.size()) > 1: 1858 | start_positions = start_positions.squeeze(-1) 1859 | if len(end_positions.size()) > 1: 1860 | end_positions = end_positions.squeeze(-1) 1861 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 1862 | ignored_index = start_logits.size(1) 1863 | start_positions = start_positions.clamp(0, ignored_index) 1864 | end_positions = end_positions.clamp(0, ignored_index) 1865 | 1866 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 1867 | start_loss = loss_fct(start_logits, start_positions) 1868 | end_loss = loss_fct(end_logits, end_positions) 1869 | total_loss = (start_loss + end_loss) / 2 1870 | 1871 | if not return_dict: 1872 | output = (start_logits, end_logits) + outputs[2:] 1873 | return ((total_loss,) + output) if total_loss is not None else output 1874 | 1875 | return QuestionAnsweringModelOutput( 1876 | loss=total_loss, 1877 | start_logits=start_logits, 1878 | end_logits=end_logits, 1879 | hidden_states=outputs.hidden_states, 1880 | attentions=outputs.attentions, 1881 | ) -------------------------------------------------------------------------------- /modules/metrics.py: -------------------------------------------------------------------------------- 1 | def eval_result(true_labels, pred_result, rel2id, logger, use_name=False): 2 | correct = 0 3 | total = len(true_labels) 4 | correct_positive = 0 5 | pred_positive = 0 6 | gold_positive = 0 7 | 8 | neg = -1 9 | for name in ['NA', 'na', 'no_relation', 'Other', 'Others', 'none', 'None']: 10 | if name in rel2id: 11 | if use_name: 12 | neg = name 13 | else: 14 | neg = rel2id[name] 15 | break 16 | for i in range(total): 17 | if use_name: 18 | golden = true_labels[i] 19 | else: 20 | golden = true_labels[i] 21 | 22 | if golden == pred_result[i]: 23 | correct += 1 24 | if golden != neg: 25 | correct_positive += 1 26 | if golden != neg: 27 | gold_positive += 1 28 | if pred_result[i] != neg: 29 | pred_positive += 1 30 | acc = float(correct) / float(total) 31 | try: 32 | micro_p = float(correct_positive) / float(pred_positive) 33 | except: 34 | micro_p = 0 35 | try: 36 | micro_r = float(correct_positive) / float(gold_positive) 37 | except: 38 | micro_r = 0 39 | try: 40 | micro_f1 = 2 * micro_p * micro_r / (micro_p + micro_r) 41 | except: 42 | micro_f1 = 0 43 | 44 | result = {'acc': acc, 'micro_p': micro_p, 'micro_r': micro_r, 'micro_f1': micro_f1} 45 | logger.info('Evaluation result: {}.'.format(result)) 46 | return result -------------------------------------------------------------------------------- /modules/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | from tqdm import tqdm 4 | import random 5 | from sklearn.metrics import classification_report as sk_classification_report 6 | from seqeval.metrics import classification_report 7 | from transformers.optimization import get_linear_schedule_with_warmup 8 | 9 | from .metrics import eval_result 10 | 11 | class BaseTrainer(object): 12 | def train(self): 13 | raise NotImplementedError() 14 | 15 | def evaluate(self): 16 | raise NotImplementedError() 17 | 18 | def test(self): 19 | raise NotImplementedError() 20 | 21 | class RETrainer(BaseTrainer): 22 | def __init__(self, train_data=None, dev_data=None, test_data=None, model=None, processor=None, args=None, logger=None, writer=None) -> None: 23 | self.train_data = train_data 24 | self.dev_data = dev_data 25 | self.test_data = test_data 26 | self.model = model 27 | self.processor = processor 28 | self.re_dict = processor.get_relation_dict() 29 | self.logger = logger 30 | self.writer = writer 31 | self.refresh_step = 2 32 | self.best_dev_metric = 0 33 | self.best_test_metric = 0 34 | self.best_dev_epoch = None 35 | self.best_test_epoch = None 36 | self.optimizer = None 37 | if self.train_data is not None: 38 | self.train_num_steps = len(self.train_data) * args.num_epochs 39 | self.step = 0 40 | self.args = args 41 | if self.args.use_prompt: 42 | self.before_multimodal_train() 43 | else: 44 | self.before_train() 45 | 46 | def train(self): 47 | self.step = 0 48 | self.model.train() 49 | self.logger.info("***** Running training *****") 50 | self.logger.info(" Num instance = %d", len(self.train_data)*self.args.batch_size) 51 | self.logger.info(" Num epoch = %d", self.args.num_epochs) 52 | self.logger.info(" Batch size = %d", self.args.batch_size) 53 | self.logger.info(" Learning rate = {}".format(self.args.lr)) 54 | self.logger.info(" Evaluate begin = %d", self.args.eval_begin_epoch) 55 | 56 | if self.args.load_path is not None: # load model from load_path 57 | self.logger.info("Loading model from {}".format(self.args.load_path)) 58 | self.model.load_state_dict(torch.load(self.args.load_path)) 59 | self.logger.info("Load model successful!") 60 | 61 | with tqdm(total=self.train_num_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True, initial=self.step) as pbar: 62 | self.pbar = pbar 63 | avg_loss = 0 64 | for epoch in range(1, self.args.num_epochs+1): 65 | pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.args.num_epochs)) 66 | for batch in self.train_data: 67 | self.step += 1 68 | batch = (tup.to(self.args.device) if isinstance(tup, torch.Tensor) else tup for tup in batch) 69 | (loss, logits), labels = self._step(batch, mode="train") 70 | avg_loss += loss.detach().cpu().item() 71 | 72 | loss.backward() 73 | self.optimizer.step() 74 | self.scheduler.step() 75 | self.optimizer.zero_grad() 76 | 77 | if self.step % self.refresh_step == 0: 78 | avg_loss = float(avg_loss) / self.refresh_step 79 | print_output = "loss:{:<6.5f}".format(avg_loss) 80 | pbar.update(self.refresh_step) 81 | pbar.set_postfix_str(print_output) 82 | if self.writer: 83 | self.writer.add_scalar(tag='train_loss', scalar_value=avg_loss, global_step=self.step) # tensorbordx 84 | avg_loss = 0 85 | 86 | if epoch >= self.args.eval_begin_epoch: 87 | self.evaluate(epoch) # generator to dev. 88 | 89 | pbar.close() 90 | self.pbar = None 91 | self.logger.info("Get best dev performance at epoch {}, best dev f1 score is {}".format(self.best_dev_epoch, self.best_dev_metric)) 92 | self.logger.info("Get best test performance at epoch {}, best test f1 score is {}".format(self.best_test_epoch, self.best_test_metric)) 93 | 94 | def evaluate(self, epoch): 95 | self.model.eval() 96 | self.logger.info("***** Running evaluate *****") 97 | self.logger.info(" Num instance = %d", len(self.dev_data)*self.args.batch_size) 98 | self.logger.info(" Batch size = %d", self.args.batch_size) 99 | step = 0 100 | true_labels, pred_labels = [], [] 101 | with torch.no_grad(): 102 | with tqdm(total=len(self.dev_data), leave=False, dynamic_ncols=True) as pbar: 103 | pbar.set_description_str(desc="Dev") 104 | total_loss = 0 105 | for batch in self.dev_data: 106 | step += 1 107 | batch = (tup.to(self.args.device) if isinstance(tup, torch.Tensor) else tup for tup in batch) # to cpu/cuda device 108 | (loss, logits), labels = self._step(batch, mode="dev") # logits: batch, 3 109 | total_loss += loss.detach().cpu().item() 110 | 111 | preds = logits.argmax(-1) 112 | true_labels.extend(labels.view(-1).detach().cpu().tolist()) 113 | pred_labels.extend(preds.view(-1).detach().cpu().tolist()) 114 | pbar.update() 115 | # evaluate done 116 | pbar.close() 117 | sk_result = sk_classification_report(y_true=true_labels, y_pred=pred_labels, labels=list(self.re_dict.values())[1:], target_names=list(self.re_dict.keys())[1:], digits=4) 118 | self.logger.info("%s\n", sk_result) 119 | result = eval_result(true_labels, pred_labels, self.re_dict, self.logger) 120 | acc, micro_f1 = round(result['acc']*100, 4), round(result['micro_f1']*100, 4) 121 | if self.writer: 122 | self.writer.add_scalar(tag='dev_acc', scalar_value=acc, global_step=epoch) # tensorbordx 123 | self.writer.add_scalar(tag='dev_f1', scalar_value=micro_f1, global_step=epoch) # tensorbordx 124 | self.writer.add_scalar(tag='dev_loss', scalar_value=total_loss/len(self.dev_data), global_step=epoch) # tensorbordx 125 | 126 | self.logger.info("Epoch {}/{}, best dev f1: {}, best epoch: {}, current dev f1 score: {}, acc: {}."\ 127 | .format(epoch, self.args.num_epochs, self.best_dev_metric, self.best_dev_epoch, micro_f1, acc)) 128 | if micro_f1 >= self.best_dev_metric: # this epoch get best performance 129 | self.logger.info("Get better performance at epoch {}".format(epoch)) 130 | self.best_dev_epoch = epoch 131 | self.best_dev_metric = micro_f1 # update best metric(f1 score) 132 | if self.args.save_path is not None: 133 | torch.save(self.model.state_dict(), self.args.save_path+"/best_model.pth") 134 | self.logger.info("Save best model at {}".format(self.args.save_path)) 135 | 136 | 137 | self.model.train() 138 | 139 | def test(self): 140 | self.model.eval() 141 | self.logger.info("\n***** Running testing *****") 142 | self.logger.info(" Num instance = %d", len(self.test_data)*self.args.batch_size) 143 | self.logger.info(" Batch size = %d", self.args.batch_size) 144 | 145 | if self.args.load_path is not None: # load model from load_path 146 | self.logger.info("Loading model from {}".format(self.args.load_path)) 147 | self.model.load_state_dict(torch.load(self.args.load_path)) 148 | self.logger.info("Load model successful!") 149 | true_labels, pred_labels = [], [] 150 | with torch.no_grad(): 151 | with tqdm(total=len(self.test_data), leave=False, dynamic_ncols=True) as pbar: 152 | pbar.set_description_str(desc="Testing") 153 | total_loss = 0 154 | for batch in self.test_data: 155 | batch = (tup.to(self.args.device) if isinstance(tup, torch.Tensor) else tup for tup in batch) # to cpu/cuda device 156 | (loss, logits), labels = self._step(batch, mode="dev") # logits: batch, 3 157 | total_loss += loss.detach().cpu().item() 158 | 159 | preds = logits.argmax(-1) 160 | true_labels.extend(labels.view(-1).detach().cpu().tolist()) 161 | pred_labels.extend(preds.view(-1).detach().cpu().tolist()) 162 | 163 | pbar.update() 164 | # evaluate done 165 | pbar.close() 166 | sk_result = sk_classification_report(y_true=true_labels, y_pred=pred_labels, labels=list(self.re_dict.values())[1:], target_names=list(self.re_dict.keys())[1:], digits=4) 167 | self.logger.info("%s\n", sk_result) 168 | result = eval_result(true_labels, pred_labels, self.re_dict, self.logger) 169 | acc, micro_f1 = round(result['acc']*100, 4), round(result['micro_f1']*100, 4) 170 | if self.writer: 171 | self.writer.add_scalar(tag='test_acc', scalar_value=acc) # tensorbordx 172 | self.writer.add_scalar(tag='test_f1', scalar_value=micro_f1) # tensorbordx 173 | self.writer.add_scalar(tag='test_loss', scalar_value=total_loss/len(self.test_data)) # tensorbordx 174 | total_loss = 0 175 | self.logger.info("Test f1 score: {}, acc: {}.".format(micro_f1, acc)) 176 | 177 | self.model.train() 178 | 179 | def _step(self, batch, mode="train"): 180 | if mode != "predict": 181 | if self.args.use_prompt: 182 | input_ids, token_type_ids, attention_mask, labels, images, aux_imgs = batch 183 | else: 184 | images, aux_imgs = None, None 185 | input_ids, token_type_ids, attention_mask, labels= batch 186 | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels, images=images, aux_imgs=aux_imgs) 187 | return outputs, labels 188 | 189 | def before_train(self): 190 | no_decay = ['bias', 'LayerNorm.weight'] 191 | optimizer_grouped_parameters = [ 192 | {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2}, 193 | {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 194 | ] 195 | self.optimizer = optim.AdamW(optimizer_grouped_parameters, lr=self.args.lr) 196 | self.scheduler = get_linear_schedule_with_warmup(optimizer=self.optimizer, 197 | num_warmup_steps=self.args.warmup_ratio*self.train_num_steps, 198 | num_training_steps=self.train_num_steps) 199 | self.model.to(self.args.device) 200 | 201 | 202 | def before_multimodal_train(self): 203 | optimizer_grouped_parameters = [] 204 | params = {'lr':self.args.lr, 'weight_decay':1e-2} 205 | params['params'] = [] 206 | for name, param in self.model.named_parameters(): 207 | if 'bert' in name: 208 | params['params'].append(param) 209 | optimizer_grouped_parameters.append(params) 210 | 211 | params = {'lr':self.args.lr, 'weight_decay':1e-2} 212 | params['params'] = [] 213 | for name, param in self.model.named_parameters(): 214 | if 'encoder_conv' in name or 'gates' in name: 215 | params['params'].append(param) 216 | optimizer_grouped_parameters.append(params) 217 | 218 | # freeze resnet 219 | for name, param in self.model.named_parameters(): 220 | if 'image_model' in name: 221 | param.require_grad = False 222 | self.optimizer = optim.AdamW(optimizer_grouped_parameters, lr=self.args.lr) 223 | self.scheduler = get_linear_schedule_with_warmup(optimizer=self.optimizer, 224 | num_warmup_steps=self.args.warmup_ratio*self.train_num_steps, 225 | num_training_steps=self.train_num_steps) 226 | self.model.to(self.args.device) 227 | 228 | 229 | class NERTrainer(BaseTrainer): 230 | def __init__(self, train_data=None, dev_data=None, test_data=None, model=None, processor=None, label_map=None, args=None, logger=None, writer=None) -> None: 231 | self.train_data = train_data 232 | self.dev_data = dev_data 233 | self.test_data = test_data 234 | self.model = model 235 | self.processor = processor 236 | self.logger = logger 237 | self.label_map = label_map 238 | self.writer = writer 239 | self.refresh_step = 2 240 | self.best_dev_metric = 0 241 | self.best_test_metric = 0 242 | self.best_train_metric = 0 243 | self.best_dev_epoch = None 244 | self.best_test_epoch = None 245 | self.best_train_epoch = None 246 | self.optimizer = None 247 | if self.train_data is not None: 248 | self.train_num_steps = len(self.train_data) * args.num_epochs 249 | self.step = 0 250 | self.args = args 251 | 252 | def train(self): 253 | if self.args.use_prompt: 254 | self.multiModal_before_train() 255 | else: 256 | self.bert_before_train() 257 | 258 | self.step = 0 259 | self.model.train() 260 | self.logger.info("***** Running training *****") 261 | self.logger.info(" Num instance = %d", len(self.train_data)*self.args.batch_size) 262 | self.logger.info(" Num epoch = %d", self.args.num_epochs) 263 | self.logger.info(" Batch size = %d", self.args.batch_size) 264 | self.logger.info(" Learning rate = {}".format(self.args.lr)) 265 | self.logger.info(" Evaluate begin = %d", self.args.eval_begin_epoch) 266 | 267 | if self.args.load_path is not None: # load model from load_path 268 | self.logger.info("Loading model from {}".format(self.args.load_path)) 269 | self.model.load_state_dict(torch.load(self.args.load_path)) 270 | self.logger.info("Load model successful!") 271 | 272 | with tqdm(total=self.train_num_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True, initial=self.step) as pbar: 273 | self.pbar = pbar 274 | avg_loss = 0 275 | for epoch in range(1, self.args.num_epochs+1): 276 | y_true, y_pred = [], [] 277 | pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.args.num_epochs)) 278 | for batch in self.train_data: 279 | self.step += 1 280 | batch = (tup.to(self.args.device) if isinstance(tup, torch.Tensor) else tup for tup in batch) 281 | attention_mask, labels, logits, loss = self._step(batch, mode="train") 282 | avg_loss += loss.detach().cpu().item() 283 | 284 | loss.backward() 285 | self.optimizer.step() 286 | self.scheduler.step() 287 | self.optimizer.zero_grad() 288 | 289 | if isinstance(logits, torch.Tensor): # CRF return lists 290 | logits = logits.argmax(-1).detach().cpu().numpy() # batch, seq, 1 291 | label_ids = labels.to('cpu').numpy() 292 | input_mask = attention_mask.to('cpu').numpy() 293 | label_map = {idx:label for label, idx in self.label_map.items()} 294 | 295 | for row, mask_line in enumerate(input_mask): 296 | true_label = [] 297 | true_predict = [] 298 | for column, mask in enumerate(mask_line): 299 | if column == 0: 300 | continue 301 | if mask: 302 | if label_map[label_ids[row][column]] != "X" and label_map[label_ids[row][column]] != "[SEP]": 303 | true_label.append(label_map[label_ids[row][column]]) 304 | true_predict.append(label_map[logits[row][column]]) 305 | else: 306 | break 307 | y_true.append(true_label) 308 | y_pred.append(true_predict) 309 | 310 | if self.step % self.refresh_step == 0: 311 | avg_loss = float(avg_loss) / self.refresh_step 312 | print_output = "loss:{:<6.5f}".format(avg_loss) 313 | pbar.update(self.refresh_step) 314 | pbar.set_postfix_str(print_output) 315 | if self.writer: 316 | self.writer.add_scalar(tag='train_loss', scalar_value=avg_loss, global_step=self.step) # tensorbordx 317 | avg_loss = 0 318 | results = classification_report(y_true, y_pred, digits=4) 319 | self.logger.info("***** Train Eval results *****") 320 | self.logger.info("\n%s", results) 321 | f1_score = float(results.split('\n')[-4].split(' ')[0].split(' ')[3]) 322 | if self.writer: 323 | self.writer.add_scalar(tag='train_f1', scalar_value=f1_score, global_step=epoch) # tensorbordx 324 | self.logger.info("Epoch {}/{}, best train f1: {}, best epoch: {}, current train f1 score: {}."\ 325 | .format(epoch, self.args.num_epochs, self.best_train_metric, self.best_train_epoch, f1_score)) 326 | if f1_score > self.best_train_metric: 327 | self.best_train_metric = f1_score 328 | self.best_train_epoch = epoch 329 | 330 | if epoch >= self.args.eval_begin_epoch: 331 | self.evaluate(epoch) # generator to dev. 332 | 333 | torch.cuda.empty_cache() 334 | 335 | pbar.close() 336 | self.pbar = None 337 | self.logger.info("Get best dev performance at epoch {}, best dev f1 score is {}".format(self.best_dev_epoch, self.best_dev_metric)) 338 | self.logger.info("Get best test performance at epoch {}, best test f1 score is {}".format(self.best_test_epoch, self.best_test_metric)) 339 | 340 | def evaluate(self, epoch): 341 | self.model.eval() 342 | self.logger.info("***** Running evaluate *****") 343 | self.logger.info(" Num instance = %d", len(self.dev_data)*self.args.batch_size) 344 | self.logger.info(" Batch size = %d", self.args.batch_size) 345 | 346 | y_true, y_pred = [], [] 347 | step = 0 348 | with torch.no_grad(): 349 | with tqdm(total=len(self.dev_data), leave=False, dynamic_ncols=True) as pbar: 350 | pbar.set_description_str(desc="Dev") 351 | total_loss = 0 352 | for batch in self.dev_data: 353 | step += 1 354 | batch = (tup.to(self.args.device) if isinstance(tup, torch.Tensor) else tup for tup in batch) # to cpu/cuda device 355 | attention_mask, labels, logits, loss = self._step(batch, mode="dev") # logits: batch, seq, num_labels 356 | total_loss += loss.detach().cpu().item() 357 | 358 | if isinstance(logits, torch.Tensor): 359 | logits = logits.argmax(-1).detach().cpu().numpy() # batch, seq, 1 360 | label_ids = labels.detach().cpu().numpy() 361 | input_mask = attention_mask.detach().cpu().numpy() 362 | label_map = {idx:label for label, idx in self.label_map.items()} 363 | for row, mask_line in enumerate(input_mask): 364 | true_label = [] 365 | true_predict = [] 366 | for column, mask in enumerate(mask_line): 367 | if column == 0: 368 | continue 369 | if mask: 370 | if label_map[label_ids[row][column]] != "X" and label_map[label_ids[row][column]] != "[SEP]": 371 | true_label.append(label_map[label_ids[row][column]]) 372 | true_predict.append(label_map[logits[row][column]]) 373 | else: 374 | break 375 | y_true.append(true_label) 376 | y_pred.append(true_predict) 377 | 378 | pbar.update() 379 | pbar.close() 380 | results = classification_report(y_true, y_pred, digits=4) 381 | self.logger.info("***** Dev Eval results *****") 382 | self.logger.info("\n%s", results) 383 | f1_score = float(results.split('\n')[-4].split(' ')[-2].split(' ')[-1]) 384 | if self.writer: 385 | self.writer.add_scalar(tag='dev_f1', scalar_value=f1_score, global_step=epoch) # tensorbordx 386 | self.writer.add_scalar(tag='dev_loss', scalar_value=total_loss/step, global_step=epoch) # tensorbordx 387 | 388 | self.logger.info("Epoch {}/{}, best dev f1: {}, best epoch: {}, current dev f1 score: {}."\ 389 | .format(epoch, self.args.num_epochs, self.best_dev_metric, self.best_dev_epoch, f1_score)) 390 | if f1_score >= self.best_dev_metric: # this epoch get best performance 391 | self.logger.info("Get better performance at epoch {}".format(epoch)) 392 | self.best_dev_epoch = epoch 393 | self.best_dev_metric = f1_score # update best metric(f1 score) 394 | if self.args.save_path is not None: 395 | torch.save(self.model.state_dict(), self.args.save_path+"/best_model.pth") 396 | self.logger.info("Save best model at {}".format(self.args.save_path)) 397 | 398 | self.model.train() 399 | 400 | def test(self): 401 | self.model.eval() 402 | self.logger.info("\n***** Running testing *****") 403 | self.logger.info(" Num instance = %d", len(self.test_data)*self.args.batch_size) 404 | self.logger.info(" Batch size = %d", self.args.batch_size) 405 | 406 | if self.args.load_path is not None: # load model from load_path 407 | self.logger.info("Loading model from {}".format(self.args.load_path)) 408 | self.model.load_state_dict(torch.load(self.args.load_path)) 409 | self.logger.info("Load model successful!") 410 | y_true, y_pred = [], [] 411 | with torch.no_grad(): 412 | with tqdm(total=len(self.test_data), leave=False, dynamic_ncols=True) as pbar: 413 | pbar.set_description_str(desc="Testing") 414 | total_loss = 0 415 | for batch in self.test_data: 416 | batch = (tup.to(self.args.device) if isinstance(tup, torch.Tensor) else tup for tup in batch) # to cpu/cuda device 417 | attention_mask, labels, logits, loss = self._step(batch, mode="dev") # logits: batch, seq, num_labels 418 | total_loss += loss.detach().cpu().item() 419 | 420 | if isinstance(logits, torch.Tensor): 421 | logits = logits.argmax(-1).detach().cpu().tolist() # batch, seq, 1 422 | label_ids = labels.detach().cpu().numpy() 423 | input_mask = attention_mask.detach().cpu().numpy() 424 | label_map = {idx:label for label, idx in self.label_map.items()} 425 | for row, mask_line in enumerate(input_mask): 426 | true_label = [] 427 | true_predict = [] 428 | for column, mask in enumerate(mask_line): 429 | if column == 0: 430 | continue 431 | if mask: 432 | if label_map[label_ids[row][column]] != "X" and label_map[label_ids[row][column]] != "[SEP]": 433 | true_label.append(label_map[label_ids[row][column]]) 434 | true_predict.append(label_map[logits[row][column]]) 435 | else: 436 | break 437 | y_true.append(true_label) 438 | y_pred.append(true_predict) 439 | pbar.update() 440 | # evaluate done 441 | pbar.close() 442 | 443 | results = classification_report(y_true, y_pred, digits=4) 444 | 445 | self.logger.info("***** Test Eval results *****") 446 | self.logger.info("\n%s", results) 447 | f1_score = float(results.split('\n')[-4].split(' ')[-2].split(' ')[-1]) 448 | if self.writer: 449 | self.writer.add_scalar(tag='test_f1', scalar_value=f1_score) # tensorbordx 450 | self.writer.add_scalar(tag='test_loss', scalar_value=total_loss/len(self.test_data)) # tensorbordx 451 | total_loss = 0 452 | self.logger.info("Test f1 score: {}.".format(f1_score)) 453 | 454 | self.model.train() 455 | 456 | def _step(self, batch, mode="train"): 457 | if self.args.use_prompt: 458 | input_ids, token_type_ids, attention_mask, labels, images, aux_imgs = batch 459 | else: 460 | images, aux_imgs = None, None 461 | input_ids, token_type_ids, attention_mask, labels = batch 462 | output = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels, images=images, aux_imgs=aux_imgs) 463 | logits, loss = output.logits, output.loss 464 | return attention_mask, labels, logits, loss 465 | 466 | 467 | 468 | def bert_before_train(self): 469 | self.optimizer = optim.AdamW(self.model.parameters(), lr=self.args.lr) 470 | 471 | self.model.to(self.args.device) 472 | self.scheduler = get_linear_schedule_with_warmup(optimizer=self.optimizer, 473 | num_warmup_steps=self.args.warmup_ratio*self.train_num_steps, 474 | num_training_steps=self.train_num_steps) 475 | 476 | def multiModal_before_train(self): 477 | # bert lr 478 | parameters = [] 479 | params = {'lr':self.args.lr, 'weight_decay':1e-2} 480 | params['params'] = [] 481 | for name, param in self.model.named_parameters(): 482 | if 'bert' in name: 483 | params['params'].append(param) 484 | parameters.append(params) 485 | 486 | # prompt lr 487 | params = {'lr':self.args.lr, 'weight_decay':1e-2} 488 | params['params'] = [] 489 | for name, param in self.model.named_parameters(): 490 | if 'encoder_conv' in name or 'gates' in name: 491 | params['params'].append(param) 492 | parameters.append(params) 493 | 494 | # crf lr 495 | params = {'lr':5e-2, 'weight_decay':1e-2} 496 | params['params'] = [] 497 | for name, param in self.model.named_parameters(): 498 | if 'crf' in name or name.startswith('fc'): 499 | params['params'].append(param) 500 | parameters.append(params) 501 | 502 | self.optimizer = optim.AdamW(parameters) 503 | 504 | for name, par in self.model.named_parameters(): # freeze resnet 505 | if 'image_model' in name: par.requires_grad = False 506 | 507 | self.scheduler = get_linear_schedule_with_warmup(optimizer=self.optimizer, 508 | num_warmup_steps=self.args.warmup_ratio*self.train_num_steps, 509 | num_training_steps=self.train_num_steps) 510 | self.model.to(self.args.device) 511 | -------------------------------------------------------------------------------- /processor/dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import torch 4 | import json 5 | import ast 6 | from PIL import Image 7 | from torch.utils.data import Dataset, DataLoader 8 | from transformers import BertTokenizer 9 | from torchvision import transforms 10 | import logging 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class MMREProcessor(object): 15 | def __init__(self, data_path, bert_name): 16 | self.data_path = data_path 17 | self.re_path = data_path['re_path'] 18 | self.tokenizer = BertTokenizer.from_pretrained(bert_name, do_lower_case=True) 19 | self.tokenizer.add_special_tokens({'additional_special_tokens':['', '', '', '']}) 20 | 21 | def load_from_file(self, mode="train", sample_ratio=1.0): 22 | """ 23 | Args: 24 | mode (str, optional): dataset mode. Defaults to "train". 25 | sample_ratio (float, optional): sample ratio in low resouce. Defaults to 1.0. 26 | """ 27 | load_file = self.data_path[mode] 28 | logger.info("Loading data from {}".format(load_file)) 29 | with open(load_file, "r", encoding="utf-8") as f: 30 | lines = f.readlines() 31 | words, relations, heads, tails, imgids, dataid = [], [], [], [], [], [] 32 | for i, line in enumerate(lines): 33 | line = ast.literal_eval(line) # str to dict 34 | words.append(line['token']) 35 | relations.append(line['relation']) 36 | heads.append(line['h']) # {name, pos} 37 | tails.append(line['t']) 38 | imgids.append(line['img_id']) 39 | dataid.append(i) 40 | 41 | assert len(words) == len(relations) == len(heads) == len(tails) == (len(imgids)) 42 | 43 | # aux image 44 | aux_path = self.data_path[mode+"_auximgs"] 45 | aux_imgs = torch.load(aux_path) 46 | 47 | # sample 48 | if sample_ratio != 1.0: 49 | sample_indexes = random.choices(list(range(len(words))), k=int(len(words)*sample_ratio)) 50 | sample_words = [words[idx] for idx in sample_indexes] 51 | sample_relations = [relations[idx] for idx in sample_indexes] 52 | sample_heads = [heads[idx] for idx in sample_indexes] 53 | sample_tails = [tails[idx] for idx in sample_indexes] 54 | sample_imgids = [imgids[idx] for idx in sample_indexes] 55 | sample_dataid = [dataid[idx] for idx in sample_indexes] 56 | assert len(sample_words) == len(sample_relations) == len(sample_imgids), "{}, {}, {}".format(len(sample_words), len(sample_relations), len(sample_imgids)) 57 | return {'words':sample_words, 'relations':sample_relations, 'heads':sample_heads, 'tails':sample_tails, \ 58 | 'imgids':sample_imgids, 'dataid': sample_dataid, 'aux_imgs':aux_imgs} 59 | 60 | return {'words':words, 'relations':relations, 'heads':heads, 'tails':tails, 'imgids': imgids, 'dataid': dataid, 'aux_imgs':aux_imgs} 61 | 62 | 63 | def get_relation_dict(self): 64 | with open(self.re_path, 'r', encoding="utf-8") as f: 65 | line = f.readlines()[0] 66 | re_dict = json.loads(line) 67 | return re_dict 68 | 69 | class MMPNERProcessor(object): 70 | def __init__(self, data_path, bert_name) -> None: 71 | self.data_path = data_path 72 | self.tokenizer = BertTokenizer.from_pretrained(bert_name, do_lower_case=True) 73 | 74 | def load_from_file(self, mode="train", sample_ratio=1.0): 75 | """ 76 | Args: 77 | mode (str, optional): dataset mode. Defaults to "train". 78 | sample_ratio (float, optional): sample ratio in low resouce. Defaults to 1.0. 79 | """ 80 | load_file = self.data_path[mode] 81 | logger.info("Loading data from {}".format(load_file)) 82 | with open(load_file, "r", encoding="utf-8") as f: 83 | lines = f.readlines() 84 | raw_words, raw_targets = [], [] 85 | raw_word, raw_target = [], [] 86 | imgs = [] 87 | for line in lines: 88 | if line.startswith("IMGID:"): 89 | img_id = line.strip().split('IMGID:')[1] + '.jpg' 90 | imgs.append(img_id) 91 | continue 92 | if line != "\n": 93 | raw_word.append(line.split('\t')[0]) 94 | label = line.split('\t')[1][:-1] 95 | if 'OTHER' in label: 96 | label = label[:2] + 'MISC' 97 | raw_target.append(label) 98 | else: 99 | raw_words.append(raw_word) 100 | raw_targets.append(raw_target) 101 | raw_word, raw_target = [], [] 102 | 103 | assert len(raw_words) == len(raw_targets) == len(imgs), "{}, {}, {}".format(len(raw_words), len(raw_targets), len(imgs)) 104 | # load aux image 105 | aux_path = self.data_path[mode+"_auximgs"] 106 | aux_imgs = torch.load(aux_path) 107 | 108 | # sample data, only for low-resource 109 | if sample_ratio != 1.0: 110 | sample_indexes = random.choices(list(range(len(raw_words))), k=int(len(raw_words)*sample_ratio)) 111 | sample_raw_words = [raw_words[idx] for idx in sample_indexes] 112 | sample_raw_targets = [raw_targets[idx] for idx in sample_indexes] 113 | sample_imgs = [imgs[idx] for idx in sample_indexes] 114 | assert len(sample_raw_words) == len(sample_raw_targets) == len(sample_imgs), "{}, {}, {}".format(len(sample_raw_words), len(sample_raw_targets), len(sample_imgs)) 115 | return {"words": sample_raw_words, "targets": sample_raw_targets, "imgs": sample_imgs, "aux_imgs":aux_imgs} 116 | 117 | return {"words": raw_words, "targets": raw_targets, "imgs": imgs, "aux_imgs":aux_imgs} 118 | 119 | def get_label_mapping(self): 120 | LABEL_LIST = ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X", "[CLS]", "[SEP]"] 121 | label_mapping = {label:idx for idx, label in enumerate(LABEL_LIST, 1)} 122 | label_mapping["PAD"] = 0 123 | return label_mapping 124 | 125 | class MMREDataset(Dataset): 126 | def __init__(self, processor, transform, img_path=None, aux_img_path=None, max_seq=40, sample_ratio=1.0, mode="train") -> None: 127 | self.processor = processor 128 | self.transform = transform 129 | self.max_seq = max_seq 130 | self.img_path = img_path[mode] if img_path is not None else img_path 131 | self.aux_img_path = aux_img_path[mode] if aux_img_path is not None else aux_img_path 132 | self.mode = mode 133 | 134 | self.data_dict = self.processor.load_from_file(mode, sample_ratio) 135 | self.re_dict = self.processor.get_relation_dict() 136 | self.tokenizer = self.processor.tokenizer 137 | 138 | def __len__(self): 139 | return len(self.data_dict['words']) 140 | 141 | def __getitem__(self, idx): 142 | word_list, relation, head_d, tail_d, imgid = self.data_dict['words'][idx], self.data_dict['relations'][idx], self.data_dict['heads'][idx], self.data_dict['tails'][idx], self.data_dict['imgids'][idx] 143 | item_id = self.data_dict['dataid'][idx] 144 | # [CLS] ... head ... tail .. [SEP] 145 | head_pos, tail_pos = head_d['pos'], tail_d['pos'] 146 | # insert 147 | extend_word_list = [] 148 | for i in range(len(word_list)): 149 | if i == head_pos[0]: 150 | extend_word_list.append('') 151 | if i == head_pos[1]: 152 | extend_word_list.append('') 153 | if i == tail_pos[0]: 154 | extend_word_list.append('') 155 | if i == tail_pos[1]: 156 | extend_word_list.append('') 157 | extend_word_list.append(word_list[i]) 158 | extend_word_list = " ".join(extend_word_list) 159 | encode_dict = self.tokenizer.encode_plus(text=extend_word_list, max_length=self.max_seq, truncation=True, padding='max_length') 160 | input_ids, token_type_ids, attention_mask = encode_dict['input_ids'], encode_dict['token_type_ids'], encode_dict['attention_mask'] 161 | input_ids, token_type_ids, attention_mask = torch.tensor(input_ids), torch.tensor(token_type_ids), torch.tensor(attention_mask) 162 | 163 | re_label = self.re_dict[relation] # label to id 164 | 165 | # image process 166 | if self.img_path is not None: 167 | try: 168 | img_path = os.path.join(self.img_path, imgid) 169 | image = Image.open(img_path).convert('RGB') 170 | image = self.transform(image) 171 | except: 172 | img_path = os.path.join(self.img_path, 'inf.png') 173 | image = Image.open(img_path).convert('RGB') 174 | image = self.transform(image) 175 | if self.aux_img_path is not None: 176 | # process aux image 177 | aux_imgs = [] 178 | aux_img_paths = [] 179 | imgid = imgid.split(".")[0] 180 | if item_id in self.data_dict['aux_imgs']: 181 | aux_img_paths = self.data_dict['aux_imgs'][item_id] 182 | aux_img_paths = [os.path.join(self.aux_img_path, path) for path in aux_img_paths] 183 | # discaed more than 3 aux image 184 | for i in range(min(3, len(aux_img_paths))): 185 | aux_img = Image.open(aux_img_paths[i]).convert('RGB') 186 | aux_img = self.transform(aux_img) 187 | aux_imgs.append(aux_img) 188 | 189 | # padding zero if less than 3 190 | for i in range(3-len(aux_img_paths)): 191 | aux_imgs.append(torch.zeros((3, 224, 224))) 192 | 193 | aux_imgs = torch.stack(aux_imgs, dim=0) 194 | assert len(aux_imgs) == 3 195 | return input_ids, token_type_ids, attention_mask, torch.tensor(re_label), image, aux_imgs 196 | return input_ids, token_type_ids, attention_mask, torch.tensor(re_label) 197 | 198 | class MMPNERDataset(Dataset): 199 | def __init__(self, processor, transform, img_path=None, aux_img_path=None, max_seq=40, sample_ratio=1, mode='train', ignore_idx=0) -> None: 200 | self.processor = processor 201 | self.transform = transform 202 | self.data_dict = processor.load_from_file(mode, sample_ratio) 203 | self.tokenizer = processor.tokenizer 204 | self.label_mapping = processor.get_label_mapping() 205 | self.max_seq = max_seq 206 | self.ignore_idx = ignore_idx 207 | self.img_path = img_path 208 | self.aux_img_path = aux_img_path[mode] if aux_img_path is not None else None 209 | self.mode = mode 210 | self.sample_ratio = sample_ratio 211 | 212 | def __len__(self): 213 | return len(self.data_dict['words']) 214 | 215 | def __getitem__(self, idx): 216 | word_list, label_list, img = self.data_dict['words'][idx], self.data_dict['targets'][idx], self.data_dict['imgs'][idx] 217 | tokens, labels = [], [] 218 | for i, word in enumerate(word_list): 219 | token = self.tokenizer.tokenize(word) 220 | tokens.extend(token) 221 | label = label_list[i] 222 | for m in range(len(token)): 223 | if m == 0: 224 | labels.append(self.label_mapping[label]) 225 | else: 226 | labels.append(self.label_mapping["X"]) 227 | if len(tokens) >= self.max_seq - 1: 228 | tokens = tokens[0:(self.max_seq - 2)] 229 | labels = labels[0:(self.max_seq - 2)] 230 | 231 | encode_dict = self.tokenizer.encode_plus(tokens, max_length=self.max_seq, truncation=True, padding='max_length') 232 | input_ids, token_type_ids, attention_mask = encode_dict['input_ids'], encode_dict['token_type_ids'], encode_dict['attention_mask'] 233 | labels = [self.label_mapping["[CLS]"]] + labels + [self.label_mapping["[SEP]"]] + [self.ignore_idx]*(self.max_seq-len(labels)-2) 234 | 235 | if self.img_path is not None: 236 | # image process 237 | try: 238 | img_path = os.path.join(self.img_path, img) 239 | image = Image.open(img_path).convert('RGB') 240 | image = self.transform(image) 241 | except: 242 | img_path = os.path.join(self.img_path, 'inf.png') 243 | image = Image.open(img_path).convert('RGB') 244 | image = self.transform(image) 245 | 246 | if self.aux_img_path is not None: 247 | aux_imgs = [] 248 | aux_img_paths = [] 249 | if img in self.data_dict['aux_imgs']: 250 | aux_img_paths = self.data_dict['aux_imgs'][img] 251 | aux_img_paths = [os.path.join(self.aux_img_path, path) for path in aux_img_paths] 252 | for i in range(min(3, len(aux_img_paths))): 253 | aux_img = Image.open(aux_img_paths[i]).convert('RGB') 254 | aux_img = self.transform(aux_img) 255 | aux_imgs.append(aux_img) 256 | 257 | for i in range(3-len(aux_img_paths)): 258 | aux_imgs.append(torch.zeros((3, 224, 224))) 259 | 260 | aux_imgs = torch.stack(aux_imgs, dim=0) 261 | assert len(aux_imgs) == 3 262 | return torch.tensor(input_ids), torch.tensor(token_type_ids), torch.tensor(attention_mask), torch.tensor(labels), image, aux_imgs 263 | 264 | assert len(input_ids) == len(token_type_ids) == len(attention_mask) == len(labels) 265 | return torch.tensor(input_ids), torch.tensor(token_type_ids), torch.tensor(attention_mask), torch.tensor(labels) 266 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.11.3 2 | pytorch==1.7.0 3 | tensorboardX==2.4 4 | TorchCRF==1.1.0 5 | wandb==0.12.1 6 | torchvision==0.8.2 7 | torch==1.7.1 8 | scikit-learn==1.0 9 | seqeval==1.2.2 -------------------------------------------------------------------------------- /resource/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/HVPNeT/410bd0c54325376cd308bd5ee8558d61ee3d9553/resource/model.png -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | import sys 5 | sys.path.append("..") 6 | 7 | import torch 8 | import numpy as np 9 | import random 10 | from torchvision import transforms 11 | from torch.utils.data import DataLoader 12 | from models.bert_model import HMNeTREModel, HMNeTNERModel 13 | from processor.dataset import MMREProcessor, MMPNERProcessor, MMREDataset, MMPNERDataset 14 | from modules.train import RETrainer, NERTrainer 15 | 16 | import warnings 17 | warnings.filterwarnings("ignore", category=UserWarning) 18 | # from tensorboardX import SummaryWriter 19 | 20 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 21 | datefmt = '%m/%d/%Y %H:%M:%S', 22 | level = logging.INFO) 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | MODEL_CLASSES = { 27 | 'MRE': HMNeTREModel, 28 | 'twitter15': HMNeTNERModel, 29 | 'twitter17': HMNeTNERModel 30 | } 31 | 32 | TRAINER_CLASSES = { 33 | 'MRE': RETrainer, 34 | 'twitter15': NERTrainer, 35 | 'twitter17': NERTrainer 36 | } 37 | DATA_PROCESS = { 38 | 'MRE': (MMREProcessor, MMREDataset), 39 | 'twitter15': (MMPNERProcessor, MMPNERDataset), 40 | 'twitter17': (MMPNERProcessor, MMPNERDataset) 41 | } 42 | 43 | DATA_PATH = { 44 | 'MRE': { 45 | # text data 46 | 'train': 'data/RE_data/txt/ours_train.txt', 47 | 'dev': 'data/RE_data/txt/ours_val.txt', 48 | 'test': 'data/RE_data/txt/ours_test.txt', 49 | # {data_id : object_crop_img_path} 50 | 'train_auximgs': 'data/RE_data/txt/mre_train_dict.pth', 51 | 'dev_auximgs': 'data/RE_data/txt/mre_dev_dict.pth', 52 | 'test_auximgs': 'data/RE_data/txt/mre_test_dict.pth', 53 | # relation json data 54 | 're_path': 'data/RE_data/ours_rel2id.json' 55 | }, 56 | 57 | 'twitter15': { 58 | # text data 59 | 'train': 'data/NER_data/twitter2015/train.txt', 60 | 'dev': 'data/NER_data/twitter2015/valid.txt', 61 | 'test': 'data/NER_data/twitter2015/test.txt', 62 | # {data_id : object_crop_img_path} 63 | 'train_auximgs': 'data/NER_data/twitter2015/twitter2015_train_dict.pth', 64 | 'dev_auximgs': 'data/NER_data/twitter2015/twitter2015_val_dict.pth', 65 | 'test_auximgs': 'data/NER_data/twitter2015/twitter2015_test_dict.pth' 66 | }, 67 | 68 | 'twitter17': { 69 | # text data 70 | 'train': 'data/NER_data/twitter2017/train.txt', 71 | 'dev': 'data/NER_data/twitter2017/valid.txt', 72 | 'test': 'data/NER_data/twitter2017/test.txt', 73 | # {data_id : object_crop_img_path} 74 | 'train_auximgs': 'data/NER_data/twitter2017/twitter2017_train_dict.pth', 75 | 'dev_auximgs': 'data/NER_data/twitter2017/twitter2017_val_dict.pth', 76 | 'test_auximgs': 'data/NER_data/twitter2017/twitter2017_test_dict.pth' 77 | }, 78 | 79 | } 80 | 81 | # image data 82 | IMG_PATH = { 83 | 'MRE': {'train': 'data/RE_data/img_org/train/', 84 | 'dev': 'data/RE_data/img_org/val/', 85 | 'test': 'data/RE_data/img_org/test'}, 86 | 'twitter15': 'data/NER_data/twitter2015_images', 87 | 'twitter17': 'data/NER_data/twitter2017_images', 88 | } 89 | 90 | # auxiliary images 91 | AUX_PATH = { 92 | 'MRE':{ 93 | 'train': 'data/RE_data/img_vg/train/crops', 94 | 'dev': 'data/RE_data/img_vg/val/crops', 95 | 'test': 'data/RE_data/img_vg/test/crops' 96 | }, 97 | 'twitter15': { 98 | 'train': 'data/NER_data/twitter2015_aux_images/train/crops', 99 | 'dev': 'data/NER_data/twitter2015_aux_images/val/crops', 100 | 'test': 'data/NER_data/twitter2015_aux_images/test/crops', 101 | }, 102 | 103 | 'twitter17': { 104 | 'train': 'data/NER_data/twitter2017_aux_images/train/crops', 105 | 'dev': 'data/NER_data/twitter2017_aux_images/val/crops', 106 | 'test': 'data/NER_data/twitter2017_aux_images/test/crops', 107 | } 108 | } 109 | 110 | def set_seed(seed=2021): 111 | """set random seed""" 112 | torch.manual_seed(seed) 113 | torch.cuda.manual_seed_all(seed) 114 | torch.backends.cudnn.deterministic = True 115 | np.random.seed(seed) 116 | random.seed(seed) 117 | 118 | def main(): 119 | parser = argparse.ArgumentParser() 120 | parser.add_argument('--dataset_name', default='twitter15', type=str, help="The name of dataset.") 121 | parser.add_argument('--bert_name', default='bert-base-uncased', type=str, help="Pretrained language model path") 122 | parser.add_argument('--num_epochs', default=30, type=int, help="num training epochs") 123 | parser.add_argument('--device', default='cuda', type=str, help="cuda or cpu") 124 | parser.add_argument('--batch_size', default=32, type=int, help="batch size") 125 | parser.add_argument('--lr', default=1e-5, type=float, help="learning rate") 126 | parser.add_argument('--warmup_ratio', default=0.01, type=float) 127 | parser.add_argument('--eval_begin_epoch', default=16, type=int, help="epoch to start evluate") 128 | parser.add_argument('--seed', default=1, type=int, help="random seed, default is 1") 129 | parser.add_argument('--prompt_len', default=10, type=int, help="prompt length") 130 | parser.add_argument('--prompt_dim', default=800, type=int, help="mid dimension of prompt project layer") 131 | parser.add_argument('--load_path', default=None, type=str, help="Load model from load_path") 132 | parser.add_argument('--save_path', default=None, type=str, help="save model at save_path") 133 | parser.add_argument('--write_path', default=None, type=str, help="do_test=True, predictions will be write in write_path") 134 | parser.add_argument('--notes', default="", type=str, help="input some remarks for making save path dir.") 135 | parser.add_argument('--use_prompt', action='store_true') 136 | parser.add_argument('--do_train', action='store_true') 137 | parser.add_argument('--only_test', action='store_true') 138 | parser.add_argument('--max_seq', default=128, type=int) 139 | parser.add_argument('--ignore_idx', default=-100, type=int) 140 | parser.add_argument('--sample_ratio', default=1.0, type=float, help="only for low resource.") 141 | 142 | args = parser.parse_args() 143 | 144 | data_path, img_path, aux_path = DATA_PATH[args.dataset_name], IMG_PATH[args.dataset_name], AUX_PATH[args.dataset_name] 145 | model_class, Trainer = MODEL_CLASSES[args.dataset_name], TRAINER_CLASSES[args.dataset_name] 146 | data_process, dataset_class = DATA_PROCESS[args.dataset_name] 147 | 148 | transform = transforms.Compose([ 149 | transforms.Resize(256), 150 | transforms.CenterCrop(224), 151 | transforms.ToTensor(), 152 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 153 | std=[0.229, 0.224, 0.225])]) 154 | 155 | set_seed(args.seed) # set seed, default is 1 156 | if args.save_path is not None: # make save_path dir 157 | # args.save_path = os.path.join(args.save_path, args.dataset_name+"_"+str(args.batch_size)+"_"+str(args.lr)+"_"+args.notes) 158 | if not os.path.exists(args.save_path): 159 | os.makedirs(args.save_path, exist_ok=True) 160 | print(args) 161 | logdir = "logs/" + args.dataset_name+ "_"+str(args.batch_size) + "_" + str(args.lr) + args.notes 162 | # writer = SummaryWriter(logdir=logdir) 163 | writer=None 164 | 165 | if not args.use_prompt: 166 | img_path, aux_path = None, None 167 | 168 | processor = data_process(data_path, args.bert_name) 169 | train_dataset = dataset_class(processor, transform, img_path, aux_path, args.max_seq, sample_ratio=args.sample_ratio, mode='train') 170 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) 171 | 172 | dev_dataset = dataset_class(processor, transform, img_path, aux_path, args.max_seq, mode='dev') 173 | dev_dataloader = DataLoader(dev_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True) 174 | 175 | test_dataset = dataset_class(processor, transform, img_path, aux_path, args.max_seq, mode='test') 176 | test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True) 177 | 178 | if args.dataset_name == 'MRE': # RE task 179 | re_dict = processor.get_relation_dict() 180 | num_labels = len(re_dict) 181 | tokenizer = processor.tokenizer 182 | model = HMNeTREModel(num_labels, tokenizer, args=args) 183 | 184 | trainer = Trainer(train_data=train_dataloader, dev_data=dev_dataloader, test_data=test_dataloader, model=model, processor=processor, args=args, logger=logger, writer=writer) 185 | else: # NER task 186 | label_mapping = processor.get_label_mapping() 187 | label_list = list(label_mapping.keys()) 188 | model = HMNeTNERModel(label_list, args) 189 | 190 | trainer = Trainer(train_data=train_dataloader, dev_data=dev_dataloader, test_data=test_dataloader, model=model, label_map=label_mapping, args=args, logger=logger, writer=writer) 191 | 192 | if args.do_train: 193 | # train 194 | trainer.train() 195 | # test best model 196 | args.load_path = os.path.join(args.save_path, 'best_model.pth') 197 | trainer.test() 198 | 199 | if args.only_test: 200 | # only do test 201 | trainer.test() 202 | 203 | torch.cuda.empty_cache() 204 | # writer.close() 205 | 206 | 207 | if __name__ == "__main__": 208 | main() -------------------------------------------------------------------------------- /run_re_task.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATASET_NAME="MRE" 4 | BERT_NAME="bert-base-uncased" 5 | 6 | CUDA_VISIBLE_DEVICES=2 python -u run.py \ 7 | --dataset_name=${DATASET_NAME} \ 8 | --bert_name=${BERT_NAME} \ 9 | --num_epochs=15 \ 10 | --batch_size=16 \ 11 | --lr=3e-5 \ 12 | --warmup_ratio=0.06 \ 13 | --eval_begin_epoch=1 \ 14 | --seed=1234 \ 15 | --do_train \ 16 | --max_seq=80 \ 17 | --use_prompt \ 18 | --prompt_len=4 \ 19 | --sample_ratio=1.0 \ 20 | --save_path='ckpt/re/' -------------------------------------------------------------------------------- /run_twitter15.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Required environment variables: 3 | # batch_size (recommendation: 8 / 16) 4 | # lr: learning rate (recommendation: 3e-5 / 5e-5) 5 | # seed: random seed, default is 1234 6 | # BERT_NAME: pre-trained text model name ( bert-*) 7 | # max_seq: max sequence length 8 | # sample_ratio: few-shot learning, default is 1.0 9 | # save_path: model saved path 10 | 11 | DATASET_NAME="twitter15" 12 | BERT_NAME="bert-base-uncased" 13 | 14 | lr=3e-5 15 | 16 | CUDA_VISIBLE_DEVICES=0 python -u run.py \ 17 | --dataset_name=${DATASET_NAME} \ 18 | --bert_name=${BERT_NAME} \ 19 | --num_epochs=30 \ 20 | --batch_size=8 \ 21 | --lr=$lr \ 22 | --warmup_ratio=0.01 \ 23 | --eval_begin_epoch=3 \ 24 | --seed=1234 \ 25 | --do_train \ 26 | --ignore_idx=0 \ 27 | --max_seq=80 \ 28 | --use_prompt \ 29 | --prompt_len=4 \ 30 | --sample_ratio=1.0 \ 31 | --save_path=your_ckpt_path 32 | -------------------------------------------------------------------------------- /run_twitter17.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Required environment variables: 3 | # batch_size (recommendation: 8 / 16) 4 | # lr: learning rate (recommendation: 3e-5 / 5e-5) 5 | # seed: random seed, default is 1234 6 | # BERT_NAME: pre-trained text model name ( bert-*) 7 | # max_seq: max sequence length 8 | # sample_ratio: few-shot learning, default is 1.0 9 | # save_path: model saved path 10 | 11 | DATASET_NAME="twitter17" 12 | BERT_NAME="bert-base-uncased" 13 | 14 | CUDA_VISIBLE_DEVICES=0 python -u run.py \ 15 | --dataset_name=${DATASET_NAME} \ 16 | --bert_name=${BERT_NAME} \ 17 | --num_epochs=30 \ 18 | --batch_size=8 \ 19 | --lr=3e-5 \ 20 | --warmup_ratio=0.01 \ 21 | --eval_begin_epoch=3 \ 22 | --seed=1234 \ 23 | --do_train \ 24 | --ignore_idx=0 \ 25 | --max_seq=128 \ 26 | --use_prompt \ 27 | --prompt_len=4 \ 28 | --sample_ratio=1.0 \ 29 | --save_path=your_ckpt_path --------------------------------------------------------------------------------