├── ANP_data └── image label_folder ├── README.md ├── data ├── politic_data │ ├── a.txt │ ├── test_opinion.txt │ ├── train_opinion.txt │ └── valid_opinion.txt ├── twitter2015 │ ├── test.txt │ ├── train.txt │ └── valid.txt └── twitter2017 │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── image_data └── image data folder.txt ├── log files ├── twitter2015_output │ ├── config.json │ ├── eval_results.txt │ └── model_config.json ├── twitter2017_output │ ├── config.json │ ├── eval_results.txt │ └── model_config.json └── twitterpolitic_output │ ├── config.json │ ├── eval_results.txt │ └── model_config.json ├── model ├── resnet │ ├── __init__.py │ ├── resnet.py │ ├── resnet_model_folder.txt │ └── resnet_utils.py └── roberta-base-cased │ └── Roberta_model_folder.txt ├── my_bert ├── __init__.py ├── __main__.py ├── bichannel_modeling.py ├── convert_tf_checkpoint_to_pytorch.py ├── file_utils.py ├── mner_modeling.py ├── optimization.py └── tokenization.py ├── ner_evaluate.py ├── run_cmmt_crf.py └── run_cmmt_crf.sh /ANP_data/image label_folder: -------------------------------------------------------------------------------- 1 | Download the image label file via this link(https://drive.google.com/drive/folders/1i-EiyLS0RwsOw8cQIMMAKRRqy9VgXOhT?usp=sharing), and then put the associaled image label files into folder "./ANP_data/" 2 | 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cross-Modal Multitask Transformer for End-to-End Multimodal Aspect-Based Sentiment Analysis 2 | 3 | #### Author: Li YANG, yang0666@e.ntu.edu.sg 4 | 5 | #### The Corresponding Paper: 6 | ##### Cross-modal multitask transformer for end-to-end multimodal aspect-based sentiment analysis 7 | ##### [[https://www.sciencedirect.com/science/article/abs/pii/S0306457324000840](https://www.sciencedirect.com/science/article/pii/S0306457322001479)](https://www.sciencedirect.com/science/article/abs/pii/S0306457322001479) 8 | 9 | ##### The framework of the CMMT model: 10 | ![alt text]Screenshot 2024-04-10 at 10 38 04 AM 11 | 12 | 13 | 14 | ## Data 15 | - The preprocessed CoNLL format files are provided in this repo. For each tweet, the first line is its image id, and the following lines are its textual contents. 16 | - Step 1:Download each tweet's associated images via this link (https://drive.google.com/file/d/1PpvvncnQkgDNeBMKVgG2zFYuRhbL873g/view), and then put the associated images into folders "./image_data/twitter2015/" and "./image_data/twitter2017/"; 17 | - The politician dataset can be get via: https://drive.google.com/file/d/1oa029MLk8I_J99pxBs7X9RaIbUHhhTNG/view?usp=sharing 18 | - Step 2: Download the image label file via this link(https://drive.google.com/drive/folders/1UaeSYJQCQzszRmBWdhA11LqXnWousn4G?usp=sharing), and then put the associaled image label files into folder "./ANP_data/" 19 | - Step 3: Download the pre-trained ResNet-152 via this link (https://download.pytorch.org/models/resnet152-b121ed2d.pth), and put the pre-trained ResNet-152 model under the folder './model/resnet/" 20 | 21 | - Step 4: Download the pre-trained roberta-base-cased from huggingface and put the pre-trained roberta model under the folder "./model/roberta-base-cased/" 22 | - Step 5: ANP information can be downloaded via https://drive.google.com/drive/folders/1UaeSYJQCQzszRmBWdhA11LqXnWousn4G?usp=share_link 23 | 24 | ## Requirement 25 | * PyTorch 1.0.0 26 | * Python 3.7 27 | * pytorch-crf 0.7.2 28 | 29 | ## Code Usage 30 | 31 | ### Training for CMMT 32 | - This is the training code of tuning parameters on the dev set, and testing on the test set. Note that you can change "CUDA_VISIBLE_DEVICES=2" based on your available GPUs. 33 | 34 | ```sh 35 | sh run_cmmt_crf.sh 36 | ``` 37 | 38 | - We show our running logs on twitter-2015, twitter-2017 and political twitter in the folder "log files". Note that the results are a little bit lower than the results reported in our paper, since the experiments were run on different servers. 39 | 40 | 41 | ## Acknowledgements 42 | - Using these two datasets means you have read and accepted the copyrights set by Twitter and dataset providers. 43 | - Most of the codes are based on the codes provided by huggingface: https://github.com/huggingface/transformers. 44 | 45 | ## Citation Information: 46 | Yang, L., Na, J. C., & Yu, J. (2022). Cross-modal multitask transformer for end-to-end multimodal aspect-based sentiment analysis. Information Processing & Management, 59(5), 103038. 47 | -------------------------------------------------------------------------------- /data/politic_data/a.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /image_data/image data folder.txt: -------------------------------------------------------------------------------- 1 | Download each tweet's associated images via this link (https://drive.google.com/file/d/1PpvvncnQkgDNeBMKVgG2zFYuRhbL873g/view), and then put the associated images into folders "./image_data/twitter2015/" and "./image_data/twitter2017/" 2 | -------------------------------------------------------------------------------- /log files/twitter2015_output/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "RobertaForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "bos_token_id": 0, 7 | "eos_token_id": 2, 8 | "hidden_act": "gelu", 9 | "hidden_dropout_prob": 0.1, 10 | "hidden_size": 768, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 3072, 13 | "layer_norm_eps": 1e-05, 14 | "max_position_embeddings": 514, 15 | "model_type": "roberta", 16 | "num_attention_heads": 12, 17 | "num_hidden_layers": 12, 18 | "pad_token_id": 1, 19 | "type_vocab_size": 1, 20 | "vocab_size": 50265 21 | } 22 | -------------------------------------------------------------------------------- /log files/twitter2015_output/eval_results.txt: -------------------------------------------------------------------------------- 1 | Overall: 0.6457765667574932 0.6856316297010607 0.6651075771749297 2 | Positive: 0.6655290102389079 0.6151419558359621 0.6393442622950819 3 | Neutral: 0.65 0.7495881383855024 0.6962509563886764 4 | Negative: 0.5648148148148148 0.5398230088495575 0.5520361990950227 5 | -------------------------------------------------------------------------------- /log files/twitter2015_output/model_config.json: -------------------------------------------------------------------------------- 1 | {"bert_model": "/mnt/nfs-storage-titan/yoli_projects/roberta-base-cased", "do_lower": false, "max_seq_length": 128, "num_labels": 11, "label_map": {"1": "O", "2": "B-NEU", "3": "I-NEU", "4": "B-POS", "5": "I-POS", "6": "B-NEG", "7": "I-NEG", "8": "X", "9": "", "10": ""}} -------------------------------------------------------------------------------- /log files/twitter2017_output/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "RobertaForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "bos_token_id": 0, 7 | "eos_token_id": 2, 8 | "hidden_act": "gelu", 9 | "hidden_dropout_prob": 0.1, 10 | "hidden_size": 768, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 3072, 13 | "layer_norm_eps": 1e-05, 14 | "max_position_embeddings": 514, 15 | "model_type": "roberta", 16 | "num_attention_heads": 12, 17 | "num_hidden_layers": 12, 18 | "pad_token_id": 1, 19 | "type_vocab_size": 1, 20 | "vocab_size": 50265 21 | } 22 | -------------------------------------------------------------------------------- /log files/twitter2017_output/eval_results.txt: -------------------------------------------------------------------------------- 1 | Overall: 0.6758675078864353 0.6944894651539708 0.6850519584332534 2 | Positive: 0.6893203883495146 0.7200811359026369 0.7043650793650794 3 | Neutral: 0.678082191780822 0.6910994764397905 0.6845289541918756 4 | Negative: 0.6272189349112426 0.6309523809523809 0.629080118694362 5 | -------------------------------------------------------------------------------- /log files/twitter2017_output/model_config.json: -------------------------------------------------------------------------------- 1 | {"bert_model": "/mnt/nfs-storage-titan/yoli_projects/roberta-base-cased", "do_lower": false, "max_seq_length": 128, "num_labels": 11, "label_map": {"1": "O", "2": "B-NEU", "3": "I-NEU", "4": "B-POS", "5": "I-POS", "6": "B-NEG", "7": "I-NEG", "8": "X", "9": "", "10": ""}} -------------------------------------------------------------------------------- /log files/twitterpolitic_output/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "RobertaForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "bos_token_id": 0, 7 | "eos_token_id": 2, 8 | "hidden_act": "gelu", 9 | "hidden_dropout_prob": 0.1, 10 | "hidden_size": 768, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 3072, 13 | "layer_norm_eps": 1e-05, 14 | "max_position_embeddings": 514, 15 | "model_type": "roberta", 16 | "num_attention_heads": 12, 17 | "num_hidden_layers": 12, 18 | "pad_token_id": 1, 19 | "type_vocab_size": 1, 20 | "vocab_size": 50265 21 | } 22 | -------------------------------------------------------------------------------- /log files/twitterpolitic_output/eval_results.txt: -------------------------------------------------------------------------------- 1 | Overall: 0.6526315789473685 0.657243816254417 0.6549295774647887 2 | Positive: 0.6197916666666666 0.6761363636363636 0.6467391304347826 3 | Neutral: 0.6260162601626016 0.6277173913043478 0.6268656716417911 4 | Negative: 0.7074829931972789 0.6819672131147541 0.6944908180300502 5 | -------------------------------------------------------------------------------- /log files/twitterpolitic_output/model_config.json: -------------------------------------------------------------------------------- 1 | {"bert_model": "/mnt/nfs-storage-titan/yoli_projects/roberta-base-cased", "do_lower": false, "max_seq_length": 128, "num_labels": 11, "label_map": {"1": "O", "2": "B-NEU", "3": "I-NEU", "4": "B-POS", "5": "I-POS", "6": "B-NEG", "7": "I-NEG", "8": "X", "9": "", "10": ""}} -------------------------------------------------------------------------------- /model/resnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangli-hub/CMMT-Code/06d38658b0c0e3c585ca8fa17d354c353c676e84/model/resnet/__init__.py -------------------------------------------------------------------------------- /model/resnet/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152'] 8 | 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | "3x3 convolution with padding" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | expansion = 4 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None): 61 | super(Bottleneck, self).__init__() 62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 63 | self.bn1 = nn.BatchNorm2d(planes) 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 65 | padding=1, bias=False) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 68 | self.bn3 = nn.BatchNorm2d(planes * 4) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x): 74 | residual = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv3(out) 85 | out = self.bn3(out) 86 | 87 | if self.downsample is not None: 88 | residual = self.downsample(x) 89 | 90 | out += residual 91 | out = self.relu(out) 92 | 93 | return out 94 | 95 | 96 | class ResNet(nn.Module): 97 | 98 | def __init__(self, block, layers, num_classes=1000): 99 | self.inplanes = 64 100 | super(ResNet, self).__init__() 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 102 | bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 110 | self.avgpool = nn.AvgPool2d(7, stride=1) 111 | self.fc = nn.Linear(512 * block.expansion, num_classes) 112 | 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 116 | m.weight.data.normal_(0, math.sqrt(2. / n)) 117 | elif isinstance(m, nn.BatchNorm2d): 118 | m.weight.data.fill_(1) 119 | m.bias.data.zero_() 120 | 121 | def _make_layer(self, block, planes, blocks, stride=1): 122 | downsample = None 123 | if stride != 1 or self.inplanes != planes * block.expansion: 124 | downsample = nn.Sequential( 125 | nn.Conv2d(self.inplanes, planes * block.expansion, 126 | kernel_size=1, stride=stride, bias=False), 127 | nn.BatchNorm2d(planes * block.expansion), 128 | ) 129 | 130 | layers = [] 131 | layers.append(block(self.inplanes, planes, stride, downsample)) 132 | self.inplanes = planes * block.expansion 133 | for i in range(1, blocks): 134 | layers.append(block(self.inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | x = self.bn1(x) 141 | x = self.relu(x) 142 | x = self.maxpool(x) 143 | 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | 149 | x = self.avgpool(x) 150 | x = x.view(x.size(0), -1) 151 | x = self.fc(x) 152 | 153 | return x 154 | 155 | 156 | def resnet18(pretrained=False, **kwargs): 157 | """Constructs a ResNet-18 model. 158 | 159 | Args: 160 | pretrained (bool): If True, returns a model pre-trained on ImageNet 161 | """ 162 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 163 | if pretrained: 164 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 165 | return model 166 | 167 | 168 | def resnet34(pretrained=False, **kwargs): 169 | """Constructs a ResNet-34 model. 170 | 171 | Args: 172 | pretrained (bool): If True, returns a model pre-trained on ImageNet 173 | """ 174 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 175 | if pretrained: 176 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 177 | return model 178 | 179 | 180 | def resnet50(pretrained=False, **kwargs): 181 | """Constructs a ResNet-50 model. 182 | 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | """ 186 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 187 | if pretrained: 188 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 189 | return model 190 | 191 | 192 | def resnet101(pretrained=False, **kwargs): 193 | """Constructs a ResNet-101 model. 194 | 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | """ 198 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 199 | if pretrained: 200 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 201 | return model 202 | 203 | 204 | def resnet152(pretrained=False, **kwargs): 205 | """Constructs a ResNet-152 model. 206 | 207 | Args: 208 | pretrained (bool): If True, returns a model pre-trained on ImageNet 209 | """ 210 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 211 | if pretrained: 212 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 213 | return model 214 | -------------------------------------------------------------------------------- /model/resnet/resnet_model_folder.txt: -------------------------------------------------------------------------------- 1 | 2 | Download the pre-trained ResNet-152 via this link (https://download.pytorch.org/models/resnet152-b121ed2d.pth), and put the pre-trained ResNet-152 model under the folder './model/resnet/" 3 | -------------------------------------------------------------------------------- /model/resnet/resnet_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | 6 | class myResnet(nn.Module): 7 | def __init__(self, resnet, if_fine_tune, device): 8 | super(myResnet, self).__init__() 9 | self.resnet = resnet 10 | self.if_fine_tune = if_fine_tune 11 | self.device = device 12 | 13 | def forward(self, x, att_size=7): 14 | x = self.resnet.conv1(x) 15 | x = self.resnet.bn1(x) 16 | x = self.resnet.relu(x) 17 | x = self.resnet.maxpool(x) 18 | 19 | x = self.resnet.layer1(x) 20 | x = self.resnet.layer2(x) 21 | x = self.resnet.layer3(x) 22 | x = self.resnet.layer4(x) 23 | 24 | fc = x.mean(3).mean(2) 25 | att = F.adaptive_avg_pool2d(x,[att_size,att_size]) 26 | 27 | x = self.resnet.avgpool(x) 28 | x = x.view(x.size(0), -1) 29 | 30 | if not self.if_fine_tune: 31 | 32 | x= Variable(x.data) 33 | fc = Variable(fc.data) 34 | att = Variable(att.data) 35 | 36 | return x, fc, att 37 | 38 | 39 | -------------------------------------------------------------------------------- /model/roberta-base-cased/Roberta_model_folder.txt: -------------------------------------------------------------------------------- 1 | Download the pre-trained roberta-base-cased and put the pre-trained roberta model under the folder "./model/roberta-base-cased/" 2 | -------------------------------------------------------------------------------- /my_bert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.4.0" 2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | from .optimization import BertAdam 4 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE -------------------------------------------------------------------------------- /my_bert/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | try: 5 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 6 | except ModuleNotFoundError: 7 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 8 | "In that case, it requires TensorFlow to be installed. Please see " 9 | "https://www.tensorflow.org/install/ for installation instructions.") 10 | raise 11 | 12 | if len(sys.argv) != 5: 13 | # pylint: disable=line-too-long 14 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 15 | else: 16 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 17 | TF_CONFIG = sys.argv.pop() 18 | TF_CHECKPOINT = sys.argv.pop() 19 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 20 | 21 | if __name__ == '__main__': 22 | main() 23 | -------------------------------------------------------------------------------- /my_bert/bichannel_modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import copy 24 | import json 25 | import math 26 | import logging 27 | import tarfile 28 | import tempfile 29 | import shutil 30 | 31 | import torch 32 | from torch import nn 33 | from torch.nn import CrossEntropyLoss 34 | 35 | from .file_utils import cached_path 36 | 37 | logger = logging.getLogger(__name__) 38 | 39 | PRETRAINED_MODEL_ARCHIVE_MAP = { 40 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", 41 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", 42 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", 43 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", 44 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", 45 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", 46 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", 47 | } 48 | CONFIG_NAME = 'bert_config.json' 49 | WEIGHTS_NAME = 'pytorch_model.bin' 50 | 51 | def gelu(x): 52 | """Implementation of the gelu activation function. 53 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 54 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 55 | """ 56 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 57 | 58 | 59 | def swish(x): 60 | return x * torch.sigmoid(x) 61 | 62 | 63 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 64 | 65 | 66 | class BertConfig(object): 67 | """Configuration class to store the configuration of a `BertModel`. 68 | """ 69 | def __init__(self, 70 | vocab_size_or_config_json_file, 71 | hidden_size=768, 72 | num_hidden_layers=12, 73 | num_attention_heads=12, 74 | intermediate_size=3072, 75 | hidden_act="gelu", 76 | hidden_dropout_prob=0.1, 77 | attention_probs_dropout_prob=0.1, 78 | max_position_embeddings=512, 79 | type_vocab_size=2, 80 | initializer_range=0.02): 81 | """Constructs BertConfig. 82 | 83 | Args: 84 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. 85 | hidden_size: Size of the encoder layers and the pooler layer. 86 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 87 | num_attention_heads: Number of attention heads for each attention layer in 88 | the Transformer encoder. 89 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 90 | layer in the Transformer encoder. 91 | hidden_act: The non-linear activation function (function or string) in the 92 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 93 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 94 | layers in the embeddings, encoder, and pooler. 95 | attention_probs_dropout_prob: The dropout ratio for the attention 96 | probabilities. 97 | max_position_embeddings: The maximum sequence length that this model might 98 | ever be used with. Typically set this to something large just in case 99 | (e.g., 512 or 1024 or 2048). 100 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 101 | `BertModel`. 102 | initializer_range: The sttdev of the truncated_normal_initializer for 103 | initializing all weight matrices. 104 | """ 105 | if isinstance(vocab_size_or_config_json_file, str): 106 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 107 | json_config = json.loads(reader.read()) 108 | for key, value in json_config.items(): 109 | self.__dict__[key] = value 110 | elif isinstance(vocab_size_or_config_json_file, int): 111 | self.vocab_size = vocab_size_or_config_json_file 112 | self.hidden_size = hidden_size 113 | self.num_hidden_layers = num_hidden_layers 114 | self.num_attention_heads = num_attention_heads 115 | self.hidden_act = hidden_act 116 | self.intermediate_size = intermediate_size 117 | self.hidden_dropout_prob = hidden_dropout_prob 118 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 119 | self.max_position_embeddings = max_position_embeddings 120 | self.type_vocab_size = type_vocab_size 121 | self.initializer_range = initializer_range 122 | else: 123 | raise ValueError("First argument must be either a vocabulary size (int)" 124 | "or the path to a pretrained model config file (str)") 125 | 126 | @classmethod 127 | def from_dict(cls, json_object): 128 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 129 | config = BertConfig(vocab_size_or_config_json_file=-1) 130 | for key, value in json_object.items(): 131 | config.__dict__[key] = value 132 | return config 133 | 134 | @classmethod 135 | def from_json_file(cls, json_file): 136 | """Constructs a `BertConfig` from a json file of parameters.""" 137 | with open(json_file, "r", encoding='utf-8') as reader: 138 | text = reader.read() 139 | return cls.from_dict(json.loads(text)) 140 | 141 | def __repr__(self): 142 | return str(self.to_json_string()) 143 | 144 | def to_dict(self): 145 | """Serializes this instance to a Python dictionary.""" 146 | output = copy.deepcopy(self.__dict__) 147 | return output 148 | 149 | def to_json_string(self): 150 | """Serializes this instance to a JSON string.""" 151 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 152 | 153 | try: 154 | from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm 155 | except ImportError: 156 | print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.") 157 | class BertLayerNorm(nn.Module): 158 | def __init__(self, hidden_size, eps=1e-12): 159 | """Construct a layernorm module in the TF style (epsilon inside the square root). 160 | """ 161 | super(BertLayerNorm, self).__init__() 162 | self.weight = nn.Parameter(torch.ones(hidden_size)) 163 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 164 | self.variance_epsilon = eps 165 | 166 | def forward(self, x): 167 | u = x.mean(-1, keepdim=True) 168 | s = (x - u).pow(2).mean(-1, keepdim=True) 169 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 170 | return self.weight * x + self.bias 171 | 172 | class BertEmbeddings(nn.Module): 173 | """Construct the embeddings from word, position and token_type embeddings. 174 | """ 175 | def __init__(self, config): 176 | super(BertEmbeddings, self).__init__() 177 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 178 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 179 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 180 | 181 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 182 | # any TensorFlow checkpoint file 183 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 184 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 185 | 186 | def forward(self, input_ids, token_type_ids=None): 187 | seq_length = input_ids.size(1) 188 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 189 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 190 | if token_type_ids is None: 191 | token_type_ids = torch.zeros_like(input_ids) 192 | 193 | words_embeddings = self.word_embeddings(input_ids) 194 | position_embeddings = self.position_embeddings(position_ids) 195 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 196 | 197 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 198 | embeddings = self.LayerNorm(embeddings) 199 | embeddings = self.dropout(embeddings) 200 | return embeddings 201 | 202 | 203 | class BertSelfAttention(nn.Module): 204 | def __init__(self, config): 205 | super(BertSelfAttention, self).__init__() 206 | if config.hidden_size % config.num_attention_heads != 0: 207 | raise ValueError( 208 | "The hidden size (%d) is not a multiple of the number of attention " 209 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 210 | self.num_attention_heads = config.num_attention_heads 211 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 212 | self.all_head_size = self.num_attention_heads * self.attention_head_size 213 | 214 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 215 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 216 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 217 | 218 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 219 | 220 | def transpose_for_scores(self, x): 221 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 222 | x = x.view(*new_x_shape) 223 | return x.permute(0, 2, 1, 3) 224 | 225 | def forward(self, hidden_states, attention_mask): 226 | mixed_query_layer = self.query(hidden_states) 227 | mixed_key_layer = self.key(hidden_states) 228 | mixed_value_layer = self.value(hidden_states) 229 | 230 | query_layer = self.transpose_for_scores(mixed_query_layer) 231 | key_layer = self.transpose_for_scores(mixed_key_layer) 232 | value_layer = self.transpose_for_scores(mixed_value_layer) 233 | 234 | # Take the dot product between "query" and "key" to get the raw attention scores. 235 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 236 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 237 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 238 | attention_scores = attention_scores + attention_mask 239 | 240 | # Normalize the attention scores to probabilities. 241 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 242 | 243 | # This is actually dropping out entire tokens to attend to, which might 244 | # seem a bit unusual, but is taken from the original Transformer paper. 245 | attention_probs = self.dropout(attention_probs) 246 | 247 | context_layer = torch.matmul(attention_probs, value_layer) 248 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 249 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 250 | context_layer = context_layer.view(*new_context_layer_shape) 251 | return context_layer 252 | 253 | class BertCoAttention(nn.Module): 254 | def __init__(self, config): 255 | super(BertCoAttention, self).__init__() 256 | if config.hidden_size % config.num_attention_heads != 0: 257 | raise ValueError( 258 | "The hidden size (%d) is not a multiple of the number of attention " 259 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 260 | self.num_attention_heads = config.num_attention_heads 261 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 262 | self.all_head_size = self.num_attention_heads * self.attention_head_size 263 | 264 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 265 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 266 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 267 | 268 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 269 | 270 | def transpose_for_scores(self, x): 271 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 272 | x = x.view(*new_x_shape) 273 | return x.permute(0, 2, 1, 3) 274 | 275 | def forward(self, s1_hidden_states, s2_hidden_states, s2_attention_mask): 276 | mixed_query_layer = self.query(s1_hidden_states) 277 | mixed_key_layer = self.key(s2_hidden_states) 278 | mixed_value_layer = self.value(s2_hidden_states) 279 | 280 | query_layer = self.transpose_for_scores(mixed_query_layer) 281 | key_layer = self.transpose_for_scores(mixed_key_layer) 282 | value_layer = self.transpose_for_scores(mixed_value_layer) 283 | 284 | # Take the dot product between "query" and "key" to get the raw attention scores. 285 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 286 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 287 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 288 | attention_scores = attention_scores + s2_attention_mask 289 | 290 | # Normalize the attention scores to probabilities. 291 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 292 | 293 | # This is actually dropping out entire tokens to attend to, which might 294 | # seem a bit unusual, but is taken from the original Transformer paper. 295 | attention_probs = self.dropout(attention_probs) 296 | 297 | context_layer = torch.matmul(attention_probs, value_layer) 298 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 299 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 300 | context_layer = context_layer.view(*new_context_layer_shape) 301 | return context_layer 302 | 303 | 304 | class BertSelfOutput(nn.Module): 305 | def __init__(self, config): 306 | super(BertSelfOutput, self).__init__() 307 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 308 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 309 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 310 | 311 | def forward(self, hidden_states, input_tensor): 312 | hidden_states = self.dense(hidden_states) 313 | hidden_states = self.dropout(hidden_states) 314 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 315 | return hidden_states 316 | 317 | 318 | class BertAttention(nn.Module): 319 | def __init__(self, config): 320 | super(BertAttention, self).__init__() 321 | self.self = BertSelfAttention(config) 322 | self.output = BertSelfOutput(config) 323 | 324 | def forward(self, input_tensor, attention_mask): 325 | self_output = self.self(input_tensor, attention_mask) 326 | attention_output = self.output(self_output, input_tensor) 327 | return attention_output 328 | 329 | 330 | class BertCrossAttention(nn.Module): 331 | def __init__(self, config): 332 | super(BertCrossAttention, self).__init__() 333 | self.self = BertCoAttention(config) 334 | self.output = BertSelfOutput(config) 335 | 336 | def forward(self, s1_input_tensor, s2_input_tensor, s2_attention_mask): 337 | self_output = self.self(s1_input_tensor, s2_input_tensor, s2_attention_mask) 338 | attention_output = self.output(self_output, s1_input_tensor) 339 | return attention_output 340 | 341 | 342 | class BertIntermediate(nn.Module): 343 | def __init__(self, config): 344 | super(BertIntermediate, self).__init__() 345 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 346 | self.intermediate_act_fn = ACT2FN[config.hidden_act] \ 347 | if isinstance(config.hidden_act, str) else config.hidden_act 348 | 349 | def forward(self, hidden_states): 350 | hidden_states = self.dense(hidden_states) 351 | hidden_states = self.intermediate_act_fn(hidden_states) 352 | return hidden_states 353 | 354 | 355 | class BertOutput(nn.Module): 356 | def __init__(self, config): 357 | super(BertOutput, self).__init__() 358 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 359 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 360 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 361 | 362 | def forward(self, hidden_states, input_tensor): 363 | hidden_states = self.dense(hidden_states) 364 | hidden_states = self.dropout(hidden_states) 365 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 366 | return hidden_states 367 | 368 | 369 | class BertLayer(nn.Module): 370 | def __init__(self, config): 371 | super(BertLayer, self).__init__() 372 | self.attention = BertAttention(config) 373 | self.intermediate = BertIntermediate(config) 374 | self.output = BertOutput(config) 375 | 376 | def forward(self, hidden_states, attention_mask): 377 | attention_output = self.attention(hidden_states, attention_mask) 378 | intermediate_output = self.intermediate(attention_output) 379 | layer_output = self.output(intermediate_output, attention_output) 380 | return layer_output 381 | 382 | class BertCrossAttentionLayer(nn.Module): 383 | def __init__(self, config): 384 | super(BertCrossAttentionLayer, self).__init__() 385 | self.attention = BertCrossAttention(config) 386 | self.intermediate = BertIntermediate(config) 387 | self.output = BertOutput(config) 388 | 389 | def forward(self, s1_hidden_states, s2_hidden_states, s2_attention_mask): 390 | attention_output = self.attention(s1_hidden_states, s2_hidden_states, s2_attention_mask) 391 | intermediate_output = self.intermediate(attention_output) 392 | layer_output = self.output(intermediate_output, attention_output) 393 | return layer_output 394 | 395 | class BertEncoder(nn.Module): 396 | def __init__(self, config): 397 | super(BertEncoder, self).__init__() 398 | layer = BertLayer(config) 399 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 400 | 401 | def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): 402 | all_encoder_layers = [] 403 | for layer_module in self.layer: 404 | hidden_states = layer_module(hidden_states, attention_mask) 405 | if output_all_encoded_layers: 406 | all_encoder_layers.append(hidden_states) 407 | if not output_all_encoded_layers: 408 | all_encoder_layers.append(hidden_states) 409 | return all_encoder_layers 410 | 411 | 412 | class BertPooler(nn.Module): 413 | def __init__(self, config): 414 | super(BertPooler, self).__init__() 415 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 416 | self.activation = nn.Tanh() 417 | 418 | def forward(self, hidden_states): 419 | # We "pool" the model by simply taking the hidden state corresponding 420 | # to the first token. 421 | first_token_tensor = hidden_states[:, 0] 422 | pooled_output = self.dense(first_token_tensor) 423 | pooled_output = self.activation(pooled_output) 424 | return pooled_output 425 | 426 | 427 | class BertPredictionHeadTransform(nn.Module): 428 | def __init__(self, config): 429 | super(BertPredictionHeadTransform, self).__init__() 430 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 431 | self.transform_act_fn = ACT2FN[config.hidden_act] \ 432 | if isinstance(config.hidden_act, str) else config.hidden_act 433 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 434 | 435 | def forward(self, hidden_states): 436 | hidden_states = self.dense(hidden_states) 437 | hidden_states = self.transform_act_fn(hidden_states) 438 | hidden_states = self.LayerNorm(hidden_states) 439 | return hidden_states 440 | 441 | 442 | class BertLMPredictionHead(nn.Module): 443 | def __init__(self, config, bert_model_embedding_weights): 444 | super(BertLMPredictionHead, self).__init__() 445 | self.transform = BertPredictionHeadTransform(config) 446 | 447 | # The output weights are the same as the input embeddings, but there is 448 | # an output-only bias for each token. 449 | self.decoder = nn.Linear(bert_model_embedding_weights.size(1), 450 | bert_model_embedding_weights.size(0), 451 | bias=False) 452 | self.decoder.weight = bert_model_embedding_weights 453 | self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) 454 | 455 | def forward(self, hidden_states): 456 | hidden_states = self.transform(hidden_states) 457 | hidden_states = self.decoder(hidden_states) + self.bias 458 | return hidden_states 459 | 460 | 461 | class BertOnlyMLMHead(nn.Module): 462 | def __init__(self, config, bert_model_embedding_weights): 463 | super(BertOnlyMLMHead, self).__init__() 464 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) 465 | 466 | def forward(self, sequence_output): 467 | prediction_scores = self.predictions(sequence_output) 468 | return prediction_scores 469 | 470 | 471 | class BertOnlyNSPHead(nn.Module): 472 | def __init__(self, config): 473 | super(BertOnlyNSPHead, self).__init__() 474 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 475 | 476 | def forward(self, pooled_output): 477 | seq_relationship_score = self.seq_relationship(pooled_output) 478 | return seq_relationship_score 479 | 480 | 481 | class BertPreTrainingHeads(nn.Module): 482 | def __init__(self, config, bert_model_embedding_weights): 483 | super(BertPreTrainingHeads, self).__init__() 484 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) 485 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 486 | 487 | def forward(self, sequence_output, pooled_output): 488 | prediction_scores = self.predictions(sequence_output) 489 | seq_relationship_score = self.seq_relationship(pooled_output) 490 | return prediction_scores, seq_relationship_score 491 | 492 | 493 | class PreTrainedBertModel(nn.Module): 494 | """ An abstract class to handle weights initialization and 495 | a simple interface for dowloading and loading pretrained models. 496 | """ 497 | def __init__(self, config, *inputs, **kwargs): 498 | super(PreTrainedBertModel, self).__init__() 499 | if not isinstance(config, BertConfig): 500 | raise ValueError( 501 | "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " 502 | "To create a model from a Google pretrained model use " 503 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 504 | self.__class__.__name__, self.__class__.__name__ 505 | )) 506 | self.config = config 507 | 508 | def init_bert_weights(self, module): 509 | """ Initialize the weights. 510 | """ 511 | if isinstance(module, (nn.Linear, nn.Embedding)): 512 | # Slightly different from the TF version which uses truncated_normal for initialization 513 | # cf https://github.com/pytorch/pytorch/pull/5617 514 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 515 | elif isinstance(module, BertLayerNorm): 516 | module.bias.data.zero_() 517 | module.weight.data.fill_(1.0) 518 | if isinstance(module, nn.Linear) and module.bias is not None: 519 | module.bias.data.zero_() 520 | 521 | @classmethod 522 | def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs): 523 | """ 524 | Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict. 525 | Download and cache the pre-trained model file if needed. 526 | 527 | Params: 528 | pretrained_model_name: either: 529 | - a str with the name of a pre-trained model to load selected in the list of: 530 | . `bert-base-uncased` 531 | . `bert-large-uncased` 532 | . `bert-base-cased` 533 | . `bert-large-cased` 534 | . `bert-base-multilingual-uncased` 535 | . `bert-base-multilingual-cased` 536 | . `bert-base-chinese` 537 | - a path or url to a pretrained model archive containing: 538 | . `bert_config.json` a configuration file for the model 539 | . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance 540 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 541 | state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models 542 | *inputs, **kwargs: additional input for the specific Bert class 543 | (ex: num_labels for BertForSequenceClassification) 544 | """ 545 | if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP: 546 | archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name] 547 | else: 548 | archive_file = pretrained_model_name 549 | # redirect to the cache, if necessary 550 | try: 551 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) 552 | except FileNotFoundError: 553 | logger.error( 554 | "Model name '{}' was not found in model name list ({}). " 555 | "We assumed '{}' was a path or url but couldn't find any file " 556 | "associated to this path or url.".format( 557 | pretrained_model_name, 558 | ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), 559 | archive_file)) 560 | return None 561 | if resolved_archive_file == archive_file: 562 | logger.info("loading archive file {}".format(archive_file)) 563 | else: 564 | logger.info("loading archive file {} from cache at {}".format( 565 | archive_file, resolved_archive_file)) 566 | tempdir = None 567 | if os.path.isdir(resolved_archive_file): 568 | serialization_dir = resolved_archive_file 569 | else: 570 | # Extract archive to temp dir 571 | tempdir = tempfile.mkdtemp() 572 | logger.info("extracting archive file {} to temp dir {}".format( 573 | resolved_archive_file, tempdir)) 574 | with tarfile.open(resolved_archive_file, 'r:gz') as archive: 575 | archive.extractall(tempdir) 576 | serialization_dir = tempdir 577 | # Load config 578 | config_file = os.path.join(serialization_dir, CONFIG_NAME) 579 | config = BertConfig.from_json_file(config_file) 580 | logger.info("Model config {}".format(config)) 581 | # Instantiate model. 582 | model = cls(config, *inputs, **kwargs) 583 | if state_dict is None: 584 | weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) 585 | state_dict = torch.load(weights_path) 586 | 587 | old_keys = [] 588 | new_keys = [] 589 | for key in state_dict.keys(): 590 | new_key = None 591 | if 'gamma' in key: 592 | new_key = key.replace('gamma', 'weight') 593 | if 'beta' in key: 594 | new_key = key.replace('beta', 'bias') 595 | if new_key: 596 | old_keys.append(key) 597 | new_keys.append(new_key) 598 | for old_key, new_key in zip(old_keys, new_keys): 599 | state_dict[new_key] = state_dict.pop(old_key) 600 | 601 | missing_keys = [] 602 | unexpected_keys = [] 603 | error_msgs = [] 604 | # copy state_dict so _load_from_state_dict can modify it 605 | metadata = getattr(state_dict, '_metadata', None) 606 | state_dict = state_dict.copy() 607 | if metadata is not None: 608 | state_dict._metadata = metadata 609 | 610 | def load(module, prefix=''): 611 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 612 | module._load_from_state_dict( 613 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 614 | for name, child in module._modules.items(): 615 | if child is not None: 616 | load(child, prefix + name + '.') 617 | load(model, prefix='' if hasattr(model, 'bert') else 'bert.') 618 | if len(missing_keys) > 0: 619 | logger.info("Weights of {} not initialized from pretrained model: {}".format( 620 | model.__class__.__name__, missing_keys)) 621 | if len(unexpected_keys) > 0: 622 | logger.info("Weights from pretrained model not used in {}: {}".format( 623 | model.__class__.__name__, unexpected_keys)) 624 | if tempdir: 625 | # Clean up temp dir 626 | shutil.rmtree(tempdir) 627 | return model 628 | 629 | 630 | class BertModel(PreTrainedBertModel): 631 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 632 | 633 | Params: 634 | config: a BertConfig class instance with the configuration to build a new model 635 | 636 | Inputs: 637 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 638 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 639 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 640 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 641 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 642 | a `sentence B` token (see BERT paper for more details). 643 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 644 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 645 | input sequence length in the current batch. It's the mask that we typically use for attention when 646 | a batch has varying length sentences. 647 | `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. 648 | 649 | Outputs: Tuple of (encoded_layers, pooled_output) 650 | `encoded_layers`: controled by `output_all_encoded_layers` argument: 651 | - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end 652 | of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each 653 | encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], 654 | - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding 655 | to the last attention block of shape [batch_size, sequence_length, hidden_size], 656 | `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a 657 | classifier pretrained on top of the hidden state associated to the first character of the 658 | input (`CLF`) to train on the Next-Sentence task (see BERT's paper). 659 | 660 | Example usage: 661 | ```python 662 | # Already been converted into WordPiece token ids 663 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 664 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 665 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 666 | 667 | config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 668 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 669 | 670 | model = modeling.BertModel(config=config) 671 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 672 | ``` 673 | """ 674 | def __init__(self, config): 675 | super(BertModel, self).__init__(config) 676 | self.embeddings = BertEmbeddings(config) 677 | self.encoder = BertEncoder(config) 678 | self.pooler = BertPooler(config) 679 | self.apply(self.init_bert_weights) 680 | 681 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): 682 | if attention_mask is None: 683 | attention_mask = torch.ones_like(input_ids) 684 | if token_type_ids is None: 685 | token_type_ids = torch.zeros_like(input_ids) 686 | 687 | # We create a 3D attention mask from a 2D tensor mask. 688 | # Sizes are [batch_size, 1, 1, to_seq_length] 689 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 690 | # this attention mask is more simple than the triangular masking of causal attention 691 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 692 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 693 | 694 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 695 | # masked positions, this operation will create a tensor which is 0.0 for 696 | # positions we want to attend and -10000.0 for masked positions. 697 | # Since we are adding it to the raw scores before the softmax, this is 698 | # effectively the same as removing these entirely. 699 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 700 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 701 | 702 | embedding_output = self.embeddings(input_ids, token_type_ids) 703 | encoded_layers = self.encoder(embedding_output, 704 | extended_attention_mask, 705 | output_all_encoded_layers=output_all_encoded_layers) 706 | sequence_output = encoded_layers[-1] 707 | pooled_output = self.pooler(sequence_output) 708 | if not output_all_encoded_layers: 709 | encoded_layers = encoded_layers[-1] 710 | return encoded_layers, pooled_output 711 | 712 | 713 | class BertForPreTraining(PreTrainedBertModel): 714 | """BERT model with pre-training heads. 715 | This module comprises the BERT model followed by the two pre-training heads: 716 | - the masked language modeling head, and 717 | - the next sentence classification head. 718 | 719 | Params: 720 | config: a BertConfig class instance with the configuration to build a new model. 721 | 722 | Inputs: 723 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 724 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 725 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 726 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 727 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 728 | a `sentence B` token (see BERT paper for more details). 729 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 730 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 731 | input sequence length in the current batch. It's the mask that we typically use for attention when 732 | a batch has varying length sentences. 733 | `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 734 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 735 | is only computed for the labels set in [0, ..., vocab_size] 736 | `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] 737 | with indices selected in [0, 1]. 738 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence. 739 | 740 | Outputs: 741 | if `masked_lm_labels` and `next_sentence_label` are not `None`: 742 | Outputs the total_loss which is the sum of the masked language modeling loss and the next 743 | sentence classification loss. 744 | if `masked_lm_labels` or `next_sentence_label` is `None`: 745 | Outputs a tuple comprising 746 | - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and 747 | - the next sentence classification logits of shape [batch_size, 2]. 748 | 749 | Example usage: 750 | ```python 751 | # Already been converted into WordPiece token ids 752 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 753 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 754 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 755 | 756 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 757 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 758 | 759 | model = BertForPreTraining(config) 760 | masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) 761 | ``` 762 | """ 763 | def __init__(self, config): 764 | super(BertForPreTraining, self).__init__(config) 765 | self.bert = BertModel(config) 766 | self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) 767 | self.apply(self.init_bert_weights) 768 | 769 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None): 770 | sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, 771 | output_all_encoded_layers=False) 772 | prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) 773 | 774 | if masked_lm_labels is not None and next_sentence_label is not None: 775 | loss_fct = CrossEntropyLoss(ignore_index=-1) 776 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 777 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 778 | total_loss = masked_lm_loss + next_sentence_loss 779 | return total_loss 780 | else: 781 | return prediction_scores, seq_relationship_score 782 | 783 | 784 | class BertForMaskedLM(PreTrainedBertModel): 785 | """BERT model with the masked language modeling head. 786 | This module comprises the BERT model followed by the masked language modeling head. 787 | 788 | Params: 789 | config: a BertConfig class instance with the configuration to build a new model. 790 | 791 | Inputs: 792 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 793 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 794 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 795 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 796 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 797 | a `sentence B` token (see BERT paper for more details). 798 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 799 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 800 | input sequence length in the current batch. It's the mask that we typically use for attention when 801 | a batch has varying length sentences. 802 | `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 803 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 804 | is only computed for the labels set in [0, ..., vocab_size] 805 | 806 | Outputs: 807 | if `masked_lm_labels` is not `None`: 808 | Outputs the masked language modeling loss. 809 | if `masked_lm_labels` is `None`: 810 | Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size]. 811 | 812 | Example usage: 813 | ```python 814 | # Already been converted into WordPiece token ids 815 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 816 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 817 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 818 | 819 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 820 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 821 | 822 | model = BertForMaskedLM(config) 823 | masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) 824 | ``` 825 | """ 826 | def __init__(self, config): 827 | super(BertForMaskedLM, self).__init__(config) 828 | self.bert = BertModel(config) 829 | self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) 830 | self.apply(self.init_bert_weights) 831 | 832 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None): 833 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, 834 | output_all_encoded_layers=False) 835 | prediction_scores = self.cls(sequence_output) 836 | 837 | if masked_lm_labels is not None: 838 | loss_fct = CrossEntropyLoss(ignore_index=-1) 839 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 840 | return masked_lm_loss 841 | else: 842 | return prediction_scores 843 | 844 | 845 | class BertForNextSentencePrediction(PreTrainedBertModel): 846 | """BERT model with next sentence prediction head. 847 | This module comprises the BERT model followed by the next sentence classification head. 848 | 849 | Params: 850 | config: a BertConfig class instance with the configuration to build a new model. 851 | 852 | Inputs: 853 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 854 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 855 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 856 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 857 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 858 | a `sentence B` token (see BERT paper for more details). 859 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 860 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 861 | input sequence length in the current batch. It's the mask that we typically use for attention when 862 | a batch has varying length sentences. 863 | `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] 864 | with indices selected in [0, 1]. 865 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence. 866 | 867 | Outputs: 868 | if `next_sentence_label` is not `None`: 869 | Outputs the total_loss which is the sum of the masked language modeling loss and the next 870 | sentence classification loss. 871 | if `next_sentence_label` is `None`: 872 | Outputs the next sentence classification logits of shape [batch_size, 2]. 873 | 874 | Example usage: 875 | ```python 876 | # Already been converted into WordPiece token ids 877 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 878 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 879 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 880 | 881 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 882 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 883 | 884 | model = BertForNextSentencePrediction(config) 885 | seq_relationship_logits = model(input_ids, token_type_ids, input_mask) 886 | ``` 887 | """ 888 | def __init__(self, config): 889 | super(BertForNextSentencePrediction, self).__init__(config) 890 | self.bert = BertModel(config) 891 | self.cls = BertOnlyNSPHead(config) 892 | self.apply(self.init_bert_weights) 893 | 894 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None): 895 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, 896 | output_all_encoded_layers=False) 897 | seq_relationship_score = self.cls( pooled_output) 898 | 899 | if next_sentence_label is not None: 900 | loss_fct = CrossEntropyLoss(ignore_index=-1) 901 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 902 | return next_sentence_loss 903 | else: 904 | return seq_relationship_score 905 | 906 | ''' 907 | class BertForSequenceClassification(PreTrainedBertModel): 908 | """BERT model for classification with one additional layer. 909 | """ 910 | def __init__(self, config, num_labels=2): 911 | super(BertForSequenceClassification, self).__init__(config) 912 | self.num_labels = num_labels 913 | self.bert = BertModel(config) 914 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 915 | self.img_attention = BertLayer(config) 916 | self.img_pooler = BertPooler(config) 917 | self.classifier = nn.Linear(config.hidden_size, num_labels) 918 | self.apply(self.init_bert_weights) 919 | 920 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 921 | sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 922 | 923 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 924 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 925 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 926 | 927 | img_att_text_output_layer = self.img_attention(sequence_output, extended_attention_mask) 928 | img_att_text_output = self.img_pooler(img_att_text_output_layer) 929 | 930 | pooled_output = self.dropout(img_att_text_output) 931 | logits = self.classifier(pooled_output) 932 | 933 | if labels is not None: 934 | loss_fct = CrossEntropyLoss() 935 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 936 | return loss 937 | else: 938 | return logits 939 | ''' 940 | 941 | class BertForSequenceClassification(PreTrainedBertModel): 942 | """BERT model for classification with one additional layer. 943 | """ 944 | def __init__(self, config, num_labels=2): 945 | super(BertForSequenceClassification, self).__init__(config) 946 | self.num_labels = num_labels 947 | self.bert = BertModel(config) 948 | self.s1_bert = BertModel(config) 949 | self.s2_bert = BertModel(config) 950 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 951 | self.s1_attention = BertCrossAttentionLayer(config) 952 | self.s1_pooler = BertPooler(config) 953 | self.s2_attention = BertCrossAttentionLayer(config) 954 | self.s2_pooler = BertPooler(config) 955 | self.classifier = nn.Linear(config.hidden_size*3, num_labels) 956 | self.apply(self.init_bert_weights) 957 | 958 | def forward(self, input_ids, s1_input_ids, s2_input_ids, token_type_ids=None, s1_type_ids=None, s2_type_ids=None, \ 959 | attention_mask=None, s1_mask=None, s2_mask=None, labels=None, copy_flag=False): 960 | sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 961 | if copy_flag: 962 | self.s1_bert = copy.deepcopy(self.bert) 963 | self.s2_bert = copy.deepcopy(self.bert) 964 | s1_output, s1_pooled_output = self.s1_bert(s1_input_ids, s1_type_ids, s1_mask, output_all_encoded_layers=False) 965 | s2_output, s2_pooled_output = self.s2_bert(s2_input_ids, s2_type_ids, s2_mask, output_all_encoded_layers=False) 966 | 967 | extended_s1_mask = s1_mask.unsqueeze(1).unsqueeze(2) 968 | extended_s1_mask = extended_s1_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 969 | extended_s1_mask = (1.0 - extended_s1_mask) * -10000.0 970 | 971 | extended_s2_mask = s2_mask.unsqueeze(1).unsqueeze(2) 972 | extended_s2_mask = extended_s2_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 973 | extended_s2_mask = (1.0 - extended_s2_mask) * -10000.0 974 | 975 | s1_cross_output_layer = self.s1_attention(s1_output, s2_output, extended_s2_mask) 976 | s1_cross_output = self.s1_pooler(s1_cross_output_layer) 977 | 978 | s2_cross_output_layer = self.s2_attention(s2_output, s1_output, extended_s1_mask) 979 | s2_cross_output = self.s2_pooler(s2_cross_output_layer) 980 | 981 | final_output = torch.cat((pooled_output, s1_cross_output, s2_cross_output), dim=-1) 982 | pooled_output = self.dropout(final_output) 983 | logits = self.classifier(pooled_output) 984 | 985 | if labels is not None: 986 | loss_fct = CrossEntropyLoss() 987 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 988 | return loss 989 | else: 990 | return logits 991 | 992 | class BertForMultipleChoice(PreTrainedBertModel): 993 | """BERT model for multiple choice tasks. 994 | This module is composed of the BERT model with a linear layer on top of 995 | the pooled output. 996 | 997 | Params: 998 | `config`: a BertConfig class instance with the configuration to build a new model. 999 | `num_choices`: the number of classes for the classifier. Default = 2. 1000 | 1001 | Inputs: 1002 | `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] 1003 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1004 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1005 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] 1006 | with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` 1007 | and type 1 corresponds to a `sentence B` token (see BERT paper for more details). 1008 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices 1009 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1010 | input sequence length in the current batch. It's the mask that we typically use for attention when 1011 | a batch has varying length sentences. 1012 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 1013 | with indices selected in [0, ..., num_choices]. 1014 | 1015 | Outputs: 1016 | if `labels` is not `None`: 1017 | Outputs the CrossEntropy classification loss of the output with the labels. 1018 | if `labels` is `None`: 1019 | Outputs the classification logits of shape [batch_size, num_labels]. 1020 | 1021 | Example usage: 1022 | ```python 1023 | # Already been converted into WordPiece token ids 1024 | input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) 1025 | input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) 1026 | token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) 1027 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1028 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1029 | 1030 | num_choices = 2 1031 | 1032 | model = BertForMultipleChoice(config, num_choices) 1033 | logits = model(input_ids, token_type_ids, input_mask) 1034 | ``` 1035 | """ 1036 | def __init__(self, config, num_choices=2): 1037 | super(BertForMultipleChoice, self).__init__(config) 1038 | self.num_choices = num_choices 1039 | self.bert = BertModel(config) 1040 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1041 | self.classifier = nn.Linear(config.hidden_size, 1) 1042 | self.apply(self.init_bert_weights) 1043 | 1044 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 1045 | flat_input_ids = input_ids.view(-1, input_ids.size(-1)) 1046 | flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) 1047 | flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) 1048 | _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) 1049 | pooled_output = self.dropout(pooled_output) 1050 | logits = self.classifier(pooled_output) 1051 | reshaped_logits = logits.view(-1, self.num_choices) 1052 | 1053 | if labels is not None: 1054 | loss_fct = CrossEntropyLoss() 1055 | loss = loss_fct(reshaped_logits, labels) 1056 | return loss 1057 | else: 1058 | return reshaped_logits 1059 | 1060 | 1061 | class BertForTokenClassification(PreTrainedBertModel): 1062 | """BERT model for token-level classification. 1063 | This module is composed of the BERT model with a linear layer on top of 1064 | the full hidden state of the last layer. 1065 | 1066 | Params: 1067 | `config`: a BertConfig class instance with the configuration to build a new model. 1068 | `num_labels`: the number of classes for the classifier. Default = 2. 1069 | 1070 | Inputs: 1071 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1072 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1073 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1074 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1075 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1076 | a `sentence B` token (see BERT paper for more details). 1077 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1078 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1079 | input sequence length in the current batch. It's the mask that we typically use for attention when 1080 | a batch has varying length sentences. 1081 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 1082 | with indices selected in [0, ..., num_labels]. 1083 | 1084 | Outputs: 1085 | if `labels` is not `None`: 1086 | Outputs the CrossEntropy classification loss of the output with the labels. 1087 | if `labels` is `None`: 1088 | Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. 1089 | 1090 | Example usage: 1091 | ```python 1092 | # Already been converted into WordPiece token ids 1093 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1094 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1095 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1096 | 1097 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1098 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1099 | 1100 | num_labels = 2 1101 | 1102 | model = BertForTokenClassification(config, num_labels) 1103 | logits = model(input_ids, token_type_ids, input_mask) 1104 | ``` 1105 | """ 1106 | def __init__(self, config, num_labels=2): 1107 | super(BertForTokenClassification, self).__init__(config) 1108 | self.num_labels = num_labels 1109 | self.bert = BertModel(config) 1110 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1111 | self.classifier = nn.Linear(config.hidden_size, num_labels) 1112 | self.apply(self.init_bert_weights) 1113 | 1114 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 1115 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 1116 | sequence_output = self.dropout(sequence_output) 1117 | logits = self.classifier(sequence_output) 1118 | 1119 | if labels is not None: 1120 | loss_fct = CrossEntropyLoss() 1121 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1122 | return loss 1123 | else: 1124 | return logits 1125 | 1126 | 1127 | class BertForQuestionAnswering(PreTrainedBertModel): 1128 | """BERT model for Question Answering (span extraction). 1129 | This module is composed of the BERT model with a linear layer on top of 1130 | the sequence output that computes start_logits and end_logits 1131 | 1132 | Params: 1133 | `config`: a BertConfig class instance with the configuration to build a new model. 1134 | 1135 | Inputs: 1136 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1137 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1138 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1139 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1140 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1141 | a `sentence B` token (see BERT paper for more details). 1142 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1143 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1144 | input sequence length in the current batch. It's the mask that we typically use for attention when 1145 | a batch has varying length sentences. 1146 | `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. 1147 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken 1148 | into account for computing the loss. 1149 | `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. 1150 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken 1151 | into account for computing the loss. 1152 | 1153 | Outputs: 1154 | if `start_positions` and `end_positions` are not `None`: 1155 | Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. 1156 | if `start_positions` or `end_positions` is `None`: 1157 | Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end 1158 | position tokens of shape [batch_size, sequence_length]. 1159 | 1160 | Example usage: 1161 | ```python 1162 | # Already been converted into WordPiece token ids 1163 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1164 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1165 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1166 | 1167 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1168 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1169 | 1170 | model = BertForQuestionAnswering(config) 1171 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 1172 | ``` 1173 | """ 1174 | def __init__(self, config): 1175 | super(BertForQuestionAnswering, self).__init__(config) 1176 | self.bert = BertModel(config) 1177 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version 1178 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 1179 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 1180 | self.apply(self.init_bert_weights) 1181 | 1182 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): 1183 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 1184 | logits = self.qa_outputs(sequence_output) 1185 | start_logits, end_logits = logits.split(1, dim=-1) 1186 | start_logits = start_logits.squeeze(-1) 1187 | end_logits = end_logits.squeeze(-1) 1188 | 1189 | if start_positions is not None and end_positions is not None: 1190 | # If we are on multi-GPU, split add a dimension 1191 | if len(start_positions.size()) > 1: 1192 | start_positions = start_positions.squeeze(-1) 1193 | if len(end_positions.size()) > 1: 1194 | end_positions = end_positions.squeeze(-1) 1195 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 1196 | ignored_index = start_logits.size(1) 1197 | start_positions.clamp_(0, ignored_index) 1198 | end_positions.clamp_(0, ignored_index) 1199 | 1200 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 1201 | start_loss = loss_fct(start_logits, start_positions) 1202 | end_loss = loss_fct(end_logits, end_positions) 1203 | total_loss = (start_loss + end_loss) / 2 1204 | return total_loss 1205 | else: 1206 | return start_logits, end_logits 1207 | -------------------------------------------------------------------------------- /my_bert/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import re 23 | import argparse 24 | import tensorflow as tf 25 | import torch 26 | import numpy as np 27 | 28 | from .modeling import BertConfig, BertForPreTraining 29 | 30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 31 | config_path = os.path.abspath(bert_config_file) 32 | tf_path = os.path.abspath(tf_checkpoint_path) 33 | print("Converting TensorFlow checkpoint from {} with config at {}".format(tf_path, config_path)) 34 | # Load weights from TF model 35 | init_vars = tf.train.list_variables(tf_path) 36 | names = [] 37 | arrays = [] 38 | for name, shape in init_vars: 39 | print("Loading TF weight {} with shape {}".format(name, shape)) 40 | array = tf.train.load_variable(tf_path, name) 41 | names.append(name) 42 | arrays.append(array) 43 | 44 | # Initialise PyTorch model 45 | config = BertConfig.from_json_file(bert_config_file) 46 | print("Building PyTorch model from configuration: {}".format(str(config))) 47 | model = BertForPreTraining(config) 48 | 49 | for name, array in zip(names, arrays): 50 | name = name.split('/') 51 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v 52 | # which are not required for using pretrained model 53 | if any(n in ["adam_v", "adam_m", "global_step"] for n in name): 54 | print("Skipping {}".format("/".join(name))) 55 | continue 56 | pointer = model 57 | for m_name in name: 58 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name): 59 | l = re.split(r'_(\d+)', m_name) 60 | else: 61 | l = [m_name] 62 | if l[0] == 'kernel' or l[0] == 'gamma': 63 | pointer = getattr(pointer, 'weight') 64 | elif l[0] == 'output_bias' or l[0] == 'beta': 65 | pointer = getattr(pointer, 'bias') 66 | elif l[0] == 'output_weights': 67 | pointer = getattr(pointer, 'weight') 68 | else: 69 | pointer = getattr(pointer, l[0]) 70 | if len(l) >= 2: 71 | num = int(l[1]) 72 | pointer = pointer[num] 73 | if m_name[-11:] == '_embeddings': 74 | pointer = getattr(pointer, 'weight') 75 | elif m_name == 'kernel': 76 | array = np.transpose(array) 77 | try: 78 | assert pointer.shape == array.shape 79 | except AssertionError as e: 80 | e.args += (pointer.shape, array.shape) 81 | raise 82 | print("Initialize PyTorch weight {}".format(name)) 83 | pointer.data = torch.from_numpy(array) 84 | 85 | # Save pytorch-model 86 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 87 | torch.save(model.state_dict(), pytorch_dump_path) 88 | 89 | 90 | if __name__ == "__main__": 91 | parser = argparse.ArgumentParser() 92 | ## Required parameters 93 | parser.add_argument("--tf_checkpoint_path", 94 | default = None, 95 | type = str, 96 | required = True, 97 | help = "Path the TensorFlow checkpoint path.") 98 | parser.add_argument("--bert_config_file", 99 | default = None, 100 | type = str, 101 | required = True, 102 | help = "The config json file corresponding to the pre-trained BERT model. \n" 103 | "This specifies the model architecture.") 104 | parser.add_argument("--pytorch_dump_path", 105 | default = None, 106 | type = str, 107 | required = True, 108 | help = "Path to the output PyTorch model.") 109 | args = parser.parse_args() 110 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 111 | args.bert_config_file, 112 | args.pytorch_dump_path) 113 | -------------------------------------------------------------------------------- /my_bert/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import sys 9 | import json 10 | import logging 11 | import os 12 | import shutil 13 | import tempfile 14 | import fnmatch 15 | from functools import wraps 16 | from hashlib import sha256 17 | import sys 18 | from io import open 19 | 20 | import boto3 21 | import requests 22 | from botocore.exceptions import ClientError 23 | from tqdm import tqdm 24 | 25 | try: 26 | from torch.hub import _get_torch_home 27 | torch_cache_home = _get_torch_home() 28 | except ImportError: 29 | torch_cache_home = os.path.expanduser( 30 | os.getenv('TORCH_HOME', os.path.join( 31 | os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) 32 | default_cache_path = os.path.join(torch_cache_home, 'pytorch_pretrained_bert') 33 | 34 | try: 35 | from urllib.parse import urlparse 36 | except ImportError: 37 | from urlparse import urlparse 38 | 39 | try: 40 | from pathlib import Path 41 | PYTORCH_PRETRAINED_BERT_CACHE = Path( 42 | os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)) 43 | except (AttributeError, ImportError): 44 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 45 | default_cache_path) 46 | 47 | CONFIG_NAME = "config.json" 48 | WEIGHTS_NAME = "pytorch_model.bin" 49 | 50 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 51 | 52 | 53 | def url_to_filename(url, etag=None): 54 | """ 55 | Convert `url` into a hashed filename in a repeatable way. 56 | If `etag` is specified, append its hash to the url's, delimited 57 | by a period. 58 | """ 59 | url_bytes = url.encode('utf-8') 60 | url_hash = sha256(url_bytes) 61 | filename = url_hash.hexdigest() 62 | 63 | if etag: 64 | etag_bytes = etag.encode('utf-8') 65 | etag_hash = sha256(etag_bytes) 66 | filename += '.' + etag_hash.hexdigest() 67 | 68 | return filename 69 | 70 | 71 | def filename_to_url(filename, cache_dir=None): 72 | """ 73 | Return the url and etag (which may be ``None``) stored for `filename`. 74 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 75 | """ 76 | if cache_dir is None: 77 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 78 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 79 | cache_dir = str(cache_dir) 80 | 81 | cache_path = os.path.join(cache_dir, filename) 82 | if not os.path.exists(cache_path): 83 | raise EnvironmentError("file {} not found".format(cache_path)) 84 | 85 | meta_path = cache_path + '.json' 86 | if not os.path.exists(meta_path): 87 | raise EnvironmentError("file {} not found".format(meta_path)) 88 | 89 | with open(meta_path, encoding="utf-8") as meta_file: 90 | metadata = json.load(meta_file) 91 | url = metadata['url'] 92 | etag = metadata['etag'] 93 | 94 | return url, etag 95 | 96 | 97 | def cached_path(url_or_filename, cache_dir=None): 98 | """ 99 | Given something that might be a URL (or might be a local path), 100 | determine which. If it's a URL, download the file and cache it, and 101 | return the path to the cached file. If it's already a local path, 102 | make sure the file exists and then return the path. 103 | """ 104 | if cache_dir is None: 105 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 106 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 107 | url_or_filename = str(url_or_filename) 108 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 109 | cache_dir = str(cache_dir) 110 | 111 | parsed = urlparse(url_or_filename) 112 | 113 | if parsed.scheme in ('http', 'https', 's3'): 114 | # URL, so get it from the cache (downloading if necessary) 115 | return get_from_cache(url_or_filename, cache_dir) 116 | elif os.path.exists(url_or_filename): 117 | # File, and it exists. 118 | return url_or_filename 119 | elif parsed.scheme == '': 120 | # File, but it doesn't exist. 121 | raise EnvironmentError("file {} not found".format(url_or_filename)) 122 | else: 123 | # Something unknown 124 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 125 | 126 | 127 | def split_s3_path(url): 128 | """Split a full s3 path into the bucket name and path.""" 129 | parsed = urlparse(url) 130 | if not parsed.netloc or not parsed.path: 131 | raise ValueError("bad s3 path {}".format(url)) 132 | bucket_name = parsed.netloc 133 | s3_path = parsed.path 134 | # Remove '/' at beginning of path. 135 | if s3_path.startswith("/"): 136 | s3_path = s3_path[1:] 137 | return bucket_name, s3_path 138 | 139 | 140 | def s3_request(func): 141 | """ 142 | Wrapper function for s3 requests in order to create more helpful error 143 | messages. 144 | """ 145 | 146 | @wraps(func) 147 | def wrapper(url, *args, **kwargs): 148 | try: 149 | return func(url, *args, **kwargs) 150 | except ClientError as exc: 151 | if int(exc.response["Error"]["Code"]) == 404: 152 | raise EnvironmentError("file {} not found".format(url)) 153 | else: 154 | raise 155 | 156 | return wrapper 157 | 158 | 159 | @s3_request 160 | def s3_etag(url): 161 | """Check ETag on S3 object.""" 162 | s3_resource = boto3.resource("s3") 163 | bucket_name, s3_path = split_s3_path(url) 164 | s3_object = s3_resource.Object(bucket_name, s3_path) 165 | return s3_object.e_tag 166 | 167 | 168 | @s3_request 169 | def s3_get(url, temp_file): 170 | """Pull a file directly from S3.""" 171 | s3_resource = boto3.resource("s3") 172 | bucket_name, s3_path = split_s3_path(url) 173 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 174 | 175 | 176 | def http_get(url, temp_file): 177 | req = requests.get(url, stream=True) 178 | content_length = req.headers.get('Content-Length') 179 | total = int(content_length) if content_length is not None else None 180 | progress = tqdm(unit="B", total=total) 181 | for chunk in req.iter_content(chunk_size=1024): 182 | if chunk: # filter out keep-alive new chunks 183 | progress.update(len(chunk)) 184 | temp_file.write(chunk) 185 | progress.close() 186 | 187 | 188 | def get_from_cache(url, cache_dir=None): 189 | """ 190 | Given a URL, look for the corresponding dataset in the local cache. 191 | If it's not there, download it. Then return the path to the cached file. 192 | """ 193 | if cache_dir is None: 194 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 195 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 196 | cache_dir = str(cache_dir) 197 | 198 | if not os.path.exists(cache_dir): 199 | os.makedirs(cache_dir) 200 | 201 | # Get eTag to add to filename, if it exists. 202 | if url.startswith("s3://"): 203 | etag = s3_etag(url) 204 | else: 205 | try: 206 | response = requests.head(url, allow_redirects=True) 207 | if response.status_code != 200: 208 | etag = None 209 | else: 210 | etag = response.headers.get("ETag") 211 | except EnvironmentError: 212 | etag = None 213 | 214 | if sys.version_info[0] == 2 and etag is not None: 215 | etag = etag.decode('utf-8') 216 | filename = url_to_filename(url, etag) 217 | 218 | # get cache path to put the file 219 | cache_path = os.path.join(cache_dir, filename) 220 | 221 | # If we don't have a connection (etag is None) and can't identify the file 222 | # try to get the last downloaded one 223 | if not os.path.exists(cache_path) and etag is None: 224 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') 225 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) 226 | if matching_files: 227 | cache_path = os.path.join(cache_dir, matching_files[-1]) 228 | 229 | if not os.path.exists(cache_path): 230 | # Download to temporary file, then copy to cache dir once finished. 231 | # Otherwise you get corrupt cache entries if the download gets interrupted. 232 | with tempfile.NamedTemporaryFile() as temp_file: 233 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 234 | 235 | # GET file object 236 | if url.startswith("s3://"): 237 | s3_get(url, temp_file) 238 | else: 239 | http_get(url, temp_file) 240 | 241 | # we are copying the file before closing it, so flush to avoid truncation 242 | temp_file.flush() 243 | # shutil.copyfileobj() starts at the current position, so go to the start 244 | temp_file.seek(0) 245 | 246 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 247 | with open(cache_path, 'wb') as cache_file: 248 | shutil.copyfileobj(temp_file, cache_file) 249 | 250 | logger.info("creating metadata file for %s", cache_path) 251 | meta = {'url': url, 'etag': etag} 252 | meta_path = cache_path + '.json' 253 | with open(meta_path, 'w') as meta_file: 254 | output_string = json.dumps(meta) 255 | if sys.version_info[0] == 2 and isinstance(output_string, str): 256 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2 257 | meta_file.write(output_string) 258 | 259 | logger.info("removing temp file %s", temp_file.name) 260 | 261 | return cache_path 262 | 263 | 264 | def read_set_from_file(filename): 265 | ''' 266 | Extract a de-duped collection (set) of text from a file. 267 | Expected file format is one item per line. 268 | ''' 269 | collection = set() 270 | with open(filename, 'r', encoding='utf-8') as file_: 271 | for line in file_: 272 | collection.add(line.rstrip()) 273 | return collection 274 | 275 | 276 | def get_file_extension(path, dot=True, lower=True): 277 | ext = os.path.splitext(path)[1] 278 | ext = ext if dot else ext[1:] 279 | return ext.lower() if lower else ext -------------------------------------------------------------------------------- /my_bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | 23 | def warmup_cosine(x, warmup=0.002): 24 | if x < warmup: 25 | return x/warmup 26 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 27 | 28 | def warmup_constant(x, warmup=0.002): 29 | if x < warmup: 30 | return x/warmup 31 | return 1.0 32 | 33 | def warmup_linear(x, warmup=0.002): 34 | if x < warmup: 35 | return x/warmup 36 | return 1.0 - x 37 | 38 | SCHEDULES = { 39 | 'warmup_cosine':warmup_cosine, 40 | 'warmup_constant':warmup_constant, 41 | 'warmup_linear':warmup_linear, 42 | } 43 | 44 | 45 | class BertAdam(Optimizer): 46 | """Implements BERT version of Adam algorithm with weight decay fix. 47 | Params: 48 | lr: learning rate 49 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 50 | t_total: total number of training steps for the learning 51 | rate schedule, -1 means constant learning rate. Default: -1 52 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 53 | b1: Adams b1. Default: 0.9 54 | b2: Adams b2. Default: 0.999 55 | e: Adams epsilon. Default: 1e-6 56 | weight_decay: Weight decay. Default: 0.01 57 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 58 | """ 59 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 60 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 61 | max_grad_norm=1.0): 62 | if lr is not required and lr < 0.0: 63 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 64 | if schedule not in SCHEDULES: 65 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 66 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 67 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 68 | if not 0.0 <= b1 < 1.0: 69 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 70 | if not 0.0 <= b2 < 1.0: 71 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 72 | if not e >= 0.0: 73 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 74 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 75 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 76 | max_grad_norm=max_grad_norm) 77 | super(BertAdam, self).__init__(params, defaults) 78 | 79 | def get_lr(self): 80 | lr = [] 81 | for group in self.param_groups: 82 | for p in group['params']: 83 | state = self.state[p] 84 | if len(state) == 0: 85 | return [0] 86 | if group['t_total'] != -1: 87 | schedule_fct = SCHEDULES[group['schedule']] 88 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 89 | else: 90 | lr_scheduled = group['lr'] 91 | lr.append(lr_scheduled) 92 | return lr 93 | 94 | def step(self, closure=None): 95 | """Performs a single optimization step. 96 | 97 | Arguments: 98 | closure (callable, optional): A closure that reevaluates the model 99 | and returns the loss. 100 | """ 101 | loss = None 102 | if closure is not None: 103 | loss = closure() 104 | 105 | for group in self.param_groups: 106 | for p in group['params']: 107 | if p.grad is None: 108 | continue 109 | grad = p.grad.data 110 | if grad.is_sparse: 111 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 112 | 113 | state = self.state[p] 114 | 115 | # State initialization 116 | if len(state) == 0: 117 | state['step'] = 0 118 | # Exponential moving average of gradient values 119 | state['next_m'] = torch.zeros_like(p.data) 120 | # Exponential moving average of squared gradient values 121 | state['next_v'] = torch.zeros_like(p.data) 122 | 123 | next_m, next_v = state['next_m'], state['next_v'] 124 | beta1, beta2 = group['b1'], group['b2'] 125 | 126 | # Add grad clipping 127 | if group['max_grad_norm'] > 0: 128 | clip_grad_norm_(p, group['max_grad_norm']) 129 | 130 | # Decay the first and second moment running average coefficient 131 | # In-place operations to update the averages at the same time 132 | next_m.mul_(beta1).add_(1 - beta1, grad) 133 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 134 | update = next_m / (next_v.sqrt() + group['e']) 135 | 136 | # Just adding the square of the weights to the loss function is *not* 137 | # the correct way of using L2 regularization/weight decay with Adam, 138 | # since that will interact with the m and v parameters in strange ways. 139 | # 140 | # Instead we want to decay the weights in a manner that doesn't interact 141 | # with the m/v parameters. This is equivalent to adding the square 142 | # of the weights to the loss with plain (non-momentum) SGD. 143 | if group['weight_decay'] > 0.0: 144 | update += group['weight_decay'] * p.data 145 | 146 | if group['t_total'] != -1: 147 | schedule_fct = SCHEDULES[group['schedule']] 148 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 149 | else: 150 | lr_scheduled = group['lr'] 151 | 152 | update_with_lr = lr_scheduled * update 153 | p.data.add_(-update_with_lr) 154 | 155 | state['step'] += 1 156 | 157 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 158 | # No bias correction 159 | # bias_correction1 = 1 - beta1 ** state['step'] 160 | # bias_correction2 = 1 - beta2 ** state['step'] 161 | 162 | return loss 163 | -------------------------------------------------------------------------------- /my_bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import os 24 | import logging 25 | 26 | from .file_utils import cached_path 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 31 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 32 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 33 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 34 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 35 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 36 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 37 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 38 | } 39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 40 | 'bert-base-uncased': 512, 41 | 'bert-large-uncased': 512, 42 | 'bert-base-cased': 512, 43 | 'bert-large-cased': 512, 44 | 'bert-base-multilingual-uncased': 512, 45 | 'bert-base-multilingual-cased': 512, 46 | 'bert-base-chinese': 512, 47 | } 48 | VOCAB_NAME = 'vocab.txt' 49 | 50 | 51 | def load_vocab(vocab_file): 52 | """Loads a vocabulary file into a dictionary.""" 53 | vocab = collections.OrderedDict() 54 | index = 0 55 | with open(vocab_file, "r", encoding="utf-8") as reader: 56 | while True: 57 | token = reader.readline() 58 | if not token: 59 | break 60 | token = token.strip() 61 | vocab[token] = index 62 | index += 1 63 | return vocab 64 | 65 | 66 | def whitespace_tokenize(text): 67 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 68 | text = text.strip() 69 | if not text: 70 | return [] 71 | tokens = text.split() 72 | return tokens 73 | 74 | 75 | class BertTokenizer(object): 76 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 77 | 78 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, 79 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 80 | if not os.path.isfile(vocab_file): 81 | raise ValueError( 82 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 83 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 84 | self.vocab = load_vocab(vocab_file) 85 | self.ids_to_tokens = collections.OrderedDict( 86 | [(ids, tok) for tok, ids in self.vocab.items()]) 87 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 88 | never_split=never_split) 89 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 90 | self.max_len = max_len if max_len is not None else int(1e12) 91 | 92 | def tokenize(self, text): 93 | split_tokens = [] 94 | for token in self.basic_tokenizer.tokenize(text): 95 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 96 | split_tokens.append(sub_token) 97 | return split_tokens 98 | 99 | def convert_tokens_to_ids(self, tokens): 100 | """Converts a sequence of tokens into ids using the vocab.""" 101 | ids = [] 102 | for token in tokens: 103 | ids.append(self.vocab[token]) 104 | if len(ids) > self.max_len: 105 | raise ValueError( 106 | "Token indices sequence length is longer than the specified maximum " 107 | " sequence length for this BERT model ({} > {}). Running this" 108 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) 109 | ) 110 | return ids 111 | 112 | def convert_ids_to_tokens(self, ids): 113 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 114 | tokens = [] 115 | for i in ids: 116 | tokens.append(self.ids_to_tokens[i]) 117 | return tokens 118 | 119 | @classmethod 120 | def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): 121 | """ 122 | Instantiate a PreTrainedBertModel from a pre-trained model file. 123 | Download and cache the pre-trained model file if needed. 124 | """ 125 | if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP: 126 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name] 127 | else: 128 | vocab_file = pretrained_model_name 129 | if os.path.isdir(vocab_file): 130 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 131 | # redirect to the cache, if necessary 132 | try: 133 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 134 | except FileNotFoundError: 135 | logger.error( 136 | "Model name '{}' was not found in model name list ({}). " 137 | "We assumed '{}' was a path or url but couldn't find any file " 138 | "associated to this path or url.".format( 139 | pretrained_model_name, 140 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 141 | vocab_file)) 142 | return None 143 | if resolved_vocab_file == vocab_file: 144 | logger.info("loading vocabulary file {}".format(vocab_file)) 145 | else: 146 | logger.info("loading vocabulary file {} from cache at {}".format( 147 | vocab_file, resolved_vocab_file)) 148 | if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 149 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 150 | # than the number of positional embeddings 151 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name] 152 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 153 | # Instantiate tokenizer. 154 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 155 | return tokenizer 156 | 157 | 158 | class BasicTokenizer(object): 159 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 160 | 161 | def __init__(self, 162 | do_lower_case=True, 163 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 164 | """Constructs a BasicTokenizer. 165 | 166 | Args: 167 | do_lower_case: Whether to lower case the input. 168 | """ 169 | self.do_lower_case = do_lower_case 170 | self.never_split = never_split 171 | 172 | def tokenize(self, text): 173 | """Tokenizes a piece of text.""" 174 | text = self._clean_text(text) 175 | # This was added on November 1st, 2018 for the multilingual and Chinese 176 | # models. This is also applied to the English models now, but it doesn't 177 | # matter since the English models were not trained on any Chinese data 178 | # and generally don't have any Chinese data in them (there are Chinese 179 | # characters in the vocabulary because Wikipedia does have some Chinese 180 | # words in the English Wikipedia.). 181 | text = self._tokenize_chinese_chars(text) 182 | orig_tokens = whitespace_tokenize(text) 183 | split_tokens = [] 184 | for token in orig_tokens: 185 | if self.do_lower_case and token not in self.never_split: 186 | token = token.lower() 187 | token = self._run_strip_accents(token) 188 | split_tokens.extend(self._run_split_on_punc(token)) 189 | 190 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 191 | return output_tokens 192 | 193 | def _run_strip_accents(self, text): 194 | """Strips accents from a piece of text.""" 195 | text = unicodedata.normalize("NFD", text) 196 | output = [] 197 | for char in text: 198 | cat = unicodedata.category(char) 199 | if cat == "Mn": 200 | continue 201 | output.append(char) 202 | return "".join(output) 203 | 204 | def _run_split_on_punc(self, text): 205 | """Splits punctuation on a piece of text.""" 206 | if text in self.never_split: 207 | return [text] 208 | chars = list(text) 209 | i = 0 210 | start_new_word = True 211 | output = [] 212 | while i < len(chars): 213 | char = chars[i] 214 | if _is_punctuation(char): 215 | output.append([char]) 216 | start_new_word = True 217 | else: 218 | if start_new_word: 219 | output.append([]) 220 | start_new_word = False 221 | output[-1].append(char) 222 | i += 1 223 | 224 | return ["".join(x) for x in output] 225 | 226 | def _tokenize_chinese_chars(self, text): 227 | """Adds whitespace around any CJK character.""" 228 | output = [] 229 | for char in text: 230 | cp = ord(char) 231 | if self._is_chinese_char(cp): 232 | output.append(" ") 233 | output.append(char) 234 | output.append(" ") 235 | else: 236 | output.append(char) 237 | return "".join(output) 238 | 239 | def _is_chinese_char(self, cp): 240 | """Checks whether CP is the codepoint of a CJK character.""" 241 | # This defines a "chinese character" as anything in the CJK Unicode block: 242 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 243 | # 244 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 245 | # despite its name. The modern Korean Hangul alphabet is a different block, 246 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 247 | # space-separated words, so they are not treated specially and handled 248 | # like the all of the other languages. 249 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 250 | (cp >= 0x3400 and cp <= 0x4DBF) or # 251 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 252 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 253 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 254 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 255 | (cp >= 0xF900 and cp <= 0xFAFF) or # 256 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 257 | return True 258 | 259 | return False 260 | 261 | def _clean_text(self, text): 262 | """Performs invalid character removal and whitespace cleanup on text.""" 263 | output = [] 264 | for char in text: 265 | cp = ord(char) 266 | if cp == 0 or cp == 0xfffd or _is_control(char): 267 | continue 268 | if _is_whitespace(char): 269 | output.append(" ") 270 | else: 271 | output.append(char) 272 | return "".join(output) 273 | 274 | 275 | class WordpieceTokenizer(object): 276 | """Runs WordPiece tokenization.""" 277 | 278 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 279 | self.vocab = vocab 280 | self.unk_token = unk_token 281 | self.max_input_chars_per_word = max_input_chars_per_word 282 | 283 | def tokenize(self, text): 284 | """Tokenizes a piece of text into its word pieces. 285 | 286 | This uses a greedy longest-match-first algorithm to perform tokenization 287 | using the given vocabulary. 288 | 289 | For example: 290 | input = "unaffable" 291 | output = ["un", "##aff", "##able"] 292 | 293 | Args: 294 | text: A single token or whitespace separated tokens. This should have 295 | already been passed through `BasicTokenizer`. 296 | 297 | Returns: 298 | A list of wordpiece tokens. 299 | """ 300 | 301 | output_tokens = [] 302 | for token in whitespace_tokenize(text): 303 | chars = list(token) 304 | if len(chars) > self.max_input_chars_per_word: 305 | output_tokens.append(self.unk_token) 306 | continue 307 | 308 | is_bad = False 309 | start = 0 310 | sub_tokens = [] 311 | while start < len(chars): 312 | end = len(chars) 313 | cur_substr = None 314 | while start < end: 315 | substr = "".join(chars[start:end]) 316 | if start > 0: 317 | substr = "##" + substr 318 | if substr in self.vocab: 319 | cur_substr = substr 320 | break 321 | end -= 1 322 | if cur_substr is None: 323 | is_bad = True 324 | break 325 | sub_tokens.append(cur_substr) 326 | start = end 327 | 328 | if is_bad: 329 | output_tokens.append(self.unk_token) 330 | else: 331 | output_tokens.extend(sub_tokens) 332 | return output_tokens 333 | 334 | 335 | def _is_whitespace(char): 336 | """Checks whether `chars` is a whitespace character.""" 337 | # \t, \n, and \r are technically contorl characters but we treat them 338 | # as whitespace since they are generally considered as such. 339 | if char == " " or char == "\t" or char == "\n" or char == "\r": 340 | return True 341 | cat = unicodedata.category(char) 342 | if cat == "Zs": 343 | return True 344 | return False 345 | 346 | 347 | def _is_control(char): 348 | """Checks whether `chars` is a control character.""" 349 | # These are technically control characters but we count them as whitespace 350 | # characters. 351 | if char == "\t" or char == "\n" or char == "\r": 352 | return False 353 | cat = unicodedata.category(char) 354 | if cat.startswith("C"): 355 | return True 356 | return False 357 | 358 | 359 | def _is_punctuation(char): 360 | """Checks whether `chars` is a punctuation character.""" 361 | cp = ord(char) 362 | # We treat all non-letter/number ASCII as punctuation. 363 | # Characters such as "^", "$", and "`" are not in the Unicode 364 | # Punctuation class but we treat them as punctuation anyways, for 365 | # consistency. 366 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 367 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 368 | return True 369 | cat = unicodedata.category(char) 370 | if cat.startswith("P"): 371 | return True 372 | return False 373 | -------------------------------------------------------------------------------- /ner_evaluate.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import numpy as np 3 | 4 | def get_chunks(seq, tags): 5 | """ 6 | tags:dic{'per':1,....} 7 | Args: 8 | seq: [4, 4, 0, 0, ...] sequence of labels 9 | tags: dict["O"] = 4 10 | Returns: 11 | list of (chunk_type, chunk_start, chunk_end) 12 | 13 | Example: 14 | seq = [4, 5, 0, 3] 15 | tags = {"B-PER": 4, "I-PER": 5, "B-LOC": 3} 16 | result = [("PER", 0, 2), ("LOC", 3, 4)] 17 | """ 18 | default = tags['O'] 19 | idx_to_tag = {idx: tag for tag, idx in tags.items()} 20 | chunks = [] 21 | chunk_type, chunk_start = None, None 22 | for i, tok in enumerate(seq): 23 | #End of a chunk 1 24 | if tok == default and chunk_type is not None: 25 | # Add a chunk. 26 | chunk = (chunk_type, chunk_start, i) 27 | chunks.append(chunk) 28 | chunk_type, chunk_start = None, None 29 | 30 | # End of a chunk + start of a chunk! 31 | elif tok != default: 32 | tok_chunk_class, tok_chunk_type = get_chunk_type(tok, idx_to_tag) 33 | if chunk_type is None: 34 | chunk_type, chunk_start = tok_chunk_type, i 35 | elif tok_chunk_type != chunk_type or tok_chunk_class == "B": 36 | chunk = (chunk_type, chunk_start, i) 37 | chunks.append(chunk) 38 | chunk_type, chunk_start = tok_chunk_type, i 39 | else: 40 | pass 41 | # end condition 42 | if chunk_type is not None: 43 | chunk = (chunk_type, chunk_start, len(seq)) 44 | chunks.append(chunk) 45 | 46 | return chunks 47 | 48 | def get_chunk_type(tok, idx_to_tag): 49 | """ 50 | Args: 51 | tok: id of token, such as 4 52 | idx_to_tag: dictionary {4: "B-PER", ...} 53 | Returns: 54 | tuple: "B", "PER" 55 | """ 56 | tag_name = idx_to_tag[tok] 57 | tag_class = tag_name.split('-')[0] 58 | tag_type = tag_name.split('-')[-1] 59 | return tag_class, tag_type 60 | 61 | # def run_evaluate(self, sess, test, tags): 62 | def evaluate(labels_pred, labels,words,tags): 63 | 64 | """ 65 | words,pred, right: is a sequence, is label index or word index. 66 | Evaluates performance on test set 67 | Args: 68 | sess: tensorflow session 69 | test: dataset that yields tuple of sentences, tags 70 | tags: {tag: index} dictionary 71 | Returns: 72 | accuracy 73 | f1 score 74 | ... 75 | """ 76 | 77 | #file_write = open('./test_results.txt','w') 78 | 79 | 80 | index = 0 81 | sents_length = [] 82 | 83 | accs = [] 84 | correct_preds, total_correct, total_preds = 0., 0., 0. 85 | 86 | 87 | for lab, lab_pred, word_sent in zip(labels, labels_pred, words): 88 | word_st = word_sent 89 | lab = lab 90 | lab_pred = lab_pred 91 | accs += [a==b for (a, b) in zip(lab, lab_pred)] 92 | lab_chunks = set(get_chunks(lab, tags)) 93 | lab_pred_chunks = set(get_chunks(lab_pred, tags)) 94 | correct_preds += len(lab_chunks & lab_pred_chunks) 95 | total_preds += len(lab_pred_chunks) 96 | total_correct += len(lab_chunks) 97 | 98 | #for i in range(len(word_st)): 99 | #file_write.write('%s\t%s\t%s\n'%(word_st[i],lab[i],lab_pred[i])) 100 | #file_write.write('\n') 101 | 102 | p = correct_preds / total_preds if correct_preds > 0 else 0 103 | r = correct_preds / total_correct if correct_preds > 0 else 0 104 | f1 = 2 * p * r / (p + r) if correct_preds > 0 else 0 105 | acc = np.mean(accs) 106 | 107 | #file_write.close() 108 | return acc, f1,p,r 109 | 110 | def evaluate_each_class(labels_pred, labels,words,tags, class_type): 111 | #class_type:PER or LOC or ORG 112 | index = 0 113 | 114 | accs = [] 115 | correct_preds, total_correct, total_preds = 0., 0., 0. 116 | correct_preds_cla_type, total_preds_cla_type, total_correct_cla_type = 0., 0., 0. 117 | 118 | for lab, lab_pred, word_sent in zip(labels, labels_pred, words): 119 | lab_pre_class_type = [] 120 | lab_class_type=[] 121 | 122 | word_st = word_sent 123 | lab = lab 124 | lab_pred = lab_pred 125 | lab_chunks = get_chunks(lab, tags) 126 | lab_pred_chunks = get_chunks(lab_pred, tags) 127 | for i in range(len(lab_pred_chunks)): 128 | if lab_pred_chunks[i][0] == class_type: 129 | lab_pre_class_type.append(lab_pred_chunks[i]) 130 | lab_pre_class_type_c = set(lab_pre_class_type) 131 | 132 | for i in range(len(lab_chunks)): 133 | if lab_chunks[i][0] ==class_type: 134 | lab_class_type.append(lab_chunks[i]) 135 | lab_class_type_c = set(lab_class_type) 136 | 137 | lab_chunksss = set(lab_chunks) 138 | correct_preds_cla_type +=len(lab_pre_class_type_c & lab_chunksss) 139 | total_preds_cla_type +=len(lab_pre_class_type_c) 140 | total_correct_cla_type += len(lab_class_type_c) 141 | 142 | p = correct_preds_cla_type / total_preds_cla_type if correct_preds_cla_type > 0 else 0 143 | r = correct_preds_cla_type / total_correct_cla_type if correct_preds_cla_type > 0 else 0 144 | f1 = 2 * p * r / (p + r) if correct_preds_cla_type > 0 else 0 145 | 146 | return f1,p,r 147 | 148 | 149 | if __name__ == '__main__': 150 | max_sent=10 151 | tags = {'0':0, 152 | 'B-PER':1, 'I-PER':2, 153 | 'B-LOC':3, 'I-LOC':4, 154 | 'B-ORG':5, 'I-ORG':6, 155 | 'B-OTHER':7, 'I-OTHER':8, 156 | 'O':9} 157 | labels_pred=[ 158 | [9,9,9,1,3,1,2,2,0,0], 159 | [9,9,9,1,3,1,2,0,0,0] 160 | ] 161 | labels = [ 162 | [9,9,9,9,3,1,2,2,0,0], 163 | [9,9,9,9,3,1,2,2,0,0] 164 | ] 165 | words = [ 166 | [0,0,0,0,0,3,6,8,5,7], 167 | [0,0,0,4,5,6,7,9,1,7] 168 | ] 169 | id_to_vocb = {0:'a',1:'b',2:'c',3:'d',4:'e',5:'f',6:'g',7:'h',8:'i',9:'j'} 170 | new_words = [] 171 | for i in range(len(words)): 172 | sent = [] 173 | for j in range(len(words[i])): 174 | sent.append(id_to_vocb[words[i][j]]) 175 | new_words.append(sent) 176 | class_type = 'PER' 177 | acc, f1,p,r = evaluate(labels_pred, labels,new_words,tags) 178 | print(p,r,f1) 179 | f1,p,r = evaluate_each_class(labels_pred, labels,new_words,tags, class_type) 180 | print(p,r,f1) 181 | 182 | -------------------------------------------------------------------------------- /run_cmmt_crf.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import argparse 4 | import csv 5 | import logging 6 | import os 7 | import random 8 | import json 9 | import sys 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | from my_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE 15 | from my_bert.mner_modeling import (CONFIG_NAME, WEIGHTS_NAME, 16 | BertConfig, MTCCMBertForMMTokenClassificationCRF) 17 | from my_bert.optimization import BertAdam, warmup_linear 18 | from my_bert.tokenization import BertTokenizer 19 | from seqeval.metrics import classification_report 20 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 21 | TensorDataset) 22 | from torch.utils.data.distributed import DistributedSampler 23 | from tqdm import tqdm, trange 24 | 25 | import resnet.resnet as resnet 26 | #from resnet.resnet import resnet 27 | from resnet.resnet_utils import myResnet 28 | 29 | from torchvision import transforms 30 | from PIL import Image 31 | 32 | from sklearn.metrics import precision_recall_fscore_support 33 | 34 | from ner_evaluate import evaluate_each_class 35 | from ner_evaluate import evaluate 36 | from transformers import RobertaTokenizer, RobertaModel 37 | 38 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 39 | datefmt = '%m/%d/%Y %H:%M:%S', 40 | level = logging.INFO) 41 | logger = logging.getLogger(__name__) 42 | 43 | 44 | def image_process(image_path, transform): 45 | image = Image.open(image_path).convert('RGB') 46 | image = transform(image) 47 | return image 48 | 49 | 50 | class InputExample(object): 51 | """A single training/test example for simple sequence classification.""" 52 | 53 | def __init__(self, guid, text_a, text_b=None, label=None): 54 | """Constructs a InputExample. 55 | 56 | Args: 57 | guid: Unique id for the example. 58 | text_a: string. The untokenized text of the first sequence. For single 59 | sequence tasks, only this sequence must be specified. 60 | text_b: (Optional) string. The untokenized text of the second sequence. 61 | Only must be specified for sequence pair tasks. 62 | label: (Optional) string. The label of the example. This should be 63 | specified for train and dev examples, but not for test examples. 64 | """ 65 | self.guid = guid 66 | self.text_a = text_a 67 | self.text_b = text_b 68 | self.label = label 69 | 70 | 71 | class MMInputExample(object): 72 | """A single training/test example for simple sequence classification.""" 73 | 74 | def __init__(self, guid, text_a, text_b, img_id, label=None, auxlabel=None,imagelabel= None): # yl add 75 | """Constructs a InputExample. 76 | 77 | Args: 78 | guid: Unique id for the example. 79 | text_a: string. The untokenized text of the first sequence. For single 80 | sequence tasks, only this sequence must be specified. 81 | text_b: (Optional) string. The untokenized text of the second sequence. 82 | Only must be specified for sequence pair tasks. 83 | label: (Optional) string. The label of the example. This should be 84 | specified for train and dev examples, but not for test examples. 85 | """ 86 | self.guid = guid 87 | self.text_a = text_a 88 | self.text_b = text_b 89 | self.img_id = img_id 90 | self.label = label 91 | self.auxlabel = auxlabel 92 | self.imagelabel = imagelabel 93 | 94 | 95 | class InputFeatures(object): 96 | """A single set of features of data.""" 97 | 98 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 99 | self.input_ids = input_ids 100 | self.input_mask = input_mask 101 | self.segment_ids = segment_ids 102 | self.label_id = label_id 103 | 104 | 105 | class MMInputFeatures(object): 106 | """A single set of features of data.""" 107 | 108 | def __init__(self, input_ids, input_mask, added_input_mask, segment_ids, img_feat, label_id, auxlabel_id, imagelabel): 109 | self.input_ids = input_ids 110 | self.input_mask = input_mask 111 | self.added_input_mask = added_input_mask 112 | self.segment_ids = segment_ids 113 | self.img_feat = img_feat 114 | self.label_id = label_id 115 | self.auxlabel_id = auxlabel_id 116 | self.imagelabel = imagelabel 117 | 118 | 119 | def readfile(filename): 120 | ''' 121 | read file 122 | return format : 123 | [ ['EU', 'B-ORG'], ['rejects', 'O'], ['German', 'B-MISC'], ['call', 'O'], ['to', 'O'], ['boycott', 'O'], ['British', 'B-MISC'], ['lamb', 'O'], ['.', 'O'] ] 124 | ''' 125 | f = open(filename) 126 | data = [] 127 | sentence = [] 128 | label= [] 129 | for line in f: 130 | if len(line)==0 or line.startswith('-DOCSTART') or line[0]=="\n": 131 | if len(sentence) > 0: 132 | data.append((sentence,label)) 133 | sentence = [] 134 | label = [] 135 | continue 136 | splits = line.split(' ') 137 | sentence.append(splits[0]) 138 | label.append(splits[-1][:-1]) 139 | 140 | if len(sentence) >0: 141 | data.append((sentence,label)) 142 | sentence = [] 143 | label = [] 144 | 145 | print("The number of samples: "+ str(len(data))) 146 | return data 147 | 148 | 149 | def mmreadfile(filename, image_filename, path_img): 150 | ''' 151 | read file 152 | return format : 153 | [ ['EU', 'B-ORG'], ['rejects', 'O'], ['German', 'B-MISC'], ['call', 'O'], ['to', 'O'], ['boycott', 'O'], ['British', 'B-MISC'], ['lamb', 'O'], ['.', 'O'] ] 154 | ''' 155 | transform = transforms.Compose([ 156 | transforms.RandomCrop(224), # args.crop_size, by default it is set to be 224 157 | transforms.RandomHorizontalFlip(), 158 | transforms.ToTensor(), 159 | transforms.Normalize((0.485, 0.456, 0.406), 160 | (0.229, 0.224, 0.225))]) 161 | with open(image_filename, 'r') as f: 162 | image_data = json.load(f) 163 | f = open(filename, encoding='utf-8') 164 | data = [] 165 | imgs = [] 166 | auxlabels = [] 167 | sentence = [] 168 | label= [] 169 | auxlabel = [] 170 | imagelabels = [] 171 | imgid = '' 172 | count = 0 173 | # print(image_data.keys()) 174 | for line in f: 175 | if line.startswith('IMGID:'): 176 | imgid = line.strip().split('IMGID:')[1]+'.jpg' 177 | continue 178 | if line[0]=="\n": 179 | if len(sentence) > 0: 180 | data.append((sentence,label)) 181 | imgs.append(imgid) 182 | image_path = os.path.join(path_img, imgid) 183 | if not os.path.exists(image_path): 184 | print(image_path) 185 | try: 186 | image = image_process(image_path, transform) 187 | except: 188 | # print('image has problem!') 189 | imgid = '17_06_4705.jpg' 190 | 191 | image_label = image_data.get(imgid) 192 | if image_label == None: 193 | count += 1 194 | #print(sentence) 195 | #print(label) 196 | #print(imgid) 197 | auxlabels.append(auxlabel) 198 | imagelabels.append(image_label) 199 | sentence = [] 200 | label = [] 201 | imgid = '' 202 | auxlabel = [] 203 | continue 204 | splits = line.split('\t') 205 | sentence.append(splits[0]) 206 | cur_label = splits[1] #splits[-1][:-1] # yl add 207 | if cur_label == 'B-OTHER': 208 | cur_label = 'B-MISC' 209 | elif cur_label == 'I-OTHER': 210 | cur_label = 'I-MISC' 211 | label.append(cur_label) 212 | auxlabel.append(cur_label) 213 | #auxlabel.append(splits[2][:-1]) # yl add 214 | 215 | print("The number of samples with NULL image labels: "+ str(count)) 216 | if len(sentence) >0: 217 | data.append((sentence,label)) 218 | imgs.append(imgid) 219 | auxlabels.append(auxlabel) 220 | imagelabels.append(image_label) 221 | sentence = [] 222 | label = [] 223 | auxlabel = [] 224 | 225 | print("The number of samples: "+ str(len(data))) 226 | print("The number of images: "+ str(len(imgs))) 227 | return data, imgs, auxlabels, imagelabels 228 | 229 | 230 | class DataProcessor(object): 231 | """Base class for data converters for sequence classification data sets.""" 232 | 233 | def get_train_examples(self, data_dir,image_filename,path_img): #yl add 234 | """Gets a collection of `InputExample`s for the train set.""" 235 | raise NotImplementedError() 236 | 237 | def get_dev_examples(self, data_dir,image_filename,path_img): #yl add 238 | """Gets a collection of `InputExample`s for the dev set.""" 239 | raise NotImplementedError() 240 | 241 | def get_labels(self): 242 | """Gets the list of labels for this data set.""" 243 | raise NotImplementedError() 244 | 245 | @classmethod 246 | def _read_tsv(cls, input_file, quotechar=None): 247 | """Reads a tab separated value file.""" 248 | return readfile(input_file) 249 | 250 | def _read_mmtsv(cls, input_file, image_filename, path_img, quotechar=None): 251 | """Reads a tab separated value file.""" 252 | return mmreadfile(input_file, image_filename, path_img) # yl add 253 | 254 | 255 | class MNERProcessor(DataProcessor): 256 | """Processor for the CoNLL-2003 data set.""" 257 | 258 | def get_train_examples(self, data_dir, image_filename, path_img): # yl add 259 | """See base class.""" 260 | data, imgs, auxlabels,imagelabels = self._read_mmtsv(os.path.join(data_dir, "train.txt"), image_filename, path_img) # yl add 261 | return self._create_examples(data, imgs, auxlabels, imagelabels, "train") 262 | 263 | def get_dev_examples(self, data_dir,image_filename,path_img): # yl add 264 | """See base class.""" 265 | data, imgs, auxlabels,imagelabels = self._read_mmtsv(os.path.join(data_dir, "valid.txt"), image_filename, path_img) # yl add 266 | return self._create_examples(data, imgs, auxlabels, imagelabels, "dev") 267 | 268 | def get_test_examples(self, data_dir,image_filename,path_img): # yl add 269 | """See base class.""" 270 | data, imgs, auxlabels, imagelabels = self._read_mmtsv(os.path.join(data_dir, "test.txt"), image_filename, path_img) # yl add 271 | return self._create_examples(data, imgs, auxlabels, imagelabels, "test") 272 | 273 | 274 | def get_labels(self): 275 | return ["O", "B-NEU", "I-NEU", "B-POS", "I-POS", "B-NEG", "I-NEG","X","",""] 276 | 277 | ### modify 278 | def get_auxlabels(self): 279 | return ["O", "B-NEU", "I-NEU", "B-POS", "I-POS", "B-NEG", "I-NEG", "X", "", ""] #yl add ["O", "B-AE", "I-AE", "B-OE", "I-OE", "X", "", ""] 280 | #def get_auxlabels(self): 281 | #return ["O", "B", "I", "X", "[CLS]", "[SEP]"] 282 | 283 | ### modify 284 | def get_start_label_id(self): 285 | label_list = self.get_labels() 286 | label_map = {label: i for i, label in enumerate(label_list, 1)} 287 | return label_map[''] 288 | 289 | def get_stop_label_id(self): 290 | label_list = self.get_labels() 291 | label_map = {label: i for i, label in enumerate(label_list, 1)} 292 | return label_map[''] 293 | 294 | def _create_examples(self, lines, imgs, auxlabels, imagelabels, set_type): # yl add 295 | examples = [] 296 | for i, (sentence, label) in enumerate(lines): 297 | guid = "%s-%s" % (set_type, i) 298 | text_a = ' '.join(sentence) 299 | text_b = None 300 | img_id = imgs[i] 301 | label = label 302 | auxlabel = auxlabels[i] 303 | imagelabel = imagelabels[i] 304 | examples.append(MMInputExample(guid=guid, text_a=text_a, text_b=text_b, img_id=img_id, label=label, auxlabel=auxlabel, imagelabel = imagelabel)) 305 | return examples 306 | 307 | 308 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): 309 | """Loads a data file into a list of `InputBatch`s.""" 310 | 311 | label_map = {label : i for i, label in enumerate(label_list,1)} 312 | 313 | features = [] 314 | for (ex_index,example) in enumerate(examples): 315 | textlist = example.text_a.split(' ') 316 | labellist = example.label 317 | tokens = [] 318 | labels = [] 319 | for i, word in enumerate(textlist): 320 | token = tokenizer.tokenize(word) 321 | tokens.extend(token) 322 | label_1 = labellist[i] 323 | for m in range(len(token)): 324 | if m == 0: 325 | labels.append(label_1) 326 | else: 327 | labels.append("X") 328 | if len(tokens) >= max_seq_length - 1: 329 | tokens = tokens[0:(max_seq_length - 2)] 330 | labels = labels[0:(max_seq_length - 2)] 331 | ntokens = [] 332 | segment_ids = [] 333 | label_ids = [] 334 | ntokens.append("") 335 | segment_ids.append(0) 336 | label_ids.append(label_map[""]) 337 | for i, token in enumerate(tokens): 338 | ntokens.append(token) 339 | segment_ids.append(0) 340 | label_ids.append(label_map[labels[i]]) 341 | ntokens.append("") 342 | segment_ids.append(0) 343 | label_ids.append(label_map[""]) 344 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) 345 | input_mask = [1] * len(input_ids) 346 | while len(input_ids) < max_seq_length: 347 | input_ids.append(0) 348 | input_mask.append(0) 349 | segment_ids.append(0) 350 | label_ids.append(0) 351 | assert len(input_ids) == max_seq_length 352 | assert len(input_mask) == max_seq_length 353 | assert len(segment_ids) == max_seq_length 354 | assert len(label_ids) == max_seq_length 355 | 356 | if ex_index < 2: 357 | logger.info("*** Example ***") 358 | logger.info("guid: %s" % (example.guid)) 359 | logger.info("tokens: %s" % " ".join( 360 | [str(x) for x in tokens])) 361 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 362 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 363 | logger.info( 364 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 365 | logger.info("label: %s" % " ".join([str(x) for x in label_ids])) 366 | 367 | features.append( 368 | InputFeatures(input_ids=input_ids, 369 | input_mask=input_mask, 370 | segment_ids=segment_ids, 371 | label_id=label_ids)) 372 | return features 373 | 374 | 375 | def convert_mm_examples_to_features(examples, label_list, auxlabel_list, max_seq_length, tokenizer, crop_size, path_img): 376 | """Loads a data file into a list of `InputBatch`s.""" 377 | 378 | label_map = {label: i for i, label in enumerate(label_list, 1)} 379 | auxlabel_map = {label: i for i, label in enumerate(auxlabel_list, 1)} 380 | 381 | features = [] 382 | count = 0 383 | 384 | transform = transforms.Compose([ 385 | transforms.RandomCrop(crop_size), # args.crop_size, by default it is set to be 224 386 | transforms.RandomHorizontalFlip(), 387 | transforms.ToTensor(), 388 | transforms.Normalize((0.485, 0.456, 0.406), 389 | (0.229, 0.224, 0.225))]) 390 | 391 | for (ex_index, example) in enumerate(examples): 392 | textlist = example.text_a.split(' ') 393 | labellist = example.label 394 | auxlabellist = example.auxlabel 395 | imagelabellist = example.imagelabel 396 | imagelabellist = dict(sorted(imagelabellist.items())) 397 | imagelabel_value =[0]* len(imagelabellist) 398 | for i, (k, v) in enumerate(imagelabellist.items()): 399 | imagelabel_value[i]= v 400 | tokens = [] 401 | labels = [] 402 | auxlabels = [] 403 | for i, word in enumerate(textlist): 404 | word = " "+ word 405 | token = tokenizer.tokenize(word) 406 | tokens.extend(token) 407 | label_1 = labellist[i] 408 | auxlabel_1 = auxlabellist[i] 409 | for m in range(len(token)): 410 | if m == 0: 411 | labels.append(label_1) 412 | auxlabels.append(auxlabel_1) 413 | else: 414 | labels.append("X") 415 | auxlabels.append("X") 416 | if len(tokens) >= max_seq_length - 1: 417 | tokens = tokens[0:(max_seq_length - 2)] 418 | labels = labels[0:(max_seq_length - 2)] 419 | auxlabels = auxlabels[0:(max_seq_length - 2)] 420 | ntokens = [] 421 | segment_ids = [] 422 | label_ids = [] 423 | auxlabel_ids = [] 424 | ntokens.append("") 425 | segment_ids.append(0) 426 | label_ids.append(label_map[""]) 427 | auxlabel_ids.append(auxlabel_map[""]) 428 | for i, token in enumerate(tokens): 429 | ntokens.append(token) 430 | segment_ids.append(0) 431 | label_ids.append(label_map[labels[i]]) 432 | auxlabel_ids.append(auxlabel_map[auxlabels[i]]) 433 | ntokens.append("") 434 | segment_ids.append(0) 435 | label_ids.append(label_map[""]) 436 | auxlabel_ids.append(auxlabel_map[""]) 437 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) 438 | input_mask = [1] * len(input_ids) 439 | added_input_mask = [1] * (len(input_ids) + 49) # 1 or 49 is for encoding regional image representations 440 | 441 | while len(input_ids) < max_seq_length: 442 | input_ids.append(0) 443 | input_mask.append(0) 444 | added_input_mask.append(0) 445 | segment_ids.append(0) 446 | label_ids.append(0) 447 | auxlabel_ids.append(0) 448 | 449 | assert len(input_ids) == max_seq_length 450 | assert len(input_mask) == max_seq_length 451 | assert len(segment_ids) == max_seq_length 452 | assert len(label_ids) == max_seq_length 453 | assert len(auxlabel_ids) == max_seq_length 454 | 455 | image_name = example.img_id 456 | image_path = os.path.join(path_img, image_name) 457 | 458 | if not os.path.exists(image_path): 459 | print(image_path) 460 | try: 461 | image = image_process(image_path, transform) 462 | except: 463 | count += 1 464 | # print('image has problem!') 465 | image_path_fail = os.path.join(path_img, '17_06_4705.jpg') 466 | image = image_process(image_path_fail, transform) 467 | 468 | if ex_index < 2: 469 | logger.info("*** Example ***") 470 | logger.info("guid: %s" % (example.guid)) 471 | logger.info("tokens: %s" % " ".join( 472 | [str(x) for x in tokens])) 473 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 474 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 475 | logger.info( 476 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 477 | logger.info("label: %s" % " ".join([str(x) for x in label_ids])) 478 | logger.info("auxlabel: %s" % " ".join([str(x) for x in auxlabel_ids])) 479 | 480 | features.append( 481 | MMInputFeatures(input_ids=input_ids, input_mask=input_mask, added_input_mask=added_input_mask, 482 | segment_ids=segment_ids, img_feat=image, label_id=label_ids, auxlabel_id=auxlabel_ids, imagelabel= imagelabel_value)) 483 | 484 | print('the number of problematic samples: ' + str(count)) 485 | 486 | return features 487 | 488 | 489 | def macro_f1(y_true, y_pred): 490 | p_macro, r_macro, f_macro, support_macro \ 491 | = precision_recall_fscore_support(y_true, y_pred, average='macro') 492 | f_macro = 2*p_macro*r_macro/(p_macro+r_macro) 493 | return p_macro, r_macro, f_macro 494 | 495 | import datetime 496 | def main(): 497 | start_time = datetime.datetime.now().strftime('%m-%d-%Y-%H-%M-%S_') 498 | parser = argparse.ArgumentParser() 499 | 500 | parser.add_argument('--cuda_id', type=str, default='0', 501 | help='Choose which GPUs to run') 502 | parser.add_argument("--bert_model", default="./model/roberta-base-cased", type=str, 503 | help="pre-trained model selected in the list: roberta-base-cased, " 504 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " 505 | "bert-base-multilingual-cased, bert-base-chinese.") 506 | 507 | parser.add_argument("--task_name", 508 | default= "twitter2015", #twitter2017 509 | type=str, 510 | required=True, 511 | help="The name of the task to train.") 512 | parser.add_argument("--cache_dir", 513 | default="", 514 | type=str, 515 | help="Where do you want to store the pre-trained models downloaded from s3") 516 | parser.add_argument("--max_seq_length", 517 | default=128, 518 | type=int, 519 | help="The maximum total input sequence length after WordPiece tokenization. \n" 520 | "Sequences longer than this will be truncated, and sequences shorter \n" 521 | "than this will be padded.") 522 | parser.add_argument("--do_train", 523 | action='store_true', 524 | help="Whether to run training.") 525 | parser.add_argument("--do_eval", 526 | action='store_true', 527 | help="Whether to run eval on the dev set.") 528 | parser.add_argument("--do_lower_case", 529 | action='store_true', 530 | help="Set this flag if you are using an uncased model.") 531 | parser.add_argument("--train_batch_size", 532 | default=32, 533 | type=int, 534 | help="Total batch size for training.") 535 | parser.add_argument("--eval_batch_size", 536 | default=16, 537 | type=int, 538 | help="Total batch size for eval.") 539 | parser.add_argument("--learning_rate", 540 | default=3e-5, 541 | type=float, 542 | help="The initial learning rate for Adam.") 543 | parser.add_argument("--num_train_epochs", 544 | default=25.0, 545 | type=float, 546 | help="Total number of training epochs to perform.") 547 | parser.add_argument("--warmup_proportion", 548 | default=0.1, 549 | type=float, 550 | help="Proportion of training to perform linear learning rate warmup for. " 551 | "E.g., 0.1 = 10%% of training.") 552 | parser.add_argument("--no_cuda", 553 | action='store_true', 554 | help="Whether not to use CUDA when available") 555 | parser.add_argument("--local_rank", 556 | type=int, 557 | default=-1, 558 | help="local_rank for distributed training on gpus") 559 | parser.add_argument('--seed', 560 | type=int, 561 | default=64, 562 | help="random seed for initialization") 563 | parser.add_argument('--gradient_accumulation_steps', 564 | type=int, 565 | default=1, 566 | help="Number of updates steps to accumulate before performing a backward/update pass.") 567 | parser.add_argument('--fp16', 568 | action='store_true', 569 | help="Whether to use 16-bit float precision instead of 32-bit") 570 | parser.add_argument('--loss_scale', 571 | type=float, default=0, 572 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 573 | "0 (default value): dynamic loss scaling.\n" 574 | "Positive power of 2: static loss scaling value.\n") 575 | parser.add_argument('use_roberta', default=True, action = 'store_true') 576 | 577 | parser.add_argument('--mm_model', default='MTCCMBert', help='model name') 578 | parser.add_argument('--layer_num1', type=int, default=1, help='number of txt2img layer') 579 | parser.add_argument('--layer_num2', type=int, default=1, help='number of img2txt layer') 580 | parser.add_argument('--layer_num3', type=int, default=1, help='number of txt2txt layer') 581 | parser.add_argument('--fine_tune_cnn', action='store_true', help='fine tune pre-trained CNN if True') 582 | parser.add_argument('--resnet_root', default='./model/resnet', help='path the pre-trained cnn models') 583 | parser.add_argument('--crop_size', type=int, default=224, help='crop size of image') 584 | parser.add_argument('--path_image', default='./pytorch-pretrained-BERT/twitter_subimages/', help='path to images') 585 | parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") 586 | parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") 587 | parser.add_argument('--alpha', type=float, default=1) 588 | parser.add_argument('--beta', type=float, default=1) 589 | parser.add_argument('--dropout_rate', type=float, default=0.2) 590 | args = parser.parse_args() 591 | os.environ["CUDA_VISIBLE_DEVICES"] =args.cuda_id 592 | 593 | 594 | if args.task_name == "twitter2017": # this refers to twitter-2017 dataset 595 | args.path_image = "/mnt/nfs-storage-titan/alienware/intern18_snap/multi_modal_ABSA_pytorch_naacl/multi_modal_ABSA_pytorch_bilinear/twitter_subimages/" 596 | args.data_dir = "./data/twitter2017" 597 | args.image_filename = "./ANP_data/image_output2017.json" 598 | args.output_dir = start_time + "_twitter2017_output/" 599 | elif args.task_name == "twitter2015": # this refers to twitter-2015 dataset 600 | args.path_image = "/mnt/nfs-storage-titan/alienware/intern18_snap/multi_modal_ABSA_pytorch_naacl/multi_modal_ABSA_pytorch_bilinear/twitter15_images/" 601 | args.data_dir = "./data/twitter2015" 602 | args.image_filename = "./ANP_data/image_output2015.json" 603 | args.output_dir = start_time + "_twitter2015_output/" 604 | 605 | if args.server_ip and args.server_port: 606 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 607 | import ptvsd 608 | print("Waiting for debugger attach") 609 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 610 | ptvsd.wait_for_attach() 611 | 612 | processors = { 613 | "twitter2015": MNERProcessor, 614 | "twitter2017": MNERProcessor 615 | } 616 | 617 | if args.local_rank == -1 or args.no_cuda: 618 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 619 | n_gpu = torch.cuda.device_count() 620 | else: 621 | torch.cuda.set_device(args.local_rank) 622 | device = torch.device("cuda", args.local_rank) 623 | n_gpu = 1 624 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 625 | torch.distributed.init_process_group(backend='nccl') 626 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( 627 | device, n_gpu, bool(args.local_rank != -1), args.fp16)) 628 | 629 | if args.gradient_accumulation_steps < 1: 630 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 631 | args.gradient_accumulation_steps)) 632 | 633 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 634 | 635 | random.seed(args.seed) 636 | np.random.seed(args.seed) 637 | torch.manual_seed(args.seed) 638 | 639 | args.do_train = True 640 | args.do_eval = True 641 | if not args.do_train and not args.do_eval: 642 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 643 | 644 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: 645 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 646 | if not os.path.exists(args.output_dir): 647 | os.makedirs(args.output_dir) 648 | 649 | task_name = args.task_name.lower() 650 | 651 | if task_name not in processors: 652 | raise ValueError("Task not found: %s" % (task_name)) 653 | 654 | processor = processors[task_name]() 655 | label_list = processor.get_labels() 656 | auxlabel_list = processor.get_auxlabels() 657 | num_labels = len(label_list)+1 # label 0 corresponds to padding, label in label_list starts from 1 658 | auxnum_labels = len(auxlabel_list)+1 # label 0 corresponds to padding, label in label_list starts from 1 659 | 660 | start_label_id = processor.get_start_label_id() 661 | stop_label_id = processor.get_stop_label_id() 662 | 663 | trans_matrix = np.zeros((auxnum_labels,num_labels), dtype=float) 664 | tokenizer = RobertaTokenizer.from_pretrained('roberta-base') 665 | 666 | train_examples = None 667 | num_train_optimization_steps = None 668 | if args.do_train: 669 | train_examples = processor.get_train_examples(args.data_dir, args.image_filename, args.path_image) # yl add 670 | num_train_optimization_steps = int( 671 | len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs 672 | if args.local_rank != -1: 673 | num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() 674 | 675 | # Prepare model 676 | cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)) 677 | 678 | if args.mm_model == 'MTCCMBert': 679 | model = MTCCMBertForMMTokenClassificationCRF.from_pretrained(args.bert_model,args.use_roberta, 680 | cache_dir=cache_dir, layer_num1=args.layer_num1, layer_num2=args.layer_num2, layer_num3=args.layer_num3, 681 | num_labels = num_labels, auxnum_labels = auxnum_labels,dropout_rate=args.dropout_rate) 682 | if args.use_roberta: 683 | roberta_dict = torch.load('./model/roberta-base-cased/pytorch_model.bin') #changed to large or base 684 | new_state_dict = model.state_dict() 685 | miss_keys = [] 686 | for k in new_state_dict.keys(): 687 | if k in roberta_dict.keys(): 688 | new_state_dict[k] = roberta_dict[k] 689 | else: 690 | miss_keys.append(k) 691 | if len(miss_keys) > 0: 692 | logger.info('miss keys: {}'.format(miss_keys)) 693 | model.load_state_dict(new_state_dict) 694 | 695 | else: 696 | print('please define your MNER Model') 697 | 698 | net = getattr(resnet, 'resnet152')() 699 | net.load_state_dict(torch.load(os.path.join(args.resnet_root, 'resnet152.pth'))) 700 | encoder = myResnet(net, args.fine_tune_cnn, device) 701 | if args.fp16: 702 | model.half() 703 | encoder.half() 704 | model.to(device) 705 | encoder.to(device) 706 | if args.local_rank != -1: 707 | try: 708 | from apex.parallel import DistributedDataParallel as DDP 709 | except ImportError: 710 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 711 | 712 | model = DDP(model) 713 | encoder = DDP(encoder) 714 | elif n_gpu > 1: 715 | model = torch.nn.DataParallel(model) 716 | encoder = torch.nn.DataParallel(encoder) 717 | 718 | param_optimizer = list(model.named_parameters()) 719 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 720 | optimizer_grouped_parameters = [ 721 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 722 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 723 | ] 724 | if args.fp16: 725 | try: 726 | from apex.optimizers import FP16_Optimizer 727 | from apex.optimizers import FusedAdam 728 | except ImportError: 729 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 730 | 731 | optimizer = FusedAdam(optimizer_grouped_parameters, 732 | lr=args.learning_rate, 733 | bias_correction=False, 734 | max_grad_norm=1.0) 735 | if args.loss_scale == 0: 736 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) 737 | else: 738 | optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) 739 | 740 | else: 741 | optimizer = BertAdam(optimizer_grouped_parameters, 742 | lr=args.learning_rate, 743 | warmup=args.warmup_proportion, 744 | t_total=num_train_optimization_steps) 745 | 746 | global_step = 0 747 | nb_tr_steps = 0 748 | tr_loss = 0 749 | 750 | output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) 751 | output_config_file = os.path.join(args.output_dir, CONFIG_NAME) 752 | output_encoder_file = os.path.join(args.output_dir, "pytorch_encoder.bin") 753 | 754 | if args.do_train: 755 | train_features = convert_mm_examples_to_features( 756 | train_examples, label_list, auxlabel_list, args.max_seq_length, tokenizer, args.crop_size, args.path_image) 757 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 758 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 759 | all_added_input_mask = torch.tensor([f.added_input_mask for f in train_features], dtype=torch.long) 760 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 761 | all_img_feats = torch.stack([f.img_feat for f in train_features]) 762 | all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) 763 | all_auxlabel_ids = torch.tensor([f.auxlabel_id for f in train_features], dtype=torch.long) 764 | all_imagelabel = torch.tensor([f.imagelabel for f in train_features], dtype=torch.float) 765 | 766 | train_data = TensorDataset(all_input_ids, all_input_mask, all_added_input_mask, all_segment_ids, all_img_feats, 767 | all_label_ids, all_auxlabel_ids,all_imagelabel) 768 | 769 | if args.local_rank == -1: 770 | train_sampler = RandomSampler(train_data) 771 | else: 772 | train_sampler = DistributedSampler(train_data) 773 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) 774 | 775 | eval_examples = processor.get_dev_examples(args.data_dir, args.image_filename, args.path_image) # yl add 776 | eval_features = convert_mm_examples_to_features( 777 | eval_examples, label_list, auxlabel_list, args.max_seq_length, tokenizer, args.crop_size, args.path_image) 778 | 779 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 780 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 781 | all_added_input_mask = torch.tensor([f.added_input_mask for f in eval_features], dtype=torch.long) 782 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 783 | all_img_feats = torch.stack([f.img_feat for f in eval_features]) 784 | all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) 785 | all_auxlabel_ids = torch.tensor([f.auxlabel_id for f in eval_features], dtype=torch.long) 786 | all_imagelabel = torch.tensor([f.imagelabel for f in eval_features], dtype=torch.float) 787 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_added_input_mask, all_segment_ids, all_img_feats, 788 | all_label_ids, all_auxlabel_ids,all_imagelabel) 789 | 790 | # Run prediction for full data 791 | eval_sampler = SequentialSampler(eval_data) 792 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 793 | 794 | test_eval_examples = processor.get_test_examples(args.data_dir, args.image_filename, args.path_image) # yl add 795 | test_eval_features = convert_mm_examples_to_features( 796 | test_eval_examples, label_list, auxlabel_list, args.max_seq_length, tokenizer, args.crop_size, args.path_image) 797 | all_input_ids = torch.tensor([f.input_ids for f in test_eval_features], dtype=torch.long) 798 | all_input_mask = torch.tensor([f.input_mask for f in test_eval_features], dtype=torch.long) 799 | all_added_input_mask = torch.tensor([f.added_input_mask for f in test_eval_features], dtype=torch.long) 800 | all_segment_ids = torch.tensor([f.segment_ids for f in test_eval_features], dtype=torch.long) 801 | all_img_feats = torch.stack([f.img_feat for f in test_eval_features]) 802 | all_label_ids = torch.tensor([f.label_id for f in test_eval_features], dtype=torch.long) 803 | all_auxlabel_ids = torch.tensor([f.auxlabel_id for f in test_eval_features], dtype=torch.long) 804 | all_imagelabel = torch.tensor([f.imagelabel for f in test_eval_features], dtype=torch.float) 805 | 806 | test_eval_data = TensorDataset(all_input_ids, all_input_mask, all_added_input_mask, all_segment_ids, all_img_feats, 807 | all_label_ids, all_auxlabel_ids,all_imagelabel) 808 | # Run prediction for full data 809 | test_eval_sampler = SequentialSampler(test_eval_data) 810 | test_eval_dataloader = DataLoader(test_eval_data, sampler=test_eval_sampler, batch_size=args.eval_batch_size) 811 | 812 | max_dev_f1 = 0.0 813 | max_test_f1 = 0.0 814 | best_dev_epoch = 0 815 | best_test_epoch = 0 816 | logger.info("***** Running training *****") 817 | for train_idx in trange(int(args.num_train_epochs), desc="Epoch"): 818 | logger.info("********** Epoch: " + str(train_idx) + " **********") 819 | logger.info(" Num examples = %d", len(train_examples)) 820 | logger.info(" Batch size = %d", args.train_batch_size) 821 | logger.info(" Num steps = %d", num_train_optimization_steps) 822 | model.train() 823 | encoder.train() 824 | encoder.zero_grad() 825 | tr_loss = 0 826 | nb_tr_examples, nb_tr_steps = 0, 0 827 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 828 | batch = tuple(t.to(device) for t in batch) 829 | input_ids, input_mask, added_input_mask, segment_ids, img_feats, label_ids, auxlabel_ids, imagelabel = batch 830 | with torch.no_grad(): 831 | imgs_f, img_mean, img_att = encoder(img_feats) 832 | trans_matrix = torch.tensor(trans_matrix).to(device) 833 | neg_log_likelihood = model(input_ids, segment_ids, input_mask, added_input_mask, 834 | img_att, trans_matrix, imagelabel,args.alpha, args.beta,label_ids, auxlabel_ids) 835 | if n_gpu > 1: 836 | neg_log_likelihood = neg_log_likelihood.mean() # mean() to average on multi-gpu. 837 | if args.gradient_accumulation_steps > 1: 838 | neg_log_likelihood = neg_log_likelihood / args.gradient_accumulation_steps 839 | 840 | if args.fp16: 841 | optimizer.backward(neg_log_likelihood) 842 | else: 843 | neg_log_likelihood.backward() 844 | 845 | tr_loss += neg_log_likelihood.item() 846 | nb_tr_examples += input_ids.size(0) 847 | nb_tr_steps += 1 848 | if (step + 1) % args.gradient_accumulation_steps == 0: 849 | if args.fp16: 850 | # modify learning rate with special warm up BERT uses 851 | # if args.fp16 is False, BertAdam is used that handles this automatically 852 | lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion) 853 | for param_group in optimizer.param_groups: 854 | param_group['lr'] = lr_this_step 855 | optimizer.step() 856 | optimizer.zero_grad() 857 | global_step += 1 858 | 859 | logger.info("***** Running evaluation on Dev Set*****") 860 | logger.info(" Num examples = %d", len(eval_examples)) 861 | logger.info(" Batch size = %d", args.eval_batch_size) 862 | 863 | model.eval() 864 | encoder.eval() 865 | eval_loss, eval_accuracy = 0, 0 866 | nb_eval_steps, nb_eval_examples = 0, 0 867 | 868 | y_true = [] 869 | y_pred = [] 870 | y_true_idx = [] 871 | y_pred_idx = [] 872 | label_map = {i: label for i, label in enumerate(label_list, 1)} 873 | label_map[0] = "PAD" 874 | for input_ids, input_mask, added_input_mask, segment_ids, img_feats, label_ids, auxlabel_ids,imagelabel in tqdm( 875 | eval_dataloader, 876 | desc="Evaluating"): 877 | input_ids = input_ids.to(device) 878 | input_mask = input_mask.to(device) 879 | added_input_mask = added_input_mask.to(device) 880 | segment_ids = segment_ids.to(device) 881 | img_feats = img_feats.to(device) 882 | label_ids = label_ids.to(device) 883 | auxlabel_ids = auxlabel_ids.to(device) 884 | 885 | with torch.no_grad(): 886 | imgs_f, img_mean, img_att = encoder(img_feats) 887 | predicted_label_seq_ids = model(input_ids, segment_ids, input_mask, added_input_mask, img_att, 888 | trans_matrix, imagelabel,args.alpha, args.beta) 889 | 890 | logits = predicted_label_seq_ids 891 | label_ids = label_ids.to('cpu').numpy() 892 | input_mask = input_mask.to('cpu').numpy() 893 | for i, mask in enumerate(input_mask): 894 | temp_1 = [] 895 | temp_2 = [] 896 | tmp1_idx = [] 897 | tmp2_idx = [] 898 | for j, m in enumerate(mask): 899 | if j == 0: 900 | continue 901 | if m: 902 | if label_map[label_ids[i][j]] != "X" and label_map[ 903 | label_ids[i][j]] != "": 904 | temp_1.append(label_map[label_ids[i][j]]) 905 | tmp1_idx.append(label_ids[i][j]) 906 | temp_2.append(label_map[logits[i][j]]) 907 | tmp2_idx.append(logits[i][j]) 908 | else: 909 | break 910 | y_true.append(temp_1) 911 | y_pred.append(temp_2) 912 | y_true_idx.append(tmp1_idx) 913 | y_pred_idx.append(tmp2_idx) 914 | 915 | # report = classification_report(y_true, y_pred, digits=4) 916 | sentence_list = [] 917 | dev_data, imgs, _ ,_ = processor._read_mmtsv(os.path.join(args.data_dir, "valid.txt"),args.image_filename, args.path_image) 918 | for i in range(len(y_pred)): 919 | sentence = dev_data[i][0] 920 | sentence_list.append(sentence) 921 | reverse_label_map = {label: i for i, label in enumerate(label_list, 1)} 922 | acc, f1, p, r = evaluate(y_pred_idx, y_true_idx, sentence_list, reverse_label_map) 923 | logger.info("***** Dev Eval results *****") 924 | print("Overall: ", p, r, f1) 925 | per_f1, per_p, per_r = evaluate_each_class(y_pred_idx, y_true_idx, sentence_list, reverse_label_map, 'POS') 926 | print("Positive: ", per_p, per_r, per_f1) 927 | loc_f1, loc_p, loc_r = evaluate_each_class(y_pred_idx, y_true_idx, sentence_list, reverse_label_map, 'NEU') 928 | print("Neutral: ", loc_p, loc_r, loc_f1) 929 | org_f1, org_p, org_r = evaluate_each_class(y_pred_idx, y_true_idx, sentence_list, reverse_label_map, 'NEG') 930 | print("Negative: ", org_p, org_r, org_f1) 931 | F_score_dev = f1 932 | 933 | logger.info("***** Running Test evaluation *****") 934 | logger.info(" Num examples = %d", len(test_eval_examples)) 935 | logger.info(" Batch size = %d", args.eval_batch_size) 936 | y_true = [] 937 | y_pred = [] 938 | y_true_idx = [] 939 | y_pred_idx = [] 940 | label_map = {i: label for i, label in enumerate(label_list, 1)} 941 | label_map[0] = "PAD" 942 | for input_ids, input_mask, added_input_mask, segment_ids, img_feats, label_ids, auxlabel_ids,imagelabel in tqdm(test_eval_dataloader, 943 | desc="Evaluating"): 944 | input_ids = input_ids.to(device) 945 | input_mask = input_mask.to(device) 946 | added_input_mask = added_input_mask.to(device) 947 | segment_ids = segment_ids.to(device) 948 | img_feats = img_feats.to(device) 949 | label_ids = label_ids.to(device) 950 | auxlabel_ids = auxlabel_ids.to(device) 951 | 952 | with torch.no_grad(): 953 | imgs_f, img_mean, img_att = encoder(img_feats) 954 | predicted_label_seq_ids = model(input_ids, segment_ids, input_mask, added_input_mask, img_att,trans_matrix,imagelabel,args.alpha, args.beta) 955 | 956 | logits = predicted_label_seq_ids 957 | label_ids = label_ids.to('cpu').numpy() 958 | input_mask = input_mask.to('cpu').numpy() 959 | for i, mask in enumerate(input_mask): 960 | temp_1 = [] 961 | temp_2 = [] 962 | tmp1_idx = [] 963 | tmp2_idx = [] 964 | for j, m in enumerate(mask): 965 | if j == 0: 966 | continue 967 | if m: 968 | if label_map[label_ids[i][j]] != "X" and label_map[ 969 | label_ids[i][j]] != "": 970 | temp_1.append(label_map[label_ids[i][j]]) 971 | tmp1_idx.append(label_ids[i][j]) 972 | temp_2.append(label_map[logits[i][j]]) 973 | tmp2_idx.append(logits[i][j]) 974 | else: 975 | break 976 | y_true.append(temp_1) 977 | y_pred.append(temp_2) 978 | y_true_idx.append(tmp1_idx) 979 | y_pred_idx.append(tmp2_idx) 980 | 981 | #report = classification_report(y_true, y_pred, digits=4) 982 | sentence_list = [] 983 | test_data, imgs, _,_ = processor._read_mmtsv(os.path.join(args.data_dir, "test.txt"),args.image_filename, args.path_image) 984 | for i in range(len(y_pred)): 985 | sentence = test_data[i][0] 986 | sentence_list.append(sentence) 987 | 988 | reverse_label_map = {label: i for i, label in enumerate(label_list, 1)} 989 | acc, f1, p, r = evaluate(y_pred_idx, y_true_idx, sentence_list, reverse_label_map) 990 | logger.info("***** Test Eval results *****") 991 | print("Overall: ", p, r, f1) 992 | per_f1, per_p, per_r = evaluate_each_class(y_pred_idx, y_true_idx, sentence_list, reverse_label_map, 'POS') 993 | print("Positive: ", per_p, per_r, per_f1) 994 | loc_f1, loc_p, loc_r = evaluate_each_class(y_pred_idx, y_true_idx, sentence_list, reverse_label_map, 'NEU') 995 | print("Neutral: ", loc_p, loc_r, loc_f1) 996 | org_f1, org_p, org_r = evaluate_each_class(y_pred_idx, y_true_idx, sentence_list, reverse_label_map, 'NEG') 997 | print("Negative: ", org_p, org_r, org_f1) 998 | F_score_test = f1 999 | 1000 | if F_score_dev > max_dev_f1: 1001 | # Save a trained model and the associated configuration 1002 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 1003 | encoder_to_save = encoder.module if hasattr(encoder, 1004 | 'module') else encoder # Only save the model it-self 1005 | torch.save(model_to_save.state_dict(), output_model_file) 1006 | torch.save(encoder_to_save.state_dict(), output_encoder_file) 1007 | with open(output_config_file, 'w') as f: 1008 | f.write(model_to_save.config.to_json_string()) 1009 | label_map = {i: label for i, label in enumerate(label_list, 1)} 1010 | model_config = {"bert_model": args.bert_model, "do_lower": args.do_lower_case, 1011 | "max_seq_length": args.max_seq_length, "num_labels": len(label_list) + 1, 1012 | "label_map": label_map} 1013 | json.dump(model_config, open(os.path.join(args.output_dir, "model_config.json"), "w")) 1014 | max_dev_f1 = F_score_dev 1015 | best_dev_epoch = train_idx 1016 | if F_score_test > max_test_f1: 1017 | max_test_f1 = F_score_test 1018 | best_test_epoch = train_idx 1019 | 1020 | print("**************************************************") 1021 | print("The best epoch on the dev set: ", best_dev_epoch) 1022 | print("The best Micro-F1 score on the dev set: ", max_dev_f1) 1023 | print("The best epoch on the test set: ", best_test_epoch) 1024 | print("The best Micro-F1 score on the test set: ", max_test_f1) 1025 | print('\n') 1026 | 1027 | config = BertConfig(output_config_file) 1028 | if args.mm_model == 'MTCCMBert': 1029 | model = MTCCMBertForMMTokenClassificationCRF(config,args.use_roberta, layer_num1=args.layer_num1, layer_num2=args.layer_num2, 1030 | layer_num3=args.layer_num3, num_labels=num_labels, auxnum_labels = auxnum_labels) 1031 | 1032 | else: 1033 | print('please define your MNER Model') 1034 | 1035 | model.load_state_dict(torch.load(output_model_file)) 1036 | model.to(device) 1037 | encoder_state_dict = torch.load(output_encoder_file) 1038 | encoder.load_state_dict(encoder_state_dict) 1039 | encoder.to(device) 1040 | 1041 | if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 1042 | eval_examples = processor.get_test_examples(args.data_dir, args.image_filename, args.path_image) 1043 | eval_features = convert_mm_examples_to_features( 1044 | eval_examples, label_list, auxlabel_list, args.max_seq_length, tokenizer, args.crop_size, args.path_image) 1045 | logger.info("***** Running Test Evaluation with the Best Model on the Dev Set*****") 1046 | logger.info(" Num examples = %d", len(eval_examples)) 1047 | logger.info(" Batch size = %d", args.eval_batch_size) 1048 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 1049 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 1050 | all_added_input_mask = torch.tensor([f.added_input_mask for f in eval_features], dtype=torch.long) 1051 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 1052 | all_img_feats = torch.stack([f.img_feat for f in eval_features]) 1053 | all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) 1054 | all_auxlabel_ids = torch.tensor([f.auxlabel_id for f in eval_features], dtype=torch.long) 1055 | all_imagelabel = torch.tensor([f.imagelabel for f in eval_features], dtype=torch.float) 1056 | 1057 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_added_input_mask, all_segment_ids, all_img_feats, 1058 | all_label_ids, all_auxlabel_ids,all_imagelabel) 1059 | # Run prediction for full data 1060 | eval_sampler = SequentialSampler(eval_data) 1061 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 1062 | model.eval() 1063 | encoder.eval() 1064 | eval_loss, eval_accuracy = 0, 0 1065 | nb_eval_steps, nb_eval_examples = 0, 0 1066 | y_true = [] 1067 | y_pred = [] 1068 | y_true_idx = [] 1069 | y_pred_idx = [] 1070 | label_map = {i : label for i, label in enumerate(label_list,1)} 1071 | label_map[0] = "PAD" 1072 | for input_ids, input_mask, added_input_mask, segment_ids, img_feats, label_ids, auxlabel_ids,imagelabel in tqdm(eval_dataloader, desc="Evaluating"): 1073 | input_ids = input_ids.to(device) 1074 | input_mask = input_mask.to(device) 1075 | added_input_mask = added_input_mask.to(device) 1076 | segment_ids = segment_ids.to(device) 1077 | img_feats = img_feats.to(device) 1078 | label_ids = label_ids.to(device) 1079 | auxlabel_ids = auxlabel_ids.to(device) 1080 | trans_matrix = torch.tensor(trans_matrix).to(device) 1081 | 1082 | with torch.no_grad(): 1083 | imgs_f, img_mean, img_att = encoder(img_feats) 1084 | predicted_label_seq_ids = model(input_ids, segment_ids, input_mask, added_input_mask, img_att, trans_matrix,imagelabel,args.alpha, args.beta) 1085 | 1086 | logits = predicted_label_seq_ids 1087 | label_ids = label_ids.to('cpu').numpy() 1088 | input_mask = input_mask.to('cpu').numpy() 1089 | for i,mask in enumerate(input_mask): 1090 | temp_1 = [] 1091 | temp_2 = [] 1092 | tmp1_idx = [] 1093 | tmp2_idx = [] 1094 | for j, m in enumerate(mask): 1095 | if j == 0: 1096 | continue 1097 | if m: 1098 | if label_map[label_ids[i][j]] != "X" and label_map[label_ids[i][j]] != "": 1099 | temp_1.append(label_map[label_ids[i][j]]) 1100 | tmp1_idx.append(label_ids[i][j]) 1101 | temp_2.append(label_map[logits[i][j]]) 1102 | tmp2_idx.append(logits[i][j]) 1103 | else: 1104 | break 1105 | y_true.append(temp_1) 1106 | y_pred.append(temp_2) 1107 | y_true_idx.append(tmp1_idx) 1108 | y_pred_idx.append(tmp2_idx) 1109 | 1110 | sentence_list = [] 1111 | test_data, imgs, _,_ = processor._read_mmtsv(os.path.join(args.data_dir, "test.txt"),args.image_filename, args.path_image) 1112 | output_pred_file = os.path.join(args.output_dir, "mtmner_pred.txt") 1113 | fout = open(output_pred_file, 'w', encoding='UTF-8') 1114 | for i in range(len(y_pred)): 1115 | sentence = test_data[i][0] 1116 | sentence_list.append(sentence) 1117 | img = imgs[i] 1118 | samp_pred_label = y_pred[i] 1119 | samp_true_label = y_true[i] 1120 | fout.write(img+'\n') 1121 | fout.write(' '.join(sentence)+'\n') 1122 | fout.write(' '.join(samp_pred_label)+'\n') 1123 | fout.write(' '.join(samp_true_label)+'\n'+'\n') 1124 | fout.close() 1125 | logger.info("***** Test Eval results *****") 1126 | 1127 | reverse_label_map = {label: i for i, label in enumerate(label_list, 1)} 1128 | acc, f1, p, r = evaluate(y_pred_idx, y_true_idx, sentence_list, reverse_label_map) 1129 | print("Overall: ", p, r, f1) 1130 | per_f1, per_p, per_r = evaluate_each_class(y_pred_idx, y_true_idx, sentence_list, reverse_label_map, 'POS') 1131 | print("Positive: ", per_p, per_r, per_f1) 1132 | loc_f1, loc_p, loc_r = evaluate_each_class(y_pred_idx, y_true_idx, sentence_list, reverse_label_map, 'NEU') 1133 | print("Neutral: ", loc_p, loc_r, loc_f1) 1134 | org_f1, org_p, org_r = evaluate_each_class(y_pred_idx, y_true_idx, sentence_list, reverse_label_map, 'NEG') 1135 | print("Negative: ", org_p, org_r, org_f1) 1136 | 1137 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt") 1138 | with open(output_eval_file, "w") as writer: 1139 | #logger.info("\n%s", report) 1140 | #writer.write(report) 1141 | writer.write("Overall: " + str(p) + ' ' + str(r) + ' ' + str(f1) + '\n') 1142 | writer.write("Positive: " + str(per_p) + ' ' + str(per_r) + ' ' + str(per_f1) + '\n') 1143 | writer.write("Neutral: " + str(loc_p) + ' ' + str(loc_r) + ' ' + str(loc_f1) + '\n') 1144 | writer.write("Negative: " + str(org_p) + ' ' + str(org_r) + ' ' + str(org_f1) + '\n') 1145 | 1146 | 1147 | if __name__ == "__main__": 1148 | main() 1149 | -------------------------------------------------------------------------------- /run_cmmt_crf.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | for i in 'twitter2015' 'twitter2017' 3 | do 4 | echo 'run_cmmt_crf.py' 5 | echo ${i} 6 | echo ${k} 7 | PYTHONIOENCODING=utf-8 CUDA_VISIBLE_DEVICES=0 python run_cmmt_crf.py --task_name=${i} 8 | done 9 | --------------------------------------------------------------------------------