├── ANP_data
└── image label_folder
├── README.md
├── data
├── politic_data
│ ├── a.txt
│ ├── test_opinion.txt
│ ├── train_opinion.txt
│ └── valid_opinion.txt
├── twitter2015
│ ├── test.txt
│ ├── train.txt
│ └── valid.txt
└── twitter2017
│ ├── test.txt
│ ├── train.txt
│ └── valid.txt
├── image_data
└── image data folder.txt
├── log files
├── twitter2015_output
│ ├── config.json
│ ├── eval_results.txt
│ └── model_config.json
├── twitter2017_output
│ ├── config.json
│ ├── eval_results.txt
│ └── model_config.json
└── twitterpolitic_output
│ ├── config.json
│ ├── eval_results.txt
│ └── model_config.json
├── model
├── resnet
│ ├── __init__.py
│ ├── resnet.py
│ ├── resnet_model_folder.txt
│ └── resnet_utils.py
└── roberta-base-cased
│ └── Roberta_model_folder.txt
├── my_bert
├── __init__.py
├── __main__.py
├── bichannel_modeling.py
├── convert_tf_checkpoint_to_pytorch.py
├── file_utils.py
├── mner_modeling.py
├── optimization.py
└── tokenization.py
├── ner_evaluate.py
├── run_cmmt_crf.py
└── run_cmmt_crf.sh
/ANP_data/image label_folder:
--------------------------------------------------------------------------------
1 | Download the image label file via this link(https://drive.google.com/drive/folders/1i-EiyLS0RwsOw8cQIMMAKRRqy9VgXOhT?usp=sharing), and then put the associaled image label files into folder "./ANP_data/"
2 |
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Cross-Modal Multitask Transformer for End-to-End Multimodal Aspect-Based Sentiment Analysis
2 |
3 | #### Author: Li YANG, yang0666@e.ntu.edu.sg
4 |
5 | #### The Corresponding Paper:
6 | ##### Cross-modal multitask transformer for end-to-end multimodal aspect-based sentiment analysis
7 | ##### [[https://www.sciencedirect.com/science/article/abs/pii/S0306457324000840](https://www.sciencedirect.com/science/article/pii/S0306457322001479)](https://www.sciencedirect.com/science/article/abs/pii/S0306457322001479)
8 |
9 | ##### The framework of the CMMT model:
10 | ![alt text]
11 |
12 |
13 |
14 | ## Data
15 | - The preprocessed CoNLL format files are provided in this repo. For each tweet, the first line is its image id, and the following lines are its textual contents.
16 | - Step 1:Download each tweet's associated images via this link (https://drive.google.com/file/d/1PpvvncnQkgDNeBMKVgG2zFYuRhbL873g/view), and then put the associated images into folders "./image_data/twitter2015/" and "./image_data/twitter2017/";
17 | - The politician dataset can be get via: https://drive.google.com/file/d/1oa029MLk8I_J99pxBs7X9RaIbUHhhTNG/view?usp=sharing
18 | - Step 2: Download the image label file via this link(https://drive.google.com/drive/folders/1UaeSYJQCQzszRmBWdhA11LqXnWousn4G?usp=sharing), and then put the associaled image label files into folder "./ANP_data/"
19 | - Step 3: Download the pre-trained ResNet-152 via this link (https://download.pytorch.org/models/resnet152-b121ed2d.pth), and put the pre-trained ResNet-152 model under the folder './model/resnet/"
20 |
21 | - Step 4: Download the pre-trained roberta-base-cased from huggingface and put the pre-trained roberta model under the folder "./model/roberta-base-cased/"
22 | - Step 5: ANP information can be downloaded via https://drive.google.com/drive/folders/1UaeSYJQCQzszRmBWdhA11LqXnWousn4G?usp=share_link
23 |
24 | ## Requirement
25 | * PyTorch 1.0.0
26 | * Python 3.7
27 | * pytorch-crf 0.7.2
28 |
29 | ## Code Usage
30 |
31 | ### Training for CMMT
32 | - This is the training code of tuning parameters on the dev set, and testing on the test set. Note that you can change "CUDA_VISIBLE_DEVICES=2" based on your available GPUs.
33 |
34 | ```sh
35 | sh run_cmmt_crf.sh
36 | ```
37 |
38 | - We show our running logs on twitter-2015, twitter-2017 and political twitter in the folder "log files". Note that the results are a little bit lower than the results reported in our paper, since the experiments were run on different servers.
39 |
40 |
41 | ## Acknowledgements
42 | - Using these two datasets means you have read and accepted the copyrights set by Twitter and dataset providers.
43 | - Most of the codes are based on the codes provided by huggingface: https://github.com/huggingface/transformers.
44 |
45 | ## Citation Information:
46 | Yang, L., Na, J. C., & Yu, J. (2022). Cross-modal multitask transformer for end-to-end multimodal aspect-based sentiment analysis. Information Processing & Management, 59(5), 103038.
47 |
--------------------------------------------------------------------------------
/data/politic_data/a.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/image_data/image data folder.txt:
--------------------------------------------------------------------------------
1 | Download each tweet's associated images via this link (https://drive.google.com/file/d/1PpvvncnQkgDNeBMKVgG2zFYuRhbL873g/view), and then put the associated images into folders "./image_data/twitter2015/" and "./image_data/twitter2017/"
2 |
--------------------------------------------------------------------------------
/log files/twitter2015_output/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "RobertaForMaskedLM"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "bos_token_id": 0,
7 | "eos_token_id": 2,
8 | "hidden_act": "gelu",
9 | "hidden_dropout_prob": 0.1,
10 | "hidden_size": 768,
11 | "initializer_range": 0.02,
12 | "intermediate_size": 3072,
13 | "layer_norm_eps": 1e-05,
14 | "max_position_embeddings": 514,
15 | "model_type": "roberta",
16 | "num_attention_heads": 12,
17 | "num_hidden_layers": 12,
18 | "pad_token_id": 1,
19 | "type_vocab_size": 1,
20 | "vocab_size": 50265
21 | }
22 |
--------------------------------------------------------------------------------
/log files/twitter2015_output/eval_results.txt:
--------------------------------------------------------------------------------
1 | Overall: 0.6457765667574932 0.6856316297010607 0.6651075771749297
2 | Positive: 0.6655290102389079 0.6151419558359621 0.6393442622950819
3 | Neutral: 0.65 0.7495881383855024 0.6962509563886764
4 | Negative: 0.5648148148148148 0.5398230088495575 0.5520361990950227
5 |
--------------------------------------------------------------------------------
/log files/twitter2015_output/model_config.json:
--------------------------------------------------------------------------------
1 | {"bert_model": "/mnt/nfs-storage-titan/yoli_projects/roberta-base-cased", "do_lower": false, "max_seq_length": 128, "num_labels": 11, "label_map": {"1": "O", "2": "B-NEU", "3": "I-NEU", "4": "B-POS", "5": "I-POS", "6": "B-NEG", "7": "I-NEG", "8": "X", "9": "", "10": ""}}
--------------------------------------------------------------------------------
/log files/twitter2017_output/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "RobertaForMaskedLM"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "bos_token_id": 0,
7 | "eos_token_id": 2,
8 | "hidden_act": "gelu",
9 | "hidden_dropout_prob": 0.1,
10 | "hidden_size": 768,
11 | "initializer_range": 0.02,
12 | "intermediate_size": 3072,
13 | "layer_norm_eps": 1e-05,
14 | "max_position_embeddings": 514,
15 | "model_type": "roberta",
16 | "num_attention_heads": 12,
17 | "num_hidden_layers": 12,
18 | "pad_token_id": 1,
19 | "type_vocab_size": 1,
20 | "vocab_size": 50265
21 | }
22 |
--------------------------------------------------------------------------------
/log files/twitter2017_output/eval_results.txt:
--------------------------------------------------------------------------------
1 | Overall: 0.6758675078864353 0.6944894651539708 0.6850519584332534
2 | Positive: 0.6893203883495146 0.7200811359026369 0.7043650793650794
3 | Neutral: 0.678082191780822 0.6910994764397905 0.6845289541918756
4 | Negative: 0.6272189349112426 0.6309523809523809 0.629080118694362
5 |
--------------------------------------------------------------------------------
/log files/twitter2017_output/model_config.json:
--------------------------------------------------------------------------------
1 | {"bert_model": "/mnt/nfs-storage-titan/yoli_projects/roberta-base-cased", "do_lower": false, "max_seq_length": 128, "num_labels": 11, "label_map": {"1": "O", "2": "B-NEU", "3": "I-NEU", "4": "B-POS", "5": "I-POS", "6": "B-NEG", "7": "I-NEG", "8": "X", "9": "", "10": ""}}
--------------------------------------------------------------------------------
/log files/twitterpolitic_output/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "RobertaForMaskedLM"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "bos_token_id": 0,
7 | "eos_token_id": 2,
8 | "hidden_act": "gelu",
9 | "hidden_dropout_prob": 0.1,
10 | "hidden_size": 768,
11 | "initializer_range": 0.02,
12 | "intermediate_size": 3072,
13 | "layer_norm_eps": 1e-05,
14 | "max_position_embeddings": 514,
15 | "model_type": "roberta",
16 | "num_attention_heads": 12,
17 | "num_hidden_layers": 12,
18 | "pad_token_id": 1,
19 | "type_vocab_size": 1,
20 | "vocab_size": 50265
21 | }
22 |
--------------------------------------------------------------------------------
/log files/twitterpolitic_output/eval_results.txt:
--------------------------------------------------------------------------------
1 | Overall: 0.6526315789473685 0.657243816254417 0.6549295774647887
2 | Positive: 0.6197916666666666 0.6761363636363636 0.6467391304347826
3 | Neutral: 0.6260162601626016 0.6277173913043478 0.6268656716417911
4 | Negative: 0.7074829931972789 0.6819672131147541 0.6944908180300502
5 |
--------------------------------------------------------------------------------
/log files/twitterpolitic_output/model_config.json:
--------------------------------------------------------------------------------
1 | {"bert_model": "/mnt/nfs-storage-titan/yoli_projects/roberta-base-cased", "do_lower": false, "max_seq_length": 128, "num_labels": 11, "label_map": {"1": "O", "2": "B-NEU", "3": "I-NEU", "4": "B-POS", "5": "I-POS", "6": "B-NEG", "7": "I-NEG", "8": "X", "9": "", "10": ""}}
--------------------------------------------------------------------------------
/model/resnet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangli-hub/CMMT-Code/06d38658b0c0e3c585ca8fa17d354c353c676e84/model/resnet/__init__.py
--------------------------------------------------------------------------------
/model/resnet/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.utils.model_zoo as model_zoo
4 |
5 |
6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
7 | 'resnet152']
8 |
9 |
10 | model_urls = {
11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
16 | }
17 |
18 |
19 | def conv3x3(in_planes, out_planes, stride=1):
20 | "3x3 convolution with padding"
21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
22 | padding=1, bias=False)
23 |
24 |
25 | class BasicBlock(nn.Module):
26 | expansion = 1
27 |
28 | def __init__(self, inplanes, planes, stride=1, downsample=None):
29 | super(BasicBlock, self).__init__()
30 | self.conv1 = conv3x3(inplanes, planes, stride)
31 | self.bn1 = nn.BatchNorm2d(planes)
32 | self.relu = nn.ReLU(inplace=True)
33 | self.conv2 = conv3x3(planes, planes)
34 | self.bn2 = nn.BatchNorm2d(planes)
35 | self.downsample = downsample
36 | self.stride = stride
37 |
38 | def forward(self, x):
39 | residual = x
40 |
41 | out = self.conv1(x)
42 | out = self.bn1(out)
43 | out = self.relu(out)
44 |
45 | out = self.conv2(out)
46 | out = self.bn2(out)
47 |
48 | if self.downsample is not None:
49 | residual = self.downsample(x)
50 |
51 | out += residual
52 | out = self.relu(out)
53 |
54 | return out
55 |
56 |
57 | class Bottleneck(nn.Module):
58 | expansion = 4
59 |
60 | def __init__(self, inplanes, planes, stride=1, downsample=None):
61 | super(Bottleneck, self).__init__()
62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
63 | self.bn1 = nn.BatchNorm2d(planes)
64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
65 | padding=1, bias=False)
66 | self.bn2 = nn.BatchNorm2d(planes)
67 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
68 | self.bn3 = nn.BatchNorm2d(planes * 4)
69 | self.relu = nn.ReLU(inplace=True)
70 | self.downsample = downsample
71 | self.stride = stride
72 |
73 | def forward(self, x):
74 | residual = x
75 |
76 | out = self.conv1(x)
77 | out = self.bn1(out)
78 | out = self.relu(out)
79 |
80 | out = self.conv2(out)
81 | out = self.bn2(out)
82 | out = self.relu(out)
83 |
84 | out = self.conv3(out)
85 | out = self.bn3(out)
86 |
87 | if self.downsample is not None:
88 | residual = self.downsample(x)
89 |
90 | out += residual
91 | out = self.relu(out)
92 |
93 | return out
94 |
95 |
96 | class ResNet(nn.Module):
97 |
98 | def __init__(self, block, layers, num_classes=1000):
99 | self.inplanes = 64
100 | super(ResNet, self).__init__()
101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
102 | bias=False)
103 | self.bn1 = nn.BatchNorm2d(64)
104 | self.relu = nn.ReLU(inplace=True)
105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
106 | self.layer1 = self._make_layer(block, 64, layers[0])
107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
110 | self.avgpool = nn.AvgPool2d(7, stride=1)
111 | self.fc = nn.Linear(512 * block.expansion, num_classes)
112 |
113 | for m in self.modules():
114 | if isinstance(m, nn.Conv2d):
115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
116 | m.weight.data.normal_(0, math.sqrt(2. / n))
117 | elif isinstance(m, nn.BatchNorm2d):
118 | m.weight.data.fill_(1)
119 | m.bias.data.zero_()
120 |
121 | def _make_layer(self, block, planes, blocks, stride=1):
122 | downsample = None
123 | if stride != 1 or self.inplanes != planes * block.expansion:
124 | downsample = nn.Sequential(
125 | nn.Conv2d(self.inplanes, planes * block.expansion,
126 | kernel_size=1, stride=stride, bias=False),
127 | nn.BatchNorm2d(planes * block.expansion),
128 | )
129 |
130 | layers = []
131 | layers.append(block(self.inplanes, planes, stride, downsample))
132 | self.inplanes = planes * block.expansion
133 | for i in range(1, blocks):
134 | layers.append(block(self.inplanes, planes))
135 |
136 | return nn.Sequential(*layers)
137 |
138 | def forward(self, x):
139 | x = self.conv1(x)
140 | x = self.bn1(x)
141 | x = self.relu(x)
142 | x = self.maxpool(x)
143 |
144 | x = self.layer1(x)
145 | x = self.layer2(x)
146 | x = self.layer3(x)
147 | x = self.layer4(x)
148 |
149 | x = self.avgpool(x)
150 | x = x.view(x.size(0), -1)
151 | x = self.fc(x)
152 |
153 | return x
154 |
155 |
156 | def resnet18(pretrained=False, **kwargs):
157 | """Constructs a ResNet-18 model.
158 |
159 | Args:
160 | pretrained (bool): If True, returns a model pre-trained on ImageNet
161 | """
162 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
163 | if pretrained:
164 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
165 | return model
166 |
167 |
168 | def resnet34(pretrained=False, **kwargs):
169 | """Constructs a ResNet-34 model.
170 |
171 | Args:
172 | pretrained (bool): If True, returns a model pre-trained on ImageNet
173 | """
174 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
175 | if pretrained:
176 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
177 | return model
178 |
179 |
180 | def resnet50(pretrained=False, **kwargs):
181 | """Constructs a ResNet-50 model.
182 |
183 | Args:
184 | pretrained (bool): If True, returns a model pre-trained on ImageNet
185 | """
186 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
187 | if pretrained:
188 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
189 | return model
190 |
191 |
192 | def resnet101(pretrained=False, **kwargs):
193 | """Constructs a ResNet-101 model.
194 |
195 | Args:
196 | pretrained (bool): If True, returns a model pre-trained on ImageNet
197 | """
198 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
199 | if pretrained:
200 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
201 | return model
202 |
203 |
204 | def resnet152(pretrained=False, **kwargs):
205 | """Constructs a ResNet-152 model.
206 |
207 | Args:
208 | pretrained (bool): If True, returns a model pre-trained on ImageNet
209 | """
210 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
211 | if pretrained:
212 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
213 | return model
214 |
--------------------------------------------------------------------------------
/model/resnet/resnet_model_folder.txt:
--------------------------------------------------------------------------------
1 |
2 | Download the pre-trained ResNet-152 via this link (https://download.pytorch.org/models/resnet152-b121ed2d.pth), and put the pre-trained ResNet-152 model under the folder './model/resnet/"
3 |
--------------------------------------------------------------------------------
/model/resnet/resnet_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | import torch.nn.functional as F
5 |
6 | class myResnet(nn.Module):
7 | def __init__(self, resnet, if_fine_tune, device):
8 | super(myResnet, self).__init__()
9 | self.resnet = resnet
10 | self.if_fine_tune = if_fine_tune
11 | self.device = device
12 |
13 | def forward(self, x, att_size=7):
14 | x = self.resnet.conv1(x)
15 | x = self.resnet.bn1(x)
16 | x = self.resnet.relu(x)
17 | x = self.resnet.maxpool(x)
18 |
19 | x = self.resnet.layer1(x)
20 | x = self.resnet.layer2(x)
21 | x = self.resnet.layer3(x)
22 | x = self.resnet.layer4(x)
23 |
24 | fc = x.mean(3).mean(2)
25 | att = F.adaptive_avg_pool2d(x,[att_size,att_size])
26 |
27 | x = self.resnet.avgpool(x)
28 | x = x.view(x.size(0), -1)
29 |
30 | if not self.if_fine_tune:
31 |
32 | x= Variable(x.data)
33 | fc = Variable(fc.data)
34 | att = Variable(att.data)
35 |
36 | return x, fc, att
37 |
38 |
39 |
--------------------------------------------------------------------------------
/model/roberta-base-cased/Roberta_model_folder.txt:
--------------------------------------------------------------------------------
1 | Download the pre-trained roberta-base-cased and put the pre-trained roberta model under the folder "./model/roberta-base-cased/"
2 |
--------------------------------------------------------------------------------
/my_bert/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.4.0"
2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
3 | from .optimization import BertAdam
4 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE
--------------------------------------------------------------------------------
/my_bert/__main__.py:
--------------------------------------------------------------------------------
1 | # coding: utf8
2 | def main():
3 | import sys
4 | try:
5 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
6 | except ModuleNotFoundError:
7 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
8 | "In that case, it requires TensorFlow to be installed. Please see "
9 | "https://www.tensorflow.org/install/ for installation instructions.")
10 | raise
11 |
12 | if len(sys.argv) != 5:
13 | # pylint: disable=line-too-long
14 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
15 | else:
16 | PYTORCH_DUMP_OUTPUT = sys.argv.pop()
17 | TF_CONFIG = sys.argv.pop()
18 | TF_CHECKPOINT = sys.argv.pop()
19 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
20 |
21 | if __name__ == '__main__':
22 | main()
23 |
--------------------------------------------------------------------------------
/my_bert/bichannel_modeling.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """PyTorch BERT model."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import os
23 | import copy
24 | import json
25 | import math
26 | import logging
27 | import tarfile
28 | import tempfile
29 | import shutil
30 |
31 | import torch
32 | from torch import nn
33 | from torch.nn import CrossEntropyLoss
34 |
35 | from .file_utils import cached_path
36 |
37 | logger = logging.getLogger(__name__)
38 |
39 | PRETRAINED_MODEL_ARCHIVE_MAP = {
40 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
41 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
42 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
43 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
44 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
45 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
46 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
47 | }
48 | CONFIG_NAME = 'bert_config.json'
49 | WEIGHTS_NAME = 'pytorch_model.bin'
50 |
51 | def gelu(x):
52 | """Implementation of the gelu activation function.
53 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
54 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
55 | """
56 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
57 |
58 |
59 | def swish(x):
60 | return x * torch.sigmoid(x)
61 |
62 |
63 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
64 |
65 |
66 | class BertConfig(object):
67 | """Configuration class to store the configuration of a `BertModel`.
68 | """
69 | def __init__(self,
70 | vocab_size_or_config_json_file,
71 | hidden_size=768,
72 | num_hidden_layers=12,
73 | num_attention_heads=12,
74 | intermediate_size=3072,
75 | hidden_act="gelu",
76 | hidden_dropout_prob=0.1,
77 | attention_probs_dropout_prob=0.1,
78 | max_position_embeddings=512,
79 | type_vocab_size=2,
80 | initializer_range=0.02):
81 | """Constructs BertConfig.
82 |
83 | Args:
84 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
85 | hidden_size: Size of the encoder layers and the pooler layer.
86 | num_hidden_layers: Number of hidden layers in the Transformer encoder.
87 | num_attention_heads: Number of attention heads for each attention layer in
88 | the Transformer encoder.
89 | intermediate_size: The size of the "intermediate" (i.e., feed-forward)
90 | layer in the Transformer encoder.
91 | hidden_act: The non-linear activation function (function or string) in the
92 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
93 | hidden_dropout_prob: The dropout probabilitiy for all fully connected
94 | layers in the embeddings, encoder, and pooler.
95 | attention_probs_dropout_prob: The dropout ratio for the attention
96 | probabilities.
97 | max_position_embeddings: The maximum sequence length that this model might
98 | ever be used with. Typically set this to something large just in case
99 | (e.g., 512 or 1024 or 2048).
100 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into
101 | `BertModel`.
102 | initializer_range: The sttdev of the truncated_normal_initializer for
103 | initializing all weight matrices.
104 | """
105 | if isinstance(vocab_size_or_config_json_file, str):
106 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
107 | json_config = json.loads(reader.read())
108 | for key, value in json_config.items():
109 | self.__dict__[key] = value
110 | elif isinstance(vocab_size_or_config_json_file, int):
111 | self.vocab_size = vocab_size_or_config_json_file
112 | self.hidden_size = hidden_size
113 | self.num_hidden_layers = num_hidden_layers
114 | self.num_attention_heads = num_attention_heads
115 | self.hidden_act = hidden_act
116 | self.intermediate_size = intermediate_size
117 | self.hidden_dropout_prob = hidden_dropout_prob
118 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
119 | self.max_position_embeddings = max_position_embeddings
120 | self.type_vocab_size = type_vocab_size
121 | self.initializer_range = initializer_range
122 | else:
123 | raise ValueError("First argument must be either a vocabulary size (int)"
124 | "or the path to a pretrained model config file (str)")
125 |
126 | @classmethod
127 | def from_dict(cls, json_object):
128 | """Constructs a `BertConfig` from a Python dictionary of parameters."""
129 | config = BertConfig(vocab_size_or_config_json_file=-1)
130 | for key, value in json_object.items():
131 | config.__dict__[key] = value
132 | return config
133 |
134 | @classmethod
135 | def from_json_file(cls, json_file):
136 | """Constructs a `BertConfig` from a json file of parameters."""
137 | with open(json_file, "r", encoding='utf-8') as reader:
138 | text = reader.read()
139 | return cls.from_dict(json.loads(text))
140 |
141 | def __repr__(self):
142 | return str(self.to_json_string())
143 |
144 | def to_dict(self):
145 | """Serializes this instance to a Python dictionary."""
146 | output = copy.deepcopy(self.__dict__)
147 | return output
148 |
149 | def to_json_string(self):
150 | """Serializes this instance to a JSON string."""
151 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
152 |
153 | try:
154 | from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
155 | except ImportError:
156 | print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
157 | class BertLayerNorm(nn.Module):
158 | def __init__(self, hidden_size, eps=1e-12):
159 | """Construct a layernorm module in the TF style (epsilon inside the square root).
160 | """
161 | super(BertLayerNorm, self).__init__()
162 | self.weight = nn.Parameter(torch.ones(hidden_size))
163 | self.bias = nn.Parameter(torch.zeros(hidden_size))
164 | self.variance_epsilon = eps
165 |
166 | def forward(self, x):
167 | u = x.mean(-1, keepdim=True)
168 | s = (x - u).pow(2).mean(-1, keepdim=True)
169 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
170 | return self.weight * x + self.bias
171 |
172 | class BertEmbeddings(nn.Module):
173 | """Construct the embeddings from word, position and token_type embeddings.
174 | """
175 | def __init__(self, config):
176 | super(BertEmbeddings, self).__init__()
177 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
178 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
179 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
180 |
181 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
182 | # any TensorFlow checkpoint file
183 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
184 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
185 |
186 | def forward(self, input_ids, token_type_ids=None):
187 | seq_length = input_ids.size(1)
188 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
189 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
190 | if token_type_ids is None:
191 | token_type_ids = torch.zeros_like(input_ids)
192 |
193 | words_embeddings = self.word_embeddings(input_ids)
194 | position_embeddings = self.position_embeddings(position_ids)
195 | token_type_embeddings = self.token_type_embeddings(token_type_ids)
196 |
197 | embeddings = words_embeddings + position_embeddings + token_type_embeddings
198 | embeddings = self.LayerNorm(embeddings)
199 | embeddings = self.dropout(embeddings)
200 | return embeddings
201 |
202 |
203 | class BertSelfAttention(nn.Module):
204 | def __init__(self, config):
205 | super(BertSelfAttention, self).__init__()
206 | if config.hidden_size % config.num_attention_heads != 0:
207 | raise ValueError(
208 | "The hidden size (%d) is not a multiple of the number of attention "
209 | "heads (%d)" % (config.hidden_size, config.num_attention_heads))
210 | self.num_attention_heads = config.num_attention_heads
211 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
212 | self.all_head_size = self.num_attention_heads * self.attention_head_size
213 |
214 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
215 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
216 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
217 |
218 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
219 |
220 | def transpose_for_scores(self, x):
221 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
222 | x = x.view(*new_x_shape)
223 | return x.permute(0, 2, 1, 3)
224 |
225 | def forward(self, hidden_states, attention_mask):
226 | mixed_query_layer = self.query(hidden_states)
227 | mixed_key_layer = self.key(hidden_states)
228 | mixed_value_layer = self.value(hidden_states)
229 |
230 | query_layer = self.transpose_for_scores(mixed_query_layer)
231 | key_layer = self.transpose_for_scores(mixed_key_layer)
232 | value_layer = self.transpose_for_scores(mixed_value_layer)
233 |
234 | # Take the dot product between "query" and "key" to get the raw attention scores.
235 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
236 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
237 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
238 | attention_scores = attention_scores + attention_mask
239 |
240 | # Normalize the attention scores to probabilities.
241 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
242 |
243 | # This is actually dropping out entire tokens to attend to, which might
244 | # seem a bit unusual, but is taken from the original Transformer paper.
245 | attention_probs = self.dropout(attention_probs)
246 |
247 | context_layer = torch.matmul(attention_probs, value_layer)
248 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
249 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
250 | context_layer = context_layer.view(*new_context_layer_shape)
251 | return context_layer
252 |
253 | class BertCoAttention(nn.Module):
254 | def __init__(self, config):
255 | super(BertCoAttention, self).__init__()
256 | if config.hidden_size % config.num_attention_heads != 0:
257 | raise ValueError(
258 | "The hidden size (%d) is not a multiple of the number of attention "
259 | "heads (%d)" % (config.hidden_size, config.num_attention_heads))
260 | self.num_attention_heads = config.num_attention_heads
261 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
262 | self.all_head_size = self.num_attention_heads * self.attention_head_size
263 |
264 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
265 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
266 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
267 |
268 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
269 |
270 | def transpose_for_scores(self, x):
271 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
272 | x = x.view(*new_x_shape)
273 | return x.permute(0, 2, 1, 3)
274 |
275 | def forward(self, s1_hidden_states, s2_hidden_states, s2_attention_mask):
276 | mixed_query_layer = self.query(s1_hidden_states)
277 | mixed_key_layer = self.key(s2_hidden_states)
278 | mixed_value_layer = self.value(s2_hidden_states)
279 |
280 | query_layer = self.transpose_for_scores(mixed_query_layer)
281 | key_layer = self.transpose_for_scores(mixed_key_layer)
282 | value_layer = self.transpose_for_scores(mixed_value_layer)
283 |
284 | # Take the dot product between "query" and "key" to get the raw attention scores.
285 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
286 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
287 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
288 | attention_scores = attention_scores + s2_attention_mask
289 |
290 | # Normalize the attention scores to probabilities.
291 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
292 |
293 | # This is actually dropping out entire tokens to attend to, which might
294 | # seem a bit unusual, but is taken from the original Transformer paper.
295 | attention_probs = self.dropout(attention_probs)
296 |
297 | context_layer = torch.matmul(attention_probs, value_layer)
298 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
299 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
300 | context_layer = context_layer.view(*new_context_layer_shape)
301 | return context_layer
302 |
303 |
304 | class BertSelfOutput(nn.Module):
305 | def __init__(self, config):
306 | super(BertSelfOutput, self).__init__()
307 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
308 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
309 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
310 |
311 | def forward(self, hidden_states, input_tensor):
312 | hidden_states = self.dense(hidden_states)
313 | hidden_states = self.dropout(hidden_states)
314 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
315 | return hidden_states
316 |
317 |
318 | class BertAttention(nn.Module):
319 | def __init__(self, config):
320 | super(BertAttention, self).__init__()
321 | self.self = BertSelfAttention(config)
322 | self.output = BertSelfOutput(config)
323 |
324 | def forward(self, input_tensor, attention_mask):
325 | self_output = self.self(input_tensor, attention_mask)
326 | attention_output = self.output(self_output, input_tensor)
327 | return attention_output
328 |
329 |
330 | class BertCrossAttention(nn.Module):
331 | def __init__(self, config):
332 | super(BertCrossAttention, self).__init__()
333 | self.self = BertCoAttention(config)
334 | self.output = BertSelfOutput(config)
335 |
336 | def forward(self, s1_input_tensor, s2_input_tensor, s2_attention_mask):
337 | self_output = self.self(s1_input_tensor, s2_input_tensor, s2_attention_mask)
338 | attention_output = self.output(self_output, s1_input_tensor)
339 | return attention_output
340 |
341 |
342 | class BertIntermediate(nn.Module):
343 | def __init__(self, config):
344 | super(BertIntermediate, self).__init__()
345 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
346 | self.intermediate_act_fn = ACT2FN[config.hidden_act] \
347 | if isinstance(config.hidden_act, str) else config.hidden_act
348 |
349 | def forward(self, hidden_states):
350 | hidden_states = self.dense(hidden_states)
351 | hidden_states = self.intermediate_act_fn(hidden_states)
352 | return hidden_states
353 |
354 |
355 | class BertOutput(nn.Module):
356 | def __init__(self, config):
357 | super(BertOutput, self).__init__()
358 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
359 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
360 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
361 |
362 | def forward(self, hidden_states, input_tensor):
363 | hidden_states = self.dense(hidden_states)
364 | hidden_states = self.dropout(hidden_states)
365 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
366 | return hidden_states
367 |
368 |
369 | class BertLayer(nn.Module):
370 | def __init__(self, config):
371 | super(BertLayer, self).__init__()
372 | self.attention = BertAttention(config)
373 | self.intermediate = BertIntermediate(config)
374 | self.output = BertOutput(config)
375 |
376 | def forward(self, hidden_states, attention_mask):
377 | attention_output = self.attention(hidden_states, attention_mask)
378 | intermediate_output = self.intermediate(attention_output)
379 | layer_output = self.output(intermediate_output, attention_output)
380 | return layer_output
381 |
382 | class BertCrossAttentionLayer(nn.Module):
383 | def __init__(self, config):
384 | super(BertCrossAttentionLayer, self).__init__()
385 | self.attention = BertCrossAttention(config)
386 | self.intermediate = BertIntermediate(config)
387 | self.output = BertOutput(config)
388 |
389 | def forward(self, s1_hidden_states, s2_hidden_states, s2_attention_mask):
390 | attention_output = self.attention(s1_hidden_states, s2_hidden_states, s2_attention_mask)
391 | intermediate_output = self.intermediate(attention_output)
392 | layer_output = self.output(intermediate_output, attention_output)
393 | return layer_output
394 |
395 | class BertEncoder(nn.Module):
396 | def __init__(self, config):
397 | super(BertEncoder, self).__init__()
398 | layer = BertLayer(config)
399 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
400 |
401 | def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
402 | all_encoder_layers = []
403 | for layer_module in self.layer:
404 | hidden_states = layer_module(hidden_states, attention_mask)
405 | if output_all_encoded_layers:
406 | all_encoder_layers.append(hidden_states)
407 | if not output_all_encoded_layers:
408 | all_encoder_layers.append(hidden_states)
409 | return all_encoder_layers
410 |
411 |
412 | class BertPooler(nn.Module):
413 | def __init__(self, config):
414 | super(BertPooler, self).__init__()
415 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
416 | self.activation = nn.Tanh()
417 |
418 | def forward(self, hidden_states):
419 | # We "pool" the model by simply taking the hidden state corresponding
420 | # to the first token.
421 | first_token_tensor = hidden_states[:, 0]
422 | pooled_output = self.dense(first_token_tensor)
423 | pooled_output = self.activation(pooled_output)
424 | return pooled_output
425 |
426 |
427 | class BertPredictionHeadTransform(nn.Module):
428 | def __init__(self, config):
429 | super(BertPredictionHeadTransform, self).__init__()
430 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
431 | self.transform_act_fn = ACT2FN[config.hidden_act] \
432 | if isinstance(config.hidden_act, str) else config.hidden_act
433 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
434 |
435 | def forward(self, hidden_states):
436 | hidden_states = self.dense(hidden_states)
437 | hidden_states = self.transform_act_fn(hidden_states)
438 | hidden_states = self.LayerNorm(hidden_states)
439 | return hidden_states
440 |
441 |
442 | class BertLMPredictionHead(nn.Module):
443 | def __init__(self, config, bert_model_embedding_weights):
444 | super(BertLMPredictionHead, self).__init__()
445 | self.transform = BertPredictionHeadTransform(config)
446 |
447 | # The output weights are the same as the input embeddings, but there is
448 | # an output-only bias for each token.
449 | self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
450 | bert_model_embedding_weights.size(0),
451 | bias=False)
452 | self.decoder.weight = bert_model_embedding_weights
453 | self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
454 |
455 | def forward(self, hidden_states):
456 | hidden_states = self.transform(hidden_states)
457 | hidden_states = self.decoder(hidden_states) + self.bias
458 | return hidden_states
459 |
460 |
461 | class BertOnlyMLMHead(nn.Module):
462 | def __init__(self, config, bert_model_embedding_weights):
463 | super(BertOnlyMLMHead, self).__init__()
464 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
465 |
466 | def forward(self, sequence_output):
467 | prediction_scores = self.predictions(sequence_output)
468 | return prediction_scores
469 |
470 |
471 | class BertOnlyNSPHead(nn.Module):
472 | def __init__(self, config):
473 | super(BertOnlyNSPHead, self).__init__()
474 | self.seq_relationship = nn.Linear(config.hidden_size, 2)
475 |
476 | def forward(self, pooled_output):
477 | seq_relationship_score = self.seq_relationship(pooled_output)
478 | return seq_relationship_score
479 |
480 |
481 | class BertPreTrainingHeads(nn.Module):
482 | def __init__(self, config, bert_model_embedding_weights):
483 | super(BertPreTrainingHeads, self).__init__()
484 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
485 | self.seq_relationship = nn.Linear(config.hidden_size, 2)
486 |
487 | def forward(self, sequence_output, pooled_output):
488 | prediction_scores = self.predictions(sequence_output)
489 | seq_relationship_score = self.seq_relationship(pooled_output)
490 | return prediction_scores, seq_relationship_score
491 |
492 |
493 | class PreTrainedBertModel(nn.Module):
494 | """ An abstract class to handle weights initialization and
495 | a simple interface for dowloading and loading pretrained models.
496 | """
497 | def __init__(self, config, *inputs, **kwargs):
498 | super(PreTrainedBertModel, self).__init__()
499 | if not isinstance(config, BertConfig):
500 | raise ValueError(
501 | "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
502 | "To create a model from a Google pretrained model use "
503 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
504 | self.__class__.__name__, self.__class__.__name__
505 | ))
506 | self.config = config
507 |
508 | def init_bert_weights(self, module):
509 | """ Initialize the weights.
510 | """
511 | if isinstance(module, (nn.Linear, nn.Embedding)):
512 | # Slightly different from the TF version which uses truncated_normal for initialization
513 | # cf https://github.com/pytorch/pytorch/pull/5617
514 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
515 | elif isinstance(module, BertLayerNorm):
516 | module.bias.data.zero_()
517 | module.weight.data.fill_(1.0)
518 | if isinstance(module, nn.Linear) and module.bias is not None:
519 | module.bias.data.zero_()
520 |
521 | @classmethod
522 | def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
523 | """
524 | Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict.
525 | Download and cache the pre-trained model file if needed.
526 |
527 | Params:
528 | pretrained_model_name: either:
529 | - a str with the name of a pre-trained model to load selected in the list of:
530 | . `bert-base-uncased`
531 | . `bert-large-uncased`
532 | . `bert-base-cased`
533 | . `bert-large-cased`
534 | . `bert-base-multilingual-uncased`
535 | . `bert-base-multilingual-cased`
536 | . `bert-base-chinese`
537 | - a path or url to a pretrained model archive containing:
538 | . `bert_config.json` a configuration file for the model
539 | . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
540 | cache_dir: an optional path to a folder in which the pre-trained models will be cached.
541 | state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
542 | *inputs, **kwargs: additional input for the specific Bert class
543 | (ex: num_labels for BertForSequenceClassification)
544 | """
545 | if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
546 | archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
547 | else:
548 | archive_file = pretrained_model_name
549 | # redirect to the cache, if necessary
550 | try:
551 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
552 | except FileNotFoundError:
553 | logger.error(
554 | "Model name '{}' was not found in model name list ({}). "
555 | "We assumed '{}' was a path or url but couldn't find any file "
556 | "associated to this path or url.".format(
557 | pretrained_model_name,
558 | ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
559 | archive_file))
560 | return None
561 | if resolved_archive_file == archive_file:
562 | logger.info("loading archive file {}".format(archive_file))
563 | else:
564 | logger.info("loading archive file {} from cache at {}".format(
565 | archive_file, resolved_archive_file))
566 | tempdir = None
567 | if os.path.isdir(resolved_archive_file):
568 | serialization_dir = resolved_archive_file
569 | else:
570 | # Extract archive to temp dir
571 | tempdir = tempfile.mkdtemp()
572 | logger.info("extracting archive file {} to temp dir {}".format(
573 | resolved_archive_file, tempdir))
574 | with tarfile.open(resolved_archive_file, 'r:gz') as archive:
575 | archive.extractall(tempdir)
576 | serialization_dir = tempdir
577 | # Load config
578 | config_file = os.path.join(serialization_dir, CONFIG_NAME)
579 | config = BertConfig.from_json_file(config_file)
580 | logger.info("Model config {}".format(config))
581 | # Instantiate model.
582 | model = cls(config, *inputs, **kwargs)
583 | if state_dict is None:
584 | weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
585 | state_dict = torch.load(weights_path)
586 |
587 | old_keys = []
588 | new_keys = []
589 | for key in state_dict.keys():
590 | new_key = None
591 | if 'gamma' in key:
592 | new_key = key.replace('gamma', 'weight')
593 | if 'beta' in key:
594 | new_key = key.replace('beta', 'bias')
595 | if new_key:
596 | old_keys.append(key)
597 | new_keys.append(new_key)
598 | for old_key, new_key in zip(old_keys, new_keys):
599 | state_dict[new_key] = state_dict.pop(old_key)
600 |
601 | missing_keys = []
602 | unexpected_keys = []
603 | error_msgs = []
604 | # copy state_dict so _load_from_state_dict can modify it
605 | metadata = getattr(state_dict, '_metadata', None)
606 | state_dict = state_dict.copy()
607 | if metadata is not None:
608 | state_dict._metadata = metadata
609 |
610 | def load(module, prefix=''):
611 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
612 | module._load_from_state_dict(
613 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
614 | for name, child in module._modules.items():
615 | if child is not None:
616 | load(child, prefix + name + '.')
617 | load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
618 | if len(missing_keys) > 0:
619 | logger.info("Weights of {} not initialized from pretrained model: {}".format(
620 | model.__class__.__name__, missing_keys))
621 | if len(unexpected_keys) > 0:
622 | logger.info("Weights from pretrained model not used in {}: {}".format(
623 | model.__class__.__name__, unexpected_keys))
624 | if tempdir:
625 | # Clean up temp dir
626 | shutil.rmtree(tempdir)
627 | return model
628 |
629 |
630 | class BertModel(PreTrainedBertModel):
631 | """BERT model ("Bidirectional Embedding Representations from a Transformer").
632 |
633 | Params:
634 | config: a BertConfig class instance with the configuration to build a new model
635 |
636 | Inputs:
637 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
638 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
639 | `extract_features.py`, `run_classifier.py` and `run_squad.py`)
640 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
641 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
642 | a `sentence B` token (see BERT paper for more details).
643 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
644 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
645 | input sequence length in the current batch. It's the mask that we typically use for attention when
646 | a batch has varying length sentences.
647 | `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
648 |
649 | Outputs: Tuple of (encoded_layers, pooled_output)
650 | `encoded_layers`: controled by `output_all_encoded_layers` argument:
651 | - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
652 | of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
653 | encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
654 | - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
655 | to the last attention block of shape [batch_size, sequence_length, hidden_size],
656 | `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
657 | classifier pretrained on top of the hidden state associated to the first character of the
658 | input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
659 |
660 | Example usage:
661 | ```python
662 | # Already been converted into WordPiece token ids
663 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
664 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
665 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
666 |
667 | config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
668 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
669 |
670 | model = modeling.BertModel(config=config)
671 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
672 | ```
673 | """
674 | def __init__(self, config):
675 | super(BertModel, self).__init__(config)
676 | self.embeddings = BertEmbeddings(config)
677 | self.encoder = BertEncoder(config)
678 | self.pooler = BertPooler(config)
679 | self.apply(self.init_bert_weights)
680 |
681 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
682 | if attention_mask is None:
683 | attention_mask = torch.ones_like(input_ids)
684 | if token_type_ids is None:
685 | token_type_ids = torch.zeros_like(input_ids)
686 |
687 | # We create a 3D attention mask from a 2D tensor mask.
688 | # Sizes are [batch_size, 1, 1, to_seq_length]
689 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
690 | # this attention mask is more simple than the triangular masking of causal attention
691 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
692 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
693 |
694 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
695 | # masked positions, this operation will create a tensor which is 0.0 for
696 | # positions we want to attend and -10000.0 for masked positions.
697 | # Since we are adding it to the raw scores before the softmax, this is
698 | # effectively the same as removing these entirely.
699 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
700 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
701 |
702 | embedding_output = self.embeddings(input_ids, token_type_ids)
703 | encoded_layers = self.encoder(embedding_output,
704 | extended_attention_mask,
705 | output_all_encoded_layers=output_all_encoded_layers)
706 | sequence_output = encoded_layers[-1]
707 | pooled_output = self.pooler(sequence_output)
708 | if not output_all_encoded_layers:
709 | encoded_layers = encoded_layers[-1]
710 | return encoded_layers, pooled_output
711 |
712 |
713 | class BertForPreTraining(PreTrainedBertModel):
714 | """BERT model with pre-training heads.
715 | This module comprises the BERT model followed by the two pre-training heads:
716 | - the masked language modeling head, and
717 | - the next sentence classification head.
718 |
719 | Params:
720 | config: a BertConfig class instance with the configuration to build a new model.
721 |
722 | Inputs:
723 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
724 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
725 | `extract_features.py`, `run_classifier.py` and `run_squad.py`)
726 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
727 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
728 | a `sentence B` token (see BERT paper for more details).
729 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
730 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
731 | input sequence length in the current batch. It's the mask that we typically use for attention when
732 | a batch has varying length sentences.
733 | `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
734 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
735 | is only computed for the labels set in [0, ..., vocab_size]
736 | `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
737 | with indices selected in [0, 1].
738 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence.
739 |
740 | Outputs:
741 | if `masked_lm_labels` and `next_sentence_label` are not `None`:
742 | Outputs the total_loss which is the sum of the masked language modeling loss and the next
743 | sentence classification loss.
744 | if `masked_lm_labels` or `next_sentence_label` is `None`:
745 | Outputs a tuple comprising
746 | - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
747 | - the next sentence classification logits of shape [batch_size, 2].
748 |
749 | Example usage:
750 | ```python
751 | # Already been converted into WordPiece token ids
752 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
753 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
754 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
755 |
756 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
757 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
758 |
759 | model = BertForPreTraining(config)
760 | masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
761 | ```
762 | """
763 | def __init__(self, config):
764 | super(BertForPreTraining, self).__init__(config)
765 | self.bert = BertModel(config)
766 | self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
767 | self.apply(self.init_bert_weights)
768 |
769 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None):
770 | sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
771 | output_all_encoded_layers=False)
772 | prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
773 |
774 | if masked_lm_labels is not None and next_sentence_label is not None:
775 | loss_fct = CrossEntropyLoss(ignore_index=-1)
776 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
777 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
778 | total_loss = masked_lm_loss + next_sentence_loss
779 | return total_loss
780 | else:
781 | return prediction_scores, seq_relationship_score
782 |
783 |
784 | class BertForMaskedLM(PreTrainedBertModel):
785 | """BERT model with the masked language modeling head.
786 | This module comprises the BERT model followed by the masked language modeling head.
787 |
788 | Params:
789 | config: a BertConfig class instance with the configuration to build a new model.
790 |
791 | Inputs:
792 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
793 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
794 | `extract_features.py`, `run_classifier.py` and `run_squad.py`)
795 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
796 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
797 | a `sentence B` token (see BERT paper for more details).
798 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
799 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
800 | input sequence length in the current batch. It's the mask that we typically use for attention when
801 | a batch has varying length sentences.
802 | `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
803 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
804 | is only computed for the labels set in [0, ..., vocab_size]
805 |
806 | Outputs:
807 | if `masked_lm_labels` is not `None`:
808 | Outputs the masked language modeling loss.
809 | if `masked_lm_labels` is `None`:
810 | Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size].
811 |
812 | Example usage:
813 | ```python
814 | # Already been converted into WordPiece token ids
815 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
816 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
817 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
818 |
819 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
820 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
821 |
822 | model = BertForMaskedLM(config)
823 | masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
824 | ```
825 | """
826 | def __init__(self, config):
827 | super(BertForMaskedLM, self).__init__(config)
828 | self.bert = BertModel(config)
829 | self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
830 | self.apply(self.init_bert_weights)
831 |
832 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None):
833 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask,
834 | output_all_encoded_layers=False)
835 | prediction_scores = self.cls(sequence_output)
836 |
837 | if masked_lm_labels is not None:
838 | loss_fct = CrossEntropyLoss(ignore_index=-1)
839 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
840 | return masked_lm_loss
841 | else:
842 | return prediction_scores
843 |
844 |
845 | class BertForNextSentencePrediction(PreTrainedBertModel):
846 | """BERT model with next sentence prediction head.
847 | This module comprises the BERT model followed by the next sentence classification head.
848 |
849 | Params:
850 | config: a BertConfig class instance with the configuration to build a new model.
851 |
852 | Inputs:
853 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
854 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
855 | `extract_features.py`, `run_classifier.py` and `run_squad.py`)
856 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
857 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
858 | a `sentence B` token (see BERT paper for more details).
859 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
860 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
861 | input sequence length in the current batch. It's the mask that we typically use for attention when
862 | a batch has varying length sentences.
863 | `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
864 | with indices selected in [0, 1].
865 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence.
866 |
867 | Outputs:
868 | if `next_sentence_label` is not `None`:
869 | Outputs the total_loss which is the sum of the masked language modeling loss and the next
870 | sentence classification loss.
871 | if `next_sentence_label` is `None`:
872 | Outputs the next sentence classification logits of shape [batch_size, 2].
873 |
874 | Example usage:
875 | ```python
876 | # Already been converted into WordPiece token ids
877 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
878 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
879 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
880 |
881 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
882 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
883 |
884 | model = BertForNextSentencePrediction(config)
885 | seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
886 | ```
887 | """
888 | def __init__(self, config):
889 | super(BertForNextSentencePrediction, self).__init__(config)
890 | self.bert = BertModel(config)
891 | self.cls = BertOnlyNSPHead(config)
892 | self.apply(self.init_bert_weights)
893 |
894 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None):
895 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
896 | output_all_encoded_layers=False)
897 | seq_relationship_score = self.cls( pooled_output)
898 |
899 | if next_sentence_label is not None:
900 | loss_fct = CrossEntropyLoss(ignore_index=-1)
901 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
902 | return next_sentence_loss
903 | else:
904 | return seq_relationship_score
905 |
906 | '''
907 | class BertForSequenceClassification(PreTrainedBertModel):
908 | """BERT model for classification with one additional layer.
909 | """
910 | def __init__(self, config, num_labels=2):
911 | super(BertForSequenceClassification, self).__init__(config)
912 | self.num_labels = num_labels
913 | self.bert = BertModel(config)
914 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
915 | self.img_attention = BertLayer(config)
916 | self.img_pooler = BertPooler(config)
917 | self.classifier = nn.Linear(config.hidden_size, num_labels)
918 | self.apply(self.init_bert_weights)
919 |
920 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
921 | sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
922 |
923 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
924 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
925 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
926 |
927 | img_att_text_output_layer = self.img_attention(sequence_output, extended_attention_mask)
928 | img_att_text_output = self.img_pooler(img_att_text_output_layer)
929 |
930 | pooled_output = self.dropout(img_att_text_output)
931 | logits = self.classifier(pooled_output)
932 |
933 | if labels is not None:
934 | loss_fct = CrossEntropyLoss()
935 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
936 | return loss
937 | else:
938 | return logits
939 | '''
940 |
941 | class BertForSequenceClassification(PreTrainedBertModel):
942 | """BERT model for classification with one additional layer.
943 | """
944 | def __init__(self, config, num_labels=2):
945 | super(BertForSequenceClassification, self).__init__(config)
946 | self.num_labels = num_labels
947 | self.bert = BertModel(config)
948 | self.s1_bert = BertModel(config)
949 | self.s2_bert = BertModel(config)
950 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
951 | self.s1_attention = BertCrossAttentionLayer(config)
952 | self.s1_pooler = BertPooler(config)
953 | self.s2_attention = BertCrossAttentionLayer(config)
954 | self.s2_pooler = BertPooler(config)
955 | self.classifier = nn.Linear(config.hidden_size*3, num_labels)
956 | self.apply(self.init_bert_weights)
957 |
958 | def forward(self, input_ids, s1_input_ids, s2_input_ids, token_type_ids=None, s1_type_ids=None, s2_type_ids=None, \
959 | attention_mask=None, s1_mask=None, s2_mask=None, labels=None, copy_flag=False):
960 | sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
961 | if copy_flag:
962 | self.s1_bert = copy.deepcopy(self.bert)
963 | self.s2_bert = copy.deepcopy(self.bert)
964 | s1_output, s1_pooled_output = self.s1_bert(s1_input_ids, s1_type_ids, s1_mask, output_all_encoded_layers=False)
965 | s2_output, s2_pooled_output = self.s2_bert(s2_input_ids, s2_type_ids, s2_mask, output_all_encoded_layers=False)
966 |
967 | extended_s1_mask = s1_mask.unsqueeze(1).unsqueeze(2)
968 | extended_s1_mask = extended_s1_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
969 | extended_s1_mask = (1.0 - extended_s1_mask) * -10000.0
970 |
971 | extended_s2_mask = s2_mask.unsqueeze(1).unsqueeze(2)
972 | extended_s2_mask = extended_s2_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
973 | extended_s2_mask = (1.0 - extended_s2_mask) * -10000.0
974 |
975 | s1_cross_output_layer = self.s1_attention(s1_output, s2_output, extended_s2_mask)
976 | s1_cross_output = self.s1_pooler(s1_cross_output_layer)
977 |
978 | s2_cross_output_layer = self.s2_attention(s2_output, s1_output, extended_s1_mask)
979 | s2_cross_output = self.s2_pooler(s2_cross_output_layer)
980 |
981 | final_output = torch.cat((pooled_output, s1_cross_output, s2_cross_output), dim=-1)
982 | pooled_output = self.dropout(final_output)
983 | logits = self.classifier(pooled_output)
984 |
985 | if labels is not None:
986 | loss_fct = CrossEntropyLoss()
987 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
988 | return loss
989 | else:
990 | return logits
991 |
992 | class BertForMultipleChoice(PreTrainedBertModel):
993 | """BERT model for multiple choice tasks.
994 | This module is composed of the BERT model with a linear layer on top of
995 | the pooled output.
996 |
997 | Params:
998 | `config`: a BertConfig class instance with the configuration to build a new model.
999 | `num_choices`: the number of classes for the classifier. Default = 2.
1000 |
1001 | Inputs:
1002 | `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
1003 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
1004 | `extract_features.py`, `run_classifier.py` and `run_squad.py`)
1005 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length]
1006 | with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A`
1007 | and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
1008 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices
1009 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
1010 | input sequence length in the current batch. It's the mask that we typically use for attention when
1011 | a batch has varying length sentences.
1012 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
1013 | with indices selected in [0, ..., num_choices].
1014 |
1015 | Outputs:
1016 | if `labels` is not `None`:
1017 | Outputs the CrossEntropy classification loss of the output with the labels.
1018 | if `labels` is `None`:
1019 | Outputs the classification logits of shape [batch_size, num_labels].
1020 |
1021 | Example usage:
1022 | ```python
1023 | # Already been converted into WordPiece token ids
1024 | input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]])
1025 | input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]])
1026 | token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]])
1027 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
1028 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
1029 |
1030 | num_choices = 2
1031 |
1032 | model = BertForMultipleChoice(config, num_choices)
1033 | logits = model(input_ids, token_type_ids, input_mask)
1034 | ```
1035 | """
1036 | def __init__(self, config, num_choices=2):
1037 | super(BertForMultipleChoice, self).__init__(config)
1038 | self.num_choices = num_choices
1039 | self.bert = BertModel(config)
1040 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
1041 | self.classifier = nn.Linear(config.hidden_size, 1)
1042 | self.apply(self.init_bert_weights)
1043 |
1044 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
1045 | flat_input_ids = input_ids.view(-1, input_ids.size(-1))
1046 | flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
1047 | flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1))
1048 | _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
1049 | pooled_output = self.dropout(pooled_output)
1050 | logits = self.classifier(pooled_output)
1051 | reshaped_logits = logits.view(-1, self.num_choices)
1052 |
1053 | if labels is not None:
1054 | loss_fct = CrossEntropyLoss()
1055 | loss = loss_fct(reshaped_logits, labels)
1056 | return loss
1057 | else:
1058 | return reshaped_logits
1059 |
1060 |
1061 | class BertForTokenClassification(PreTrainedBertModel):
1062 | """BERT model for token-level classification.
1063 | This module is composed of the BERT model with a linear layer on top of
1064 | the full hidden state of the last layer.
1065 |
1066 | Params:
1067 | `config`: a BertConfig class instance with the configuration to build a new model.
1068 | `num_labels`: the number of classes for the classifier. Default = 2.
1069 |
1070 | Inputs:
1071 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
1072 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
1073 | `extract_features.py`, `run_classifier.py` and `run_squad.py`)
1074 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
1075 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
1076 | a `sentence B` token (see BERT paper for more details).
1077 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
1078 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
1079 | input sequence length in the current batch. It's the mask that we typically use for attention when
1080 | a batch has varying length sentences.
1081 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
1082 | with indices selected in [0, ..., num_labels].
1083 |
1084 | Outputs:
1085 | if `labels` is not `None`:
1086 | Outputs the CrossEntropy classification loss of the output with the labels.
1087 | if `labels` is `None`:
1088 | Outputs the classification logits of shape [batch_size, sequence_length, num_labels].
1089 |
1090 | Example usage:
1091 | ```python
1092 | # Already been converted into WordPiece token ids
1093 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
1094 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
1095 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
1096 |
1097 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
1098 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
1099 |
1100 | num_labels = 2
1101 |
1102 | model = BertForTokenClassification(config, num_labels)
1103 | logits = model(input_ids, token_type_ids, input_mask)
1104 | ```
1105 | """
1106 | def __init__(self, config, num_labels=2):
1107 | super(BertForTokenClassification, self).__init__(config)
1108 | self.num_labels = num_labels
1109 | self.bert = BertModel(config)
1110 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
1111 | self.classifier = nn.Linear(config.hidden_size, num_labels)
1112 | self.apply(self.init_bert_weights)
1113 |
1114 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
1115 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
1116 | sequence_output = self.dropout(sequence_output)
1117 | logits = self.classifier(sequence_output)
1118 |
1119 | if labels is not None:
1120 | loss_fct = CrossEntropyLoss()
1121 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1122 | return loss
1123 | else:
1124 | return logits
1125 |
1126 |
1127 | class BertForQuestionAnswering(PreTrainedBertModel):
1128 | """BERT model for Question Answering (span extraction).
1129 | This module is composed of the BERT model with a linear layer on top of
1130 | the sequence output that computes start_logits and end_logits
1131 |
1132 | Params:
1133 | `config`: a BertConfig class instance with the configuration to build a new model.
1134 |
1135 | Inputs:
1136 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
1137 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
1138 | `extract_features.py`, `run_classifier.py` and `run_squad.py`)
1139 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
1140 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
1141 | a `sentence B` token (see BERT paper for more details).
1142 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
1143 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
1144 | input sequence length in the current batch. It's the mask that we typically use for attention when
1145 | a batch has varying length sentences.
1146 | `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
1147 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken
1148 | into account for computing the loss.
1149 | `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size].
1150 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken
1151 | into account for computing the loss.
1152 |
1153 | Outputs:
1154 | if `start_positions` and `end_positions` are not `None`:
1155 | Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
1156 | if `start_positions` or `end_positions` is `None`:
1157 | Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
1158 | position tokens of shape [batch_size, sequence_length].
1159 |
1160 | Example usage:
1161 | ```python
1162 | # Already been converted into WordPiece token ids
1163 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
1164 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
1165 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
1166 |
1167 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
1168 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
1169 |
1170 | model = BertForQuestionAnswering(config)
1171 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
1172 | ```
1173 | """
1174 | def __init__(self, config):
1175 | super(BertForQuestionAnswering, self).__init__(config)
1176 | self.bert = BertModel(config)
1177 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
1178 | # self.dropout = nn.Dropout(config.hidden_dropout_prob)
1179 | self.qa_outputs = nn.Linear(config.hidden_size, 2)
1180 | self.apply(self.init_bert_weights)
1181 |
1182 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None):
1183 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
1184 | logits = self.qa_outputs(sequence_output)
1185 | start_logits, end_logits = logits.split(1, dim=-1)
1186 | start_logits = start_logits.squeeze(-1)
1187 | end_logits = end_logits.squeeze(-1)
1188 |
1189 | if start_positions is not None and end_positions is not None:
1190 | # If we are on multi-GPU, split add a dimension
1191 | if len(start_positions.size()) > 1:
1192 | start_positions = start_positions.squeeze(-1)
1193 | if len(end_positions.size()) > 1:
1194 | end_positions = end_positions.squeeze(-1)
1195 | # sometimes the start/end positions are outside our model inputs, we ignore these terms
1196 | ignored_index = start_logits.size(1)
1197 | start_positions.clamp_(0, ignored_index)
1198 | end_positions.clamp_(0, ignored_index)
1199 |
1200 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1201 | start_loss = loss_fct(start_logits, start_positions)
1202 | end_loss = loss_fct(end_logits, end_positions)
1203 | total_loss = (start_loss + end_loss) / 2
1204 | return total_loss
1205 | else:
1206 | return start_logits, end_logits
1207 |
--------------------------------------------------------------------------------
/my_bert/convert_tf_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert BERT checkpoint."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 | import re
23 | import argparse
24 | import tensorflow as tf
25 | import torch
26 | import numpy as np
27 |
28 | from .modeling import BertConfig, BertForPreTraining
29 |
30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
31 | config_path = os.path.abspath(bert_config_file)
32 | tf_path = os.path.abspath(tf_checkpoint_path)
33 | print("Converting TensorFlow checkpoint from {} with config at {}".format(tf_path, config_path))
34 | # Load weights from TF model
35 | init_vars = tf.train.list_variables(tf_path)
36 | names = []
37 | arrays = []
38 | for name, shape in init_vars:
39 | print("Loading TF weight {} with shape {}".format(name, shape))
40 | array = tf.train.load_variable(tf_path, name)
41 | names.append(name)
42 | arrays.append(array)
43 |
44 | # Initialise PyTorch model
45 | config = BertConfig.from_json_file(bert_config_file)
46 | print("Building PyTorch model from configuration: {}".format(str(config)))
47 | model = BertForPreTraining(config)
48 |
49 | for name, array in zip(names, arrays):
50 | name = name.split('/')
51 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
52 | # which are not required for using pretrained model
53 | if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
54 | print("Skipping {}".format("/".join(name)))
55 | continue
56 | pointer = model
57 | for m_name in name:
58 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
59 | l = re.split(r'_(\d+)', m_name)
60 | else:
61 | l = [m_name]
62 | if l[0] == 'kernel' or l[0] == 'gamma':
63 | pointer = getattr(pointer, 'weight')
64 | elif l[0] == 'output_bias' or l[0] == 'beta':
65 | pointer = getattr(pointer, 'bias')
66 | elif l[0] == 'output_weights':
67 | pointer = getattr(pointer, 'weight')
68 | else:
69 | pointer = getattr(pointer, l[0])
70 | if len(l) >= 2:
71 | num = int(l[1])
72 | pointer = pointer[num]
73 | if m_name[-11:] == '_embeddings':
74 | pointer = getattr(pointer, 'weight')
75 | elif m_name == 'kernel':
76 | array = np.transpose(array)
77 | try:
78 | assert pointer.shape == array.shape
79 | except AssertionError as e:
80 | e.args += (pointer.shape, array.shape)
81 | raise
82 | print("Initialize PyTorch weight {}".format(name))
83 | pointer.data = torch.from_numpy(array)
84 |
85 | # Save pytorch-model
86 | print("Save PyTorch model to {}".format(pytorch_dump_path))
87 | torch.save(model.state_dict(), pytorch_dump_path)
88 |
89 |
90 | if __name__ == "__main__":
91 | parser = argparse.ArgumentParser()
92 | ## Required parameters
93 | parser.add_argument("--tf_checkpoint_path",
94 | default = None,
95 | type = str,
96 | required = True,
97 | help = "Path the TensorFlow checkpoint path.")
98 | parser.add_argument("--bert_config_file",
99 | default = None,
100 | type = str,
101 | required = True,
102 | help = "The config json file corresponding to the pre-trained BERT model. \n"
103 | "This specifies the model architecture.")
104 | parser.add_argument("--pytorch_dump_path",
105 | default = None,
106 | type = str,
107 | required = True,
108 | help = "Path to the output PyTorch model.")
109 | args = parser.parse_args()
110 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
111 | args.bert_config_file,
112 | args.pytorch_dump_path)
113 |
--------------------------------------------------------------------------------
/my_bert/file_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for working with the local dataset cache.
3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4 | Copyright by the AllenNLP authors.
5 | """
6 | from __future__ import (absolute_import, division, print_function, unicode_literals)
7 |
8 | import sys
9 | import json
10 | import logging
11 | import os
12 | import shutil
13 | import tempfile
14 | import fnmatch
15 | from functools import wraps
16 | from hashlib import sha256
17 | import sys
18 | from io import open
19 |
20 | import boto3
21 | import requests
22 | from botocore.exceptions import ClientError
23 | from tqdm import tqdm
24 |
25 | try:
26 | from torch.hub import _get_torch_home
27 | torch_cache_home = _get_torch_home()
28 | except ImportError:
29 | torch_cache_home = os.path.expanduser(
30 | os.getenv('TORCH_HOME', os.path.join(
31 | os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
32 | default_cache_path = os.path.join(torch_cache_home, 'pytorch_pretrained_bert')
33 |
34 | try:
35 | from urllib.parse import urlparse
36 | except ImportError:
37 | from urlparse import urlparse
38 |
39 | try:
40 | from pathlib import Path
41 | PYTORCH_PRETRAINED_BERT_CACHE = Path(
42 | os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path))
43 | except (AttributeError, ImportError):
44 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
45 | default_cache_path)
46 |
47 | CONFIG_NAME = "config.json"
48 | WEIGHTS_NAME = "pytorch_model.bin"
49 |
50 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
51 |
52 |
53 | def url_to_filename(url, etag=None):
54 | """
55 | Convert `url` into a hashed filename in a repeatable way.
56 | If `etag` is specified, append its hash to the url's, delimited
57 | by a period.
58 | """
59 | url_bytes = url.encode('utf-8')
60 | url_hash = sha256(url_bytes)
61 | filename = url_hash.hexdigest()
62 |
63 | if etag:
64 | etag_bytes = etag.encode('utf-8')
65 | etag_hash = sha256(etag_bytes)
66 | filename += '.' + etag_hash.hexdigest()
67 |
68 | return filename
69 |
70 |
71 | def filename_to_url(filename, cache_dir=None):
72 | """
73 | Return the url and etag (which may be ``None``) stored for `filename`.
74 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
75 | """
76 | if cache_dir is None:
77 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
78 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
79 | cache_dir = str(cache_dir)
80 |
81 | cache_path = os.path.join(cache_dir, filename)
82 | if not os.path.exists(cache_path):
83 | raise EnvironmentError("file {} not found".format(cache_path))
84 |
85 | meta_path = cache_path + '.json'
86 | if not os.path.exists(meta_path):
87 | raise EnvironmentError("file {} not found".format(meta_path))
88 |
89 | with open(meta_path, encoding="utf-8") as meta_file:
90 | metadata = json.load(meta_file)
91 | url = metadata['url']
92 | etag = metadata['etag']
93 |
94 | return url, etag
95 |
96 |
97 | def cached_path(url_or_filename, cache_dir=None):
98 | """
99 | Given something that might be a URL (or might be a local path),
100 | determine which. If it's a URL, download the file and cache it, and
101 | return the path to the cached file. If it's already a local path,
102 | make sure the file exists and then return the path.
103 | """
104 | if cache_dir is None:
105 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
106 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
107 | url_or_filename = str(url_or_filename)
108 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
109 | cache_dir = str(cache_dir)
110 |
111 | parsed = urlparse(url_or_filename)
112 |
113 | if parsed.scheme in ('http', 'https', 's3'):
114 | # URL, so get it from the cache (downloading if necessary)
115 | return get_from_cache(url_or_filename, cache_dir)
116 | elif os.path.exists(url_or_filename):
117 | # File, and it exists.
118 | return url_or_filename
119 | elif parsed.scheme == '':
120 | # File, but it doesn't exist.
121 | raise EnvironmentError("file {} not found".format(url_or_filename))
122 | else:
123 | # Something unknown
124 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
125 |
126 |
127 | def split_s3_path(url):
128 | """Split a full s3 path into the bucket name and path."""
129 | parsed = urlparse(url)
130 | if not parsed.netloc or not parsed.path:
131 | raise ValueError("bad s3 path {}".format(url))
132 | bucket_name = parsed.netloc
133 | s3_path = parsed.path
134 | # Remove '/' at beginning of path.
135 | if s3_path.startswith("/"):
136 | s3_path = s3_path[1:]
137 | return bucket_name, s3_path
138 |
139 |
140 | def s3_request(func):
141 | """
142 | Wrapper function for s3 requests in order to create more helpful error
143 | messages.
144 | """
145 |
146 | @wraps(func)
147 | def wrapper(url, *args, **kwargs):
148 | try:
149 | return func(url, *args, **kwargs)
150 | except ClientError as exc:
151 | if int(exc.response["Error"]["Code"]) == 404:
152 | raise EnvironmentError("file {} not found".format(url))
153 | else:
154 | raise
155 |
156 | return wrapper
157 |
158 |
159 | @s3_request
160 | def s3_etag(url):
161 | """Check ETag on S3 object."""
162 | s3_resource = boto3.resource("s3")
163 | bucket_name, s3_path = split_s3_path(url)
164 | s3_object = s3_resource.Object(bucket_name, s3_path)
165 | return s3_object.e_tag
166 |
167 |
168 | @s3_request
169 | def s3_get(url, temp_file):
170 | """Pull a file directly from S3."""
171 | s3_resource = boto3.resource("s3")
172 | bucket_name, s3_path = split_s3_path(url)
173 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
174 |
175 |
176 | def http_get(url, temp_file):
177 | req = requests.get(url, stream=True)
178 | content_length = req.headers.get('Content-Length')
179 | total = int(content_length) if content_length is not None else None
180 | progress = tqdm(unit="B", total=total)
181 | for chunk in req.iter_content(chunk_size=1024):
182 | if chunk: # filter out keep-alive new chunks
183 | progress.update(len(chunk))
184 | temp_file.write(chunk)
185 | progress.close()
186 |
187 |
188 | def get_from_cache(url, cache_dir=None):
189 | """
190 | Given a URL, look for the corresponding dataset in the local cache.
191 | If it's not there, download it. Then return the path to the cached file.
192 | """
193 | if cache_dir is None:
194 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
195 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
196 | cache_dir = str(cache_dir)
197 |
198 | if not os.path.exists(cache_dir):
199 | os.makedirs(cache_dir)
200 |
201 | # Get eTag to add to filename, if it exists.
202 | if url.startswith("s3://"):
203 | etag = s3_etag(url)
204 | else:
205 | try:
206 | response = requests.head(url, allow_redirects=True)
207 | if response.status_code != 200:
208 | etag = None
209 | else:
210 | etag = response.headers.get("ETag")
211 | except EnvironmentError:
212 | etag = None
213 |
214 | if sys.version_info[0] == 2 and etag is not None:
215 | etag = etag.decode('utf-8')
216 | filename = url_to_filename(url, etag)
217 |
218 | # get cache path to put the file
219 | cache_path = os.path.join(cache_dir, filename)
220 |
221 | # If we don't have a connection (etag is None) and can't identify the file
222 | # try to get the last downloaded one
223 | if not os.path.exists(cache_path) and etag is None:
224 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
225 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
226 | if matching_files:
227 | cache_path = os.path.join(cache_dir, matching_files[-1])
228 |
229 | if not os.path.exists(cache_path):
230 | # Download to temporary file, then copy to cache dir once finished.
231 | # Otherwise you get corrupt cache entries if the download gets interrupted.
232 | with tempfile.NamedTemporaryFile() as temp_file:
233 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
234 |
235 | # GET file object
236 | if url.startswith("s3://"):
237 | s3_get(url, temp_file)
238 | else:
239 | http_get(url, temp_file)
240 |
241 | # we are copying the file before closing it, so flush to avoid truncation
242 | temp_file.flush()
243 | # shutil.copyfileobj() starts at the current position, so go to the start
244 | temp_file.seek(0)
245 |
246 | logger.info("copying %s to cache at %s", temp_file.name, cache_path)
247 | with open(cache_path, 'wb') as cache_file:
248 | shutil.copyfileobj(temp_file, cache_file)
249 |
250 | logger.info("creating metadata file for %s", cache_path)
251 | meta = {'url': url, 'etag': etag}
252 | meta_path = cache_path + '.json'
253 | with open(meta_path, 'w') as meta_file:
254 | output_string = json.dumps(meta)
255 | if sys.version_info[0] == 2 and isinstance(output_string, str):
256 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2
257 | meta_file.write(output_string)
258 |
259 | logger.info("removing temp file %s", temp_file.name)
260 |
261 | return cache_path
262 |
263 |
264 | def read_set_from_file(filename):
265 | '''
266 | Extract a de-duped collection (set) of text from a file.
267 | Expected file format is one item per line.
268 | '''
269 | collection = set()
270 | with open(filename, 'r', encoding='utf-8') as file_:
271 | for line in file_:
272 | collection.add(line.rstrip())
273 | return collection
274 |
275 |
276 | def get_file_extension(path, dot=True, lower=True):
277 | ext = os.path.splitext(path)[1]
278 | ext = ext if dot else ext[1:]
279 | return ext.lower() if lower else ext
--------------------------------------------------------------------------------
/my_bert/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch optimization for BERT model."""
16 |
17 | import math
18 | import torch
19 | from torch.optim import Optimizer
20 | from torch.optim.optimizer import required
21 | from torch.nn.utils import clip_grad_norm_
22 |
23 | def warmup_cosine(x, warmup=0.002):
24 | if x < warmup:
25 | return x/warmup
26 | return 0.5 * (1.0 + torch.cos(math.pi * x))
27 |
28 | def warmup_constant(x, warmup=0.002):
29 | if x < warmup:
30 | return x/warmup
31 | return 1.0
32 |
33 | def warmup_linear(x, warmup=0.002):
34 | if x < warmup:
35 | return x/warmup
36 | return 1.0 - x
37 |
38 | SCHEDULES = {
39 | 'warmup_cosine':warmup_cosine,
40 | 'warmup_constant':warmup_constant,
41 | 'warmup_linear':warmup_linear,
42 | }
43 |
44 |
45 | class BertAdam(Optimizer):
46 | """Implements BERT version of Adam algorithm with weight decay fix.
47 | Params:
48 | lr: learning rate
49 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
50 | t_total: total number of training steps for the learning
51 | rate schedule, -1 means constant learning rate. Default: -1
52 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
53 | b1: Adams b1. Default: 0.9
54 | b2: Adams b2. Default: 0.999
55 | e: Adams epsilon. Default: 1e-6
56 | weight_decay: Weight decay. Default: 0.01
57 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
58 | """
59 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
60 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
61 | max_grad_norm=1.0):
62 | if lr is not required and lr < 0.0:
63 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
64 | if schedule not in SCHEDULES:
65 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
66 | if not 0.0 <= warmup < 1.0 and not warmup == -1:
67 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
68 | if not 0.0 <= b1 < 1.0:
69 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
70 | if not 0.0 <= b2 < 1.0:
71 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
72 | if not e >= 0.0:
73 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
74 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
75 | b1=b1, b2=b2, e=e, weight_decay=weight_decay,
76 | max_grad_norm=max_grad_norm)
77 | super(BertAdam, self).__init__(params, defaults)
78 |
79 | def get_lr(self):
80 | lr = []
81 | for group in self.param_groups:
82 | for p in group['params']:
83 | state = self.state[p]
84 | if len(state) == 0:
85 | return [0]
86 | if group['t_total'] != -1:
87 | schedule_fct = SCHEDULES[group['schedule']]
88 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
89 | else:
90 | lr_scheduled = group['lr']
91 | lr.append(lr_scheduled)
92 | return lr
93 |
94 | def step(self, closure=None):
95 | """Performs a single optimization step.
96 |
97 | Arguments:
98 | closure (callable, optional): A closure that reevaluates the model
99 | and returns the loss.
100 | """
101 | loss = None
102 | if closure is not None:
103 | loss = closure()
104 |
105 | for group in self.param_groups:
106 | for p in group['params']:
107 | if p.grad is None:
108 | continue
109 | grad = p.grad.data
110 | if grad.is_sparse:
111 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
112 |
113 | state = self.state[p]
114 |
115 | # State initialization
116 | if len(state) == 0:
117 | state['step'] = 0
118 | # Exponential moving average of gradient values
119 | state['next_m'] = torch.zeros_like(p.data)
120 | # Exponential moving average of squared gradient values
121 | state['next_v'] = torch.zeros_like(p.data)
122 |
123 | next_m, next_v = state['next_m'], state['next_v']
124 | beta1, beta2 = group['b1'], group['b2']
125 |
126 | # Add grad clipping
127 | if group['max_grad_norm'] > 0:
128 | clip_grad_norm_(p, group['max_grad_norm'])
129 |
130 | # Decay the first and second moment running average coefficient
131 | # In-place operations to update the averages at the same time
132 | next_m.mul_(beta1).add_(1 - beta1, grad)
133 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
134 | update = next_m / (next_v.sqrt() + group['e'])
135 |
136 | # Just adding the square of the weights to the loss function is *not*
137 | # the correct way of using L2 regularization/weight decay with Adam,
138 | # since that will interact with the m and v parameters in strange ways.
139 | #
140 | # Instead we want to decay the weights in a manner that doesn't interact
141 | # with the m/v parameters. This is equivalent to adding the square
142 | # of the weights to the loss with plain (non-momentum) SGD.
143 | if group['weight_decay'] > 0.0:
144 | update += group['weight_decay'] * p.data
145 |
146 | if group['t_total'] != -1:
147 | schedule_fct = SCHEDULES[group['schedule']]
148 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
149 | else:
150 | lr_scheduled = group['lr']
151 |
152 | update_with_lr = lr_scheduled * update
153 | p.data.add_(-update_with_lr)
154 |
155 | state['step'] += 1
156 |
157 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
158 | # No bias correction
159 | # bias_correction1 = 1 - beta1 ** state['step']
160 | # bias_correction2 = 1 - beta2 ** state['step']
161 |
162 | return loss
163 |
--------------------------------------------------------------------------------
/my_bert/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import unicodedata
23 | import os
24 | import logging
25 |
26 | from .file_utils import cached_path
27 |
28 | logger = logging.getLogger(__name__)
29 |
30 | PRETRAINED_VOCAB_ARCHIVE_MAP = {
31 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
32 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
33 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
34 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
35 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
36 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
37 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
38 | }
39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
40 | 'bert-base-uncased': 512,
41 | 'bert-large-uncased': 512,
42 | 'bert-base-cased': 512,
43 | 'bert-large-cased': 512,
44 | 'bert-base-multilingual-uncased': 512,
45 | 'bert-base-multilingual-cased': 512,
46 | 'bert-base-chinese': 512,
47 | }
48 | VOCAB_NAME = 'vocab.txt'
49 |
50 |
51 | def load_vocab(vocab_file):
52 | """Loads a vocabulary file into a dictionary."""
53 | vocab = collections.OrderedDict()
54 | index = 0
55 | with open(vocab_file, "r", encoding="utf-8") as reader:
56 | while True:
57 | token = reader.readline()
58 | if not token:
59 | break
60 | token = token.strip()
61 | vocab[token] = index
62 | index += 1
63 | return vocab
64 |
65 |
66 | def whitespace_tokenize(text):
67 | """Runs basic whitespace cleaning and splitting on a peice of text."""
68 | text = text.strip()
69 | if not text:
70 | return []
71 | tokens = text.split()
72 | return tokens
73 |
74 |
75 | class BertTokenizer(object):
76 | """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
77 |
78 | def __init__(self, vocab_file, do_lower_case=True, max_len=None,
79 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
80 | if not os.path.isfile(vocab_file):
81 | raise ValueError(
82 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
83 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
84 | self.vocab = load_vocab(vocab_file)
85 | self.ids_to_tokens = collections.OrderedDict(
86 | [(ids, tok) for tok, ids in self.vocab.items()])
87 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
88 | never_split=never_split)
89 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
90 | self.max_len = max_len if max_len is not None else int(1e12)
91 |
92 | def tokenize(self, text):
93 | split_tokens = []
94 | for token in self.basic_tokenizer.tokenize(text):
95 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
96 | split_tokens.append(sub_token)
97 | return split_tokens
98 |
99 | def convert_tokens_to_ids(self, tokens):
100 | """Converts a sequence of tokens into ids using the vocab."""
101 | ids = []
102 | for token in tokens:
103 | ids.append(self.vocab[token])
104 | if len(ids) > self.max_len:
105 | raise ValueError(
106 | "Token indices sequence length is longer than the specified maximum "
107 | " sequence length for this BERT model ({} > {}). Running this"
108 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
109 | )
110 | return ids
111 |
112 | def convert_ids_to_tokens(self, ids):
113 | """Converts a sequence of ids in wordpiece tokens using the vocab."""
114 | tokens = []
115 | for i in ids:
116 | tokens.append(self.ids_to_tokens[i])
117 | return tokens
118 |
119 | @classmethod
120 | def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
121 | """
122 | Instantiate a PreTrainedBertModel from a pre-trained model file.
123 | Download and cache the pre-trained model file if needed.
124 | """
125 | if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP:
126 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name]
127 | else:
128 | vocab_file = pretrained_model_name
129 | if os.path.isdir(vocab_file):
130 | vocab_file = os.path.join(vocab_file, VOCAB_NAME)
131 | # redirect to the cache, if necessary
132 | try:
133 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
134 | except FileNotFoundError:
135 | logger.error(
136 | "Model name '{}' was not found in model name list ({}). "
137 | "We assumed '{}' was a path or url but couldn't find any file "
138 | "associated to this path or url.".format(
139 | pretrained_model_name,
140 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
141 | vocab_file))
142 | return None
143 | if resolved_vocab_file == vocab_file:
144 | logger.info("loading vocabulary file {}".format(vocab_file))
145 | else:
146 | logger.info("loading vocabulary file {} from cache at {}".format(
147 | vocab_file, resolved_vocab_file))
148 | if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
149 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
150 | # than the number of positional embeddings
151 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name]
152 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
153 | # Instantiate tokenizer.
154 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
155 | return tokenizer
156 |
157 |
158 | class BasicTokenizer(object):
159 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
160 |
161 | def __init__(self,
162 | do_lower_case=True,
163 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
164 | """Constructs a BasicTokenizer.
165 |
166 | Args:
167 | do_lower_case: Whether to lower case the input.
168 | """
169 | self.do_lower_case = do_lower_case
170 | self.never_split = never_split
171 |
172 | def tokenize(self, text):
173 | """Tokenizes a piece of text."""
174 | text = self._clean_text(text)
175 | # This was added on November 1st, 2018 for the multilingual and Chinese
176 | # models. This is also applied to the English models now, but it doesn't
177 | # matter since the English models were not trained on any Chinese data
178 | # and generally don't have any Chinese data in them (there are Chinese
179 | # characters in the vocabulary because Wikipedia does have some Chinese
180 | # words in the English Wikipedia.).
181 | text = self._tokenize_chinese_chars(text)
182 | orig_tokens = whitespace_tokenize(text)
183 | split_tokens = []
184 | for token in orig_tokens:
185 | if self.do_lower_case and token not in self.never_split:
186 | token = token.lower()
187 | token = self._run_strip_accents(token)
188 | split_tokens.extend(self._run_split_on_punc(token))
189 |
190 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
191 | return output_tokens
192 |
193 | def _run_strip_accents(self, text):
194 | """Strips accents from a piece of text."""
195 | text = unicodedata.normalize("NFD", text)
196 | output = []
197 | for char in text:
198 | cat = unicodedata.category(char)
199 | if cat == "Mn":
200 | continue
201 | output.append(char)
202 | return "".join(output)
203 |
204 | def _run_split_on_punc(self, text):
205 | """Splits punctuation on a piece of text."""
206 | if text in self.never_split:
207 | return [text]
208 | chars = list(text)
209 | i = 0
210 | start_new_word = True
211 | output = []
212 | while i < len(chars):
213 | char = chars[i]
214 | if _is_punctuation(char):
215 | output.append([char])
216 | start_new_word = True
217 | else:
218 | if start_new_word:
219 | output.append([])
220 | start_new_word = False
221 | output[-1].append(char)
222 | i += 1
223 |
224 | return ["".join(x) for x in output]
225 |
226 | def _tokenize_chinese_chars(self, text):
227 | """Adds whitespace around any CJK character."""
228 | output = []
229 | for char in text:
230 | cp = ord(char)
231 | if self._is_chinese_char(cp):
232 | output.append(" ")
233 | output.append(char)
234 | output.append(" ")
235 | else:
236 | output.append(char)
237 | return "".join(output)
238 |
239 | def _is_chinese_char(self, cp):
240 | """Checks whether CP is the codepoint of a CJK character."""
241 | # This defines a "chinese character" as anything in the CJK Unicode block:
242 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
243 | #
244 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
245 | # despite its name. The modern Korean Hangul alphabet is a different block,
246 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
247 | # space-separated words, so they are not treated specially and handled
248 | # like the all of the other languages.
249 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
250 | (cp >= 0x3400 and cp <= 0x4DBF) or #
251 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
252 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
253 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
254 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
255 | (cp >= 0xF900 and cp <= 0xFAFF) or #
256 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
257 | return True
258 |
259 | return False
260 |
261 | def _clean_text(self, text):
262 | """Performs invalid character removal and whitespace cleanup on text."""
263 | output = []
264 | for char in text:
265 | cp = ord(char)
266 | if cp == 0 or cp == 0xfffd or _is_control(char):
267 | continue
268 | if _is_whitespace(char):
269 | output.append(" ")
270 | else:
271 | output.append(char)
272 | return "".join(output)
273 |
274 |
275 | class WordpieceTokenizer(object):
276 | """Runs WordPiece tokenization."""
277 |
278 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
279 | self.vocab = vocab
280 | self.unk_token = unk_token
281 | self.max_input_chars_per_word = max_input_chars_per_word
282 |
283 | def tokenize(self, text):
284 | """Tokenizes a piece of text into its word pieces.
285 |
286 | This uses a greedy longest-match-first algorithm to perform tokenization
287 | using the given vocabulary.
288 |
289 | For example:
290 | input = "unaffable"
291 | output = ["un", "##aff", "##able"]
292 |
293 | Args:
294 | text: A single token or whitespace separated tokens. This should have
295 | already been passed through `BasicTokenizer`.
296 |
297 | Returns:
298 | A list of wordpiece tokens.
299 | """
300 |
301 | output_tokens = []
302 | for token in whitespace_tokenize(text):
303 | chars = list(token)
304 | if len(chars) > self.max_input_chars_per_word:
305 | output_tokens.append(self.unk_token)
306 | continue
307 |
308 | is_bad = False
309 | start = 0
310 | sub_tokens = []
311 | while start < len(chars):
312 | end = len(chars)
313 | cur_substr = None
314 | while start < end:
315 | substr = "".join(chars[start:end])
316 | if start > 0:
317 | substr = "##" + substr
318 | if substr in self.vocab:
319 | cur_substr = substr
320 | break
321 | end -= 1
322 | if cur_substr is None:
323 | is_bad = True
324 | break
325 | sub_tokens.append(cur_substr)
326 | start = end
327 |
328 | if is_bad:
329 | output_tokens.append(self.unk_token)
330 | else:
331 | output_tokens.extend(sub_tokens)
332 | return output_tokens
333 |
334 |
335 | def _is_whitespace(char):
336 | """Checks whether `chars` is a whitespace character."""
337 | # \t, \n, and \r are technically contorl characters but we treat them
338 | # as whitespace since they are generally considered as such.
339 | if char == " " or char == "\t" or char == "\n" or char == "\r":
340 | return True
341 | cat = unicodedata.category(char)
342 | if cat == "Zs":
343 | return True
344 | return False
345 |
346 |
347 | def _is_control(char):
348 | """Checks whether `chars` is a control character."""
349 | # These are technically control characters but we count them as whitespace
350 | # characters.
351 | if char == "\t" or char == "\n" or char == "\r":
352 | return False
353 | cat = unicodedata.category(char)
354 | if cat.startswith("C"):
355 | return True
356 | return False
357 |
358 |
359 | def _is_punctuation(char):
360 | """Checks whether `chars` is a punctuation character."""
361 | cp = ord(char)
362 | # We treat all non-letter/number ASCII as punctuation.
363 | # Characters such as "^", "$", and "`" are not in the Unicode
364 | # Punctuation class but we treat them as punctuation anyways, for
365 | # consistency.
366 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
367 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
368 | return True
369 | cat = unicodedata.category(char)
370 | if cat.startswith("P"):
371 | return True
372 | return False
373 |
--------------------------------------------------------------------------------
/ner_evaluate.py:
--------------------------------------------------------------------------------
1 | import codecs
2 | import numpy as np
3 |
4 | def get_chunks(seq, tags):
5 | """
6 | tags:dic{'per':1,....}
7 | Args:
8 | seq: [4, 4, 0, 0, ...] sequence of labels
9 | tags: dict["O"] = 4
10 | Returns:
11 | list of (chunk_type, chunk_start, chunk_end)
12 |
13 | Example:
14 | seq = [4, 5, 0, 3]
15 | tags = {"B-PER": 4, "I-PER": 5, "B-LOC": 3}
16 | result = [("PER", 0, 2), ("LOC", 3, 4)]
17 | """
18 | default = tags['O']
19 | idx_to_tag = {idx: tag for tag, idx in tags.items()}
20 | chunks = []
21 | chunk_type, chunk_start = None, None
22 | for i, tok in enumerate(seq):
23 | #End of a chunk 1
24 | if tok == default and chunk_type is not None:
25 | # Add a chunk.
26 | chunk = (chunk_type, chunk_start, i)
27 | chunks.append(chunk)
28 | chunk_type, chunk_start = None, None
29 |
30 | # End of a chunk + start of a chunk!
31 | elif tok != default:
32 | tok_chunk_class, tok_chunk_type = get_chunk_type(tok, idx_to_tag)
33 | if chunk_type is None:
34 | chunk_type, chunk_start = tok_chunk_type, i
35 | elif tok_chunk_type != chunk_type or tok_chunk_class == "B":
36 | chunk = (chunk_type, chunk_start, i)
37 | chunks.append(chunk)
38 | chunk_type, chunk_start = tok_chunk_type, i
39 | else:
40 | pass
41 | # end condition
42 | if chunk_type is not None:
43 | chunk = (chunk_type, chunk_start, len(seq))
44 | chunks.append(chunk)
45 |
46 | return chunks
47 |
48 | def get_chunk_type(tok, idx_to_tag):
49 | """
50 | Args:
51 | tok: id of token, such as 4
52 | idx_to_tag: dictionary {4: "B-PER", ...}
53 | Returns:
54 | tuple: "B", "PER"
55 | """
56 | tag_name = idx_to_tag[tok]
57 | tag_class = tag_name.split('-')[0]
58 | tag_type = tag_name.split('-')[-1]
59 | return tag_class, tag_type
60 |
61 | # def run_evaluate(self, sess, test, tags):
62 | def evaluate(labels_pred, labels,words,tags):
63 |
64 | """
65 | words,pred, right: is a sequence, is label index or word index.
66 | Evaluates performance on test set
67 | Args:
68 | sess: tensorflow session
69 | test: dataset that yields tuple of sentences, tags
70 | tags: {tag: index} dictionary
71 | Returns:
72 | accuracy
73 | f1 score
74 | ...
75 | """
76 |
77 | #file_write = open('./test_results.txt','w')
78 |
79 |
80 | index = 0
81 | sents_length = []
82 |
83 | accs = []
84 | correct_preds, total_correct, total_preds = 0., 0., 0.
85 |
86 |
87 | for lab, lab_pred, word_sent in zip(labels, labels_pred, words):
88 | word_st = word_sent
89 | lab = lab
90 | lab_pred = lab_pred
91 | accs += [a==b for (a, b) in zip(lab, lab_pred)]
92 | lab_chunks = set(get_chunks(lab, tags))
93 | lab_pred_chunks = set(get_chunks(lab_pred, tags))
94 | correct_preds += len(lab_chunks & lab_pred_chunks)
95 | total_preds += len(lab_pred_chunks)
96 | total_correct += len(lab_chunks)
97 |
98 | #for i in range(len(word_st)):
99 | #file_write.write('%s\t%s\t%s\n'%(word_st[i],lab[i],lab_pred[i]))
100 | #file_write.write('\n')
101 |
102 | p = correct_preds / total_preds if correct_preds > 0 else 0
103 | r = correct_preds / total_correct if correct_preds > 0 else 0
104 | f1 = 2 * p * r / (p + r) if correct_preds > 0 else 0
105 | acc = np.mean(accs)
106 |
107 | #file_write.close()
108 | return acc, f1,p,r
109 |
110 | def evaluate_each_class(labels_pred, labels,words,tags, class_type):
111 | #class_type:PER or LOC or ORG
112 | index = 0
113 |
114 | accs = []
115 | correct_preds, total_correct, total_preds = 0., 0., 0.
116 | correct_preds_cla_type, total_preds_cla_type, total_correct_cla_type = 0., 0., 0.
117 |
118 | for lab, lab_pred, word_sent in zip(labels, labels_pred, words):
119 | lab_pre_class_type = []
120 | lab_class_type=[]
121 |
122 | word_st = word_sent
123 | lab = lab
124 | lab_pred = lab_pred
125 | lab_chunks = get_chunks(lab, tags)
126 | lab_pred_chunks = get_chunks(lab_pred, tags)
127 | for i in range(len(lab_pred_chunks)):
128 | if lab_pred_chunks[i][0] == class_type:
129 | lab_pre_class_type.append(lab_pred_chunks[i])
130 | lab_pre_class_type_c = set(lab_pre_class_type)
131 |
132 | for i in range(len(lab_chunks)):
133 | if lab_chunks[i][0] ==class_type:
134 | lab_class_type.append(lab_chunks[i])
135 | lab_class_type_c = set(lab_class_type)
136 |
137 | lab_chunksss = set(lab_chunks)
138 | correct_preds_cla_type +=len(lab_pre_class_type_c & lab_chunksss)
139 | total_preds_cla_type +=len(lab_pre_class_type_c)
140 | total_correct_cla_type += len(lab_class_type_c)
141 |
142 | p = correct_preds_cla_type / total_preds_cla_type if correct_preds_cla_type > 0 else 0
143 | r = correct_preds_cla_type / total_correct_cla_type if correct_preds_cla_type > 0 else 0
144 | f1 = 2 * p * r / (p + r) if correct_preds_cla_type > 0 else 0
145 |
146 | return f1,p,r
147 |
148 |
149 | if __name__ == '__main__':
150 | max_sent=10
151 | tags = {'0':0,
152 | 'B-PER':1, 'I-PER':2,
153 | 'B-LOC':3, 'I-LOC':4,
154 | 'B-ORG':5, 'I-ORG':6,
155 | 'B-OTHER':7, 'I-OTHER':8,
156 | 'O':9}
157 | labels_pred=[
158 | [9,9,9,1,3,1,2,2,0,0],
159 | [9,9,9,1,3,1,2,0,0,0]
160 | ]
161 | labels = [
162 | [9,9,9,9,3,1,2,2,0,0],
163 | [9,9,9,9,3,1,2,2,0,0]
164 | ]
165 | words = [
166 | [0,0,0,0,0,3,6,8,5,7],
167 | [0,0,0,4,5,6,7,9,1,7]
168 | ]
169 | id_to_vocb = {0:'a',1:'b',2:'c',3:'d',4:'e',5:'f',6:'g',7:'h',8:'i',9:'j'}
170 | new_words = []
171 | for i in range(len(words)):
172 | sent = []
173 | for j in range(len(words[i])):
174 | sent.append(id_to_vocb[words[i][j]])
175 | new_words.append(sent)
176 | class_type = 'PER'
177 | acc, f1,p,r = evaluate(labels_pred, labels,new_words,tags)
178 | print(p,r,f1)
179 | f1,p,r = evaluate_each_class(labels_pred, labels,new_words,tags, class_type)
180 | print(p,r,f1)
181 |
182 |
--------------------------------------------------------------------------------
/run_cmmt_crf.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import argparse
4 | import csv
5 | import logging
6 | import os
7 | import random
8 | import json
9 | import sys
10 |
11 | import numpy as np
12 | import torch
13 | import torch.nn.functional as F
14 | from my_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
15 | from my_bert.mner_modeling import (CONFIG_NAME, WEIGHTS_NAME,
16 | BertConfig, MTCCMBertForMMTokenClassificationCRF)
17 | from my_bert.optimization import BertAdam, warmup_linear
18 | from my_bert.tokenization import BertTokenizer
19 | from seqeval.metrics import classification_report
20 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
21 | TensorDataset)
22 | from torch.utils.data.distributed import DistributedSampler
23 | from tqdm import tqdm, trange
24 |
25 | import resnet.resnet as resnet
26 | #from resnet.resnet import resnet
27 | from resnet.resnet_utils import myResnet
28 |
29 | from torchvision import transforms
30 | from PIL import Image
31 |
32 | from sklearn.metrics import precision_recall_fscore_support
33 |
34 | from ner_evaluate import evaluate_each_class
35 | from ner_evaluate import evaluate
36 | from transformers import RobertaTokenizer, RobertaModel
37 |
38 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
39 | datefmt = '%m/%d/%Y %H:%M:%S',
40 | level = logging.INFO)
41 | logger = logging.getLogger(__name__)
42 |
43 |
44 | def image_process(image_path, transform):
45 | image = Image.open(image_path).convert('RGB')
46 | image = transform(image)
47 | return image
48 |
49 |
50 | class InputExample(object):
51 | """A single training/test example for simple sequence classification."""
52 |
53 | def __init__(self, guid, text_a, text_b=None, label=None):
54 | """Constructs a InputExample.
55 |
56 | Args:
57 | guid: Unique id for the example.
58 | text_a: string. The untokenized text of the first sequence. For single
59 | sequence tasks, only this sequence must be specified.
60 | text_b: (Optional) string. The untokenized text of the second sequence.
61 | Only must be specified for sequence pair tasks.
62 | label: (Optional) string. The label of the example. This should be
63 | specified for train and dev examples, but not for test examples.
64 | """
65 | self.guid = guid
66 | self.text_a = text_a
67 | self.text_b = text_b
68 | self.label = label
69 |
70 |
71 | class MMInputExample(object):
72 | """A single training/test example for simple sequence classification."""
73 |
74 | def __init__(self, guid, text_a, text_b, img_id, label=None, auxlabel=None,imagelabel= None): # yl add
75 | """Constructs a InputExample.
76 |
77 | Args:
78 | guid: Unique id for the example.
79 | text_a: string. The untokenized text of the first sequence. For single
80 | sequence tasks, only this sequence must be specified.
81 | text_b: (Optional) string. The untokenized text of the second sequence.
82 | Only must be specified for sequence pair tasks.
83 | label: (Optional) string. The label of the example. This should be
84 | specified for train and dev examples, but not for test examples.
85 | """
86 | self.guid = guid
87 | self.text_a = text_a
88 | self.text_b = text_b
89 | self.img_id = img_id
90 | self.label = label
91 | self.auxlabel = auxlabel
92 | self.imagelabel = imagelabel
93 |
94 |
95 | class InputFeatures(object):
96 | """A single set of features of data."""
97 |
98 | def __init__(self, input_ids, input_mask, segment_ids, label_id):
99 | self.input_ids = input_ids
100 | self.input_mask = input_mask
101 | self.segment_ids = segment_ids
102 | self.label_id = label_id
103 |
104 |
105 | class MMInputFeatures(object):
106 | """A single set of features of data."""
107 |
108 | def __init__(self, input_ids, input_mask, added_input_mask, segment_ids, img_feat, label_id, auxlabel_id, imagelabel):
109 | self.input_ids = input_ids
110 | self.input_mask = input_mask
111 | self.added_input_mask = added_input_mask
112 | self.segment_ids = segment_ids
113 | self.img_feat = img_feat
114 | self.label_id = label_id
115 | self.auxlabel_id = auxlabel_id
116 | self.imagelabel = imagelabel
117 |
118 |
119 | def readfile(filename):
120 | '''
121 | read file
122 | return format :
123 | [ ['EU', 'B-ORG'], ['rejects', 'O'], ['German', 'B-MISC'], ['call', 'O'], ['to', 'O'], ['boycott', 'O'], ['British', 'B-MISC'], ['lamb', 'O'], ['.', 'O'] ]
124 | '''
125 | f = open(filename)
126 | data = []
127 | sentence = []
128 | label= []
129 | for line in f:
130 | if len(line)==0 or line.startswith('-DOCSTART') or line[0]=="\n":
131 | if len(sentence) > 0:
132 | data.append((sentence,label))
133 | sentence = []
134 | label = []
135 | continue
136 | splits = line.split(' ')
137 | sentence.append(splits[0])
138 | label.append(splits[-1][:-1])
139 |
140 | if len(sentence) >0:
141 | data.append((sentence,label))
142 | sentence = []
143 | label = []
144 |
145 | print("The number of samples: "+ str(len(data)))
146 | return data
147 |
148 |
149 | def mmreadfile(filename, image_filename, path_img):
150 | '''
151 | read file
152 | return format :
153 | [ ['EU', 'B-ORG'], ['rejects', 'O'], ['German', 'B-MISC'], ['call', 'O'], ['to', 'O'], ['boycott', 'O'], ['British', 'B-MISC'], ['lamb', 'O'], ['.', 'O'] ]
154 | '''
155 | transform = transforms.Compose([
156 | transforms.RandomCrop(224), # args.crop_size, by default it is set to be 224
157 | transforms.RandomHorizontalFlip(),
158 | transforms.ToTensor(),
159 | transforms.Normalize((0.485, 0.456, 0.406),
160 | (0.229, 0.224, 0.225))])
161 | with open(image_filename, 'r') as f:
162 | image_data = json.load(f)
163 | f = open(filename, encoding='utf-8')
164 | data = []
165 | imgs = []
166 | auxlabels = []
167 | sentence = []
168 | label= []
169 | auxlabel = []
170 | imagelabels = []
171 | imgid = ''
172 | count = 0
173 | # print(image_data.keys())
174 | for line in f:
175 | if line.startswith('IMGID:'):
176 | imgid = line.strip().split('IMGID:')[1]+'.jpg'
177 | continue
178 | if line[0]=="\n":
179 | if len(sentence) > 0:
180 | data.append((sentence,label))
181 | imgs.append(imgid)
182 | image_path = os.path.join(path_img, imgid)
183 | if not os.path.exists(image_path):
184 | print(image_path)
185 | try:
186 | image = image_process(image_path, transform)
187 | except:
188 | # print('image has problem!')
189 | imgid = '17_06_4705.jpg'
190 |
191 | image_label = image_data.get(imgid)
192 | if image_label == None:
193 | count += 1
194 | #print(sentence)
195 | #print(label)
196 | #print(imgid)
197 | auxlabels.append(auxlabel)
198 | imagelabels.append(image_label)
199 | sentence = []
200 | label = []
201 | imgid = ''
202 | auxlabel = []
203 | continue
204 | splits = line.split('\t')
205 | sentence.append(splits[0])
206 | cur_label = splits[1] #splits[-1][:-1] # yl add
207 | if cur_label == 'B-OTHER':
208 | cur_label = 'B-MISC'
209 | elif cur_label == 'I-OTHER':
210 | cur_label = 'I-MISC'
211 | label.append(cur_label)
212 | auxlabel.append(cur_label)
213 | #auxlabel.append(splits[2][:-1]) # yl add
214 |
215 | print("The number of samples with NULL image labels: "+ str(count))
216 | if len(sentence) >0:
217 | data.append((sentence,label))
218 | imgs.append(imgid)
219 | auxlabels.append(auxlabel)
220 | imagelabels.append(image_label)
221 | sentence = []
222 | label = []
223 | auxlabel = []
224 |
225 | print("The number of samples: "+ str(len(data)))
226 | print("The number of images: "+ str(len(imgs)))
227 | return data, imgs, auxlabels, imagelabels
228 |
229 |
230 | class DataProcessor(object):
231 | """Base class for data converters for sequence classification data sets."""
232 |
233 | def get_train_examples(self, data_dir,image_filename,path_img): #yl add
234 | """Gets a collection of `InputExample`s for the train set."""
235 | raise NotImplementedError()
236 |
237 | def get_dev_examples(self, data_dir,image_filename,path_img): #yl add
238 | """Gets a collection of `InputExample`s for the dev set."""
239 | raise NotImplementedError()
240 |
241 | def get_labels(self):
242 | """Gets the list of labels for this data set."""
243 | raise NotImplementedError()
244 |
245 | @classmethod
246 | def _read_tsv(cls, input_file, quotechar=None):
247 | """Reads a tab separated value file."""
248 | return readfile(input_file)
249 |
250 | def _read_mmtsv(cls, input_file, image_filename, path_img, quotechar=None):
251 | """Reads a tab separated value file."""
252 | return mmreadfile(input_file, image_filename, path_img) # yl add
253 |
254 |
255 | class MNERProcessor(DataProcessor):
256 | """Processor for the CoNLL-2003 data set."""
257 |
258 | def get_train_examples(self, data_dir, image_filename, path_img): # yl add
259 | """See base class."""
260 | data, imgs, auxlabels,imagelabels = self._read_mmtsv(os.path.join(data_dir, "train.txt"), image_filename, path_img) # yl add
261 | return self._create_examples(data, imgs, auxlabels, imagelabels, "train")
262 |
263 | def get_dev_examples(self, data_dir,image_filename,path_img): # yl add
264 | """See base class."""
265 | data, imgs, auxlabels,imagelabels = self._read_mmtsv(os.path.join(data_dir, "valid.txt"), image_filename, path_img) # yl add
266 | return self._create_examples(data, imgs, auxlabels, imagelabels, "dev")
267 |
268 | def get_test_examples(self, data_dir,image_filename,path_img): # yl add
269 | """See base class."""
270 | data, imgs, auxlabels, imagelabels = self._read_mmtsv(os.path.join(data_dir, "test.txt"), image_filename, path_img) # yl add
271 | return self._create_examples(data, imgs, auxlabels, imagelabels, "test")
272 |
273 |
274 | def get_labels(self):
275 | return ["O", "B-NEU", "I-NEU", "B-POS", "I-POS", "B-NEG", "I-NEG","X","",""]
276 |
277 | ### modify
278 | def get_auxlabels(self):
279 | return ["O", "B-NEU", "I-NEU", "B-POS", "I-POS", "B-NEG", "I-NEG", "X", "", ""] #yl add ["O", "B-AE", "I-AE", "B-OE", "I-OE", "X", "", ""]
280 | #def get_auxlabels(self):
281 | #return ["O", "B", "I", "X", "[CLS]", "[SEP]"]
282 |
283 | ### modify
284 | def get_start_label_id(self):
285 | label_list = self.get_labels()
286 | label_map = {label: i for i, label in enumerate(label_list, 1)}
287 | return label_map['']
288 |
289 | def get_stop_label_id(self):
290 | label_list = self.get_labels()
291 | label_map = {label: i for i, label in enumerate(label_list, 1)}
292 | return label_map['']
293 |
294 | def _create_examples(self, lines, imgs, auxlabels, imagelabels, set_type): # yl add
295 | examples = []
296 | for i, (sentence, label) in enumerate(lines):
297 | guid = "%s-%s" % (set_type, i)
298 | text_a = ' '.join(sentence)
299 | text_b = None
300 | img_id = imgs[i]
301 | label = label
302 | auxlabel = auxlabels[i]
303 | imagelabel = imagelabels[i]
304 | examples.append(MMInputExample(guid=guid, text_a=text_a, text_b=text_b, img_id=img_id, label=label, auxlabel=auxlabel, imagelabel = imagelabel))
305 | return examples
306 |
307 |
308 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
309 | """Loads a data file into a list of `InputBatch`s."""
310 |
311 | label_map = {label : i for i, label in enumerate(label_list,1)}
312 |
313 | features = []
314 | for (ex_index,example) in enumerate(examples):
315 | textlist = example.text_a.split(' ')
316 | labellist = example.label
317 | tokens = []
318 | labels = []
319 | for i, word in enumerate(textlist):
320 | token = tokenizer.tokenize(word)
321 | tokens.extend(token)
322 | label_1 = labellist[i]
323 | for m in range(len(token)):
324 | if m == 0:
325 | labels.append(label_1)
326 | else:
327 | labels.append("X")
328 | if len(tokens) >= max_seq_length - 1:
329 | tokens = tokens[0:(max_seq_length - 2)]
330 | labels = labels[0:(max_seq_length - 2)]
331 | ntokens = []
332 | segment_ids = []
333 | label_ids = []
334 | ntokens.append("")
335 | segment_ids.append(0)
336 | label_ids.append(label_map[""])
337 | for i, token in enumerate(tokens):
338 | ntokens.append(token)
339 | segment_ids.append(0)
340 | label_ids.append(label_map[labels[i]])
341 | ntokens.append("")
342 | segment_ids.append(0)
343 | label_ids.append(label_map[""])
344 | input_ids = tokenizer.convert_tokens_to_ids(ntokens)
345 | input_mask = [1] * len(input_ids)
346 | while len(input_ids) < max_seq_length:
347 | input_ids.append(0)
348 | input_mask.append(0)
349 | segment_ids.append(0)
350 | label_ids.append(0)
351 | assert len(input_ids) == max_seq_length
352 | assert len(input_mask) == max_seq_length
353 | assert len(segment_ids) == max_seq_length
354 | assert len(label_ids) == max_seq_length
355 |
356 | if ex_index < 2:
357 | logger.info("*** Example ***")
358 | logger.info("guid: %s" % (example.guid))
359 | logger.info("tokens: %s" % " ".join(
360 | [str(x) for x in tokens]))
361 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
362 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
363 | logger.info(
364 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
365 | logger.info("label: %s" % " ".join([str(x) for x in label_ids]))
366 |
367 | features.append(
368 | InputFeatures(input_ids=input_ids,
369 | input_mask=input_mask,
370 | segment_ids=segment_ids,
371 | label_id=label_ids))
372 | return features
373 |
374 |
375 | def convert_mm_examples_to_features(examples, label_list, auxlabel_list, max_seq_length, tokenizer, crop_size, path_img):
376 | """Loads a data file into a list of `InputBatch`s."""
377 |
378 | label_map = {label: i for i, label in enumerate(label_list, 1)}
379 | auxlabel_map = {label: i for i, label in enumerate(auxlabel_list, 1)}
380 |
381 | features = []
382 | count = 0
383 |
384 | transform = transforms.Compose([
385 | transforms.RandomCrop(crop_size), # args.crop_size, by default it is set to be 224
386 | transforms.RandomHorizontalFlip(),
387 | transforms.ToTensor(),
388 | transforms.Normalize((0.485, 0.456, 0.406),
389 | (0.229, 0.224, 0.225))])
390 |
391 | for (ex_index, example) in enumerate(examples):
392 | textlist = example.text_a.split(' ')
393 | labellist = example.label
394 | auxlabellist = example.auxlabel
395 | imagelabellist = example.imagelabel
396 | imagelabellist = dict(sorted(imagelabellist.items()))
397 | imagelabel_value =[0]* len(imagelabellist)
398 | for i, (k, v) in enumerate(imagelabellist.items()):
399 | imagelabel_value[i]= v
400 | tokens = []
401 | labels = []
402 | auxlabels = []
403 | for i, word in enumerate(textlist):
404 | word = " "+ word
405 | token = tokenizer.tokenize(word)
406 | tokens.extend(token)
407 | label_1 = labellist[i]
408 | auxlabel_1 = auxlabellist[i]
409 | for m in range(len(token)):
410 | if m == 0:
411 | labels.append(label_1)
412 | auxlabels.append(auxlabel_1)
413 | else:
414 | labels.append("X")
415 | auxlabels.append("X")
416 | if len(tokens) >= max_seq_length - 1:
417 | tokens = tokens[0:(max_seq_length - 2)]
418 | labels = labels[0:(max_seq_length - 2)]
419 | auxlabels = auxlabels[0:(max_seq_length - 2)]
420 | ntokens = []
421 | segment_ids = []
422 | label_ids = []
423 | auxlabel_ids = []
424 | ntokens.append("")
425 | segment_ids.append(0)
426 | label_ids.append(label_map[""])
427 | auxlabel_ids.append(auxlabel_map[""])
428 | for i, token in enumerate(tokens):
429 | ntokens.append(token)
430 | segment_ids.append(0)
431 | label_ids.append(label_map[labels[i]])
432 | auxlabel_ids.append(auxlabel_map[auxlabels[i]])
433 | ntokens.append("")
434 | segment_ids.append(0)
435 | label_ids.append(label_map[""])
436 | auxlabel_ids.append(auxlabel_map[""])
437 | input_ids = tokenizer.convert_tokens_to_ids(ntokens)
438 | input_mask = [1] * len(input_ids)
439 | added_input_mask = [1] * (len(input_ids) + 49) # 1 or 49 is for encoding regional image representations
440 |
441 | while len(input_ids) < max_seq_length:
442 | input_ids.append(0)
443 | input_mask.append(0)
444 | added_input_mask.append(0)
445 | segment_ids.append(0)
446 | label_ids.append(0)
447 | auxlabel_ids.append(0)
448 |
449 | assert len(input_ids) == max_seq_length
450 | assert len(input_mask) == max_seq_length
451 | assert len(segment_ids) == max_seq_length
452 | assert len(label_ids) == max_seq_length
453 | assert len(auxlabel_ids) == max_seq_length
454 |
455 | image_name = example.img_id
456 | image_path = os.path.join(path_img, image_name)
457 |
458 | if not os.path.exists(image_path):
459 | print(image_path)
460 | try:
461 | image = image_process(image_path, transform)
462 | except:
463 | count += 1
464 | # print('image has problem!')
465 | image_path_fail = os.path.join(path_img, '17_06_4705.jpg')
466 | image = image_process(image_path_fail, transform)
467 |
468 | if ex_index < 2:
469 | logger.info("*** Example ***")
470 | logger.info("guid: %s" % (example.guid))
471 | logger.info("tokens: %s" % " ".join(
472 | [str(x) for x in tokens]))
473 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
474 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
475 | logger.info(
476 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
477 | logger.info("label: %s" % " ".join([str(x) for x in label_ids]))
478 | logger.info("auxlabel: %s" % " ".join([str(x) for x in auxlabel_ids]))
479 |
480 | features.append(
481 | MMInputFeatures(input_ids=input_ids, input_mask=input_mask, added_input_mask=added_input_mask,
482 | segment_ids=segment_ids, img_feat=image, label_id=label_ids, auxlabel_id=auxlabel_ids, imagelabel= imagelabel_value))
483 |
484 | print('the number of problematic samples: ' + str(count))
485 |
486 | return features
487 |
488 |
489 | def macro_f1(y_true, y_pred):
490 | p_macro, r_macro, f_macro, support_macro \
491 | = precision_recall_fscore_support(y_true, y_pred, average='macro')
492 | f_macro = 2*p_macro*r_macro/(p_macro+r_macro)
493 | return p_macro, r_macro, f_macro
494 |
495 | import datetime
496 | def main():
497 | start_time = datetime.datetime.now().strftime('%m-%d-%Y-%H-%M-%S_')
498 | parser = argparse.ArgumentParser()
499 |
500 | parser.add_argument('--cuda_id', type=str, default='0',
501 | help='Choose which GPUs to run')
502 | parser.add_argument("--bert_model", default="./model/roberta-base-cased", type=str,
503 | help="pre-trained model selected in the list: roberta-base-cased, "
504 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
505 | "bert-base-multilingual-cased, bert-base-chinese.")
506 |
507 | parser.add_argument("--task_name",
508 | default= "twitter2015", #twitter2017
509 | type=str,
510 | required=True,
511 | help="The name of the task to train.")
512 | parser.add_argument("--cache_dir",
513 | default="",
514 | type=str,
515 | help="Where do you want to store the pre-trained models downloaded from s3")
516 | parser.add_argument("--max_seq_length",
517 | default=128,
518 | type=int,
519 | help="The maximum total input sequence length after WordPiece tokenization. \n"
520 | "Sequences longer than this will be truncated, and sequences shorter \n"
521 | "than this will be padded.")
522 | parser.add_argument("--do_train",
523 | action='store_true',
524 | help="Whether to run training.")
525 | parser.add_argument("--do_eval",
526 | action='store_true',
527 | help="Whether to run eval on the dev set.")
528 | parser.add_argument("--do_lower_case",
529 | action='store_true',
530 | help="Set this flag if you are using an uncased model.")
531 | parser.add_argument("--train_batch_size",
532 | default=32,
533 | type=int,
534 | help="Total batch size for training.")
535 | parser.add_argument("--eval_batch_size",
536 | default=16,
537 | type=int,
538 | help="Total batch size for eval.")
539 | parser.add_argument("--learning_rate",
540 | default=3e-5,
541 | type=float,
542 | help="The initial learning rate for Adam.")
543 | parser.add_argument("--num_train_epochs",
544 | default=25.0,
545 | type=float,
546 | help="Total number of training epochs to perform.")
547 | parser.add_argument("--warmup_proportion",
548 | default=0.1,
549 | type=float,
550 | help="Proportion of training to perform linear learning rate warmup for. "
551 | "E.g., 0.1 = 10%% of training.")
552 | parser.add_argument("--no_cuda",
553 | action='store_true',
554 | help="Whether not to use CUDA when available")
555 | parser.add_argument("--local_rank",
556 | type=int,
557 | default=-1,
558 | help="local_rank for distributed training on gpus")
559 | parser.add_argument('--seed',
560 | type=int,
561 | default=64,
562 | help="random seed for initialization")
563 | parser.add_argument('--gradient_accumulation_steps',
564 | type=int,
565 | default=1,
566 | help="Number of updates steps to accumulate before performing a backward/update pass.")
567 | parser.add_argument('--fp16',
568 | action='store_true',
569 | help="Whether to use 16-bit float precision instead of 32-bit")
570 | parser.add_argument('--loss_scale',
571 | type=float, default=0,
572 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
573 | "0 (default value): dynamic loss scaling.\n"
574 | "Positive power of 2: static loss scaling value.\n")
575 | parser.add_argument('use_roberta', default=True, action = 'store_true')
576 |
577 | parser.add_argument('--mm_model', default='MTCCMBert', help='model name')
578 | parser.add_argument('--layer_num1', type=int, default=1, help='number of txt2img layer')
579 | parser.add_argument('--layer_num2', type=int, default=1, help='number of img2txt layer')
580 | parser.add_argument('--layer_num3', type=int, default=1, help='number of txt2txt layer')
581 | parser.add_argument('--fine_tune_cnn', action='store_true', help='fine tune pre-trained CNN if True')
582 | parser.add_argument('--resnet_root', default='./model/resnet', help='path the pre-trained cnn models')
583 | parser.add_argument('--crop_size', type=int, default=224, help='crop size of image')
584 | parser.add_argument('--path_image', default='./pytorch-pretrained-BERT/twitter_subimages/', help='path to images')
585 | parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
586 | parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
587 | parser.add_argument('--alpha', type=float, default=1)
588 | parser.add_argument('--beta', type=float, default=1)
589 | parser.add_argument('--dropout_rate', type=float, default=0.2)
590 | args = parser.parse_args()
591 | os.environ["CUDA_VISIBLE_DEVICES"] =args.cuda_id
592 |
593 |
594 | if args.task_name == "twitter2017": # this refers to twitter-2017 dataset
595 | args.path_image = "/mnt/nfs-storage-titan/alienware/intern18_snap/multi_modal_ABSA_pytorch_naacl/multi_modal_ABSA_pytorch_bilinear/twitter_subimages/"
596 | args.data_dir = "./data/twitter2017"
597 | args.image_filename = "./ANP_data/image_output2017.json"
598 | args.output_dir = start_time + "_twitter2017_output/"
599 | elif args.task_name == "twitter2015": # this refers to twitter-2015 dataset
600 | args.path_image = "/mnt/nfs-storage-titan/alienware/intern18_snap/multi_modal_ABSA_pytorch_naacl/multi_modal_ABSA_pytorch_bilinear/twitter15_images/"
601 | args.data_dir = "./data/twitter2015"
602 | args.image_filename = "./ANP_data/image_output2015.json"
603 | args.output_dir = start_time + "_twitter2015_output/"
604 |
605 | if args.server_ip and args.server_port:
606 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
607 | import ptvsd
608 | print("Waiting for debugger attach")
609 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
610 | ptvsd.wait_for_attach()
611 |
612 | processors = {
613 | "twitter2015": MNERProcessor,
614 | "twitter2017": MNERProcessor
615 | }
616 |
617 | if args.local_rank == -1 or args.no_cuda:
618 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
619 | n_gpu = torch.cuda.device_count()
620 | else:
621 | torch.cuda.set_device(args.local_rank)
622 | device = torch.device("cuda", args.local_rank)
623 | n_gpu = 1
624 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
625 | torch.distributed.init_process_group(backend='nccl')
626 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
627 | device, n_gpu, bool(args.local_rank != -1), args.fp16))
628 |
629 | if args.gradient_accumulation_steps < 1:
630 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
631 | args.gradient_accumulation_steps))
632 |
633 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
634 |
635 | random.seed(args.seed)
636 | np.random.seed(args.seed)
637 | torch.manual_seed(args.seed)
638 |
639 | args.do_train = True
640 | args.do_eval = True
641 | if not args.do_train and not args.do_eval:
642 | raise ValueError("At least one of `do_train` or `do_eval` must be True.")
643 |
644 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
645 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
646 | if not os.path.exists(args.output_dir):
647 | os.makedirs(args.output_dir)
648 |
649 | task_name = args.task_name.lower()
650 |
651 | if task_name not in processors:
652 | raise ValueError("Task not found: %s" % (task_name))
653 |
654 | processor = processors[task_name]()
655 | label_list = processor.get_labels()
656 | auxlabel_list = processor.get_auxlabels()
657 | num_labels = len(label_list)+1 # label 0 corresponds to padding, label in label_list starts from 1
658 | auxnum_labels = len(auxlabel_list)+1 # label 0 corresponds to padding, label in label_list starts from 1
659 |
660 | start_label_id = processor.get_start_label_id()
661 | stop_label_id = processor.get_stop_label_id()
662 |
663 | trans_matrix = np.zeros((auxnum_labels,num_labels), dtype=float)
664 | tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
665 |
666 | train_examples = None
667 | num_train_optimization_steps = None
668 | if args.do_train:
669 | train_examples = processor.get_train_examples(args.data_dir, args.image_filename, args.path_image) # yl add
670 | num_train_optimization_steps = int(
671 | len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
672 | if args.local_rank != -1:
673 | num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
674 |
675 | # Prepare model
676 | cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank))
677 |
678 | if args.mm_model == 'MTCCMBert':
679 | model = MTCCMBertForMMTokenClassificationCRF.from_pretrained(args.bert_model,args.use_roberta,
680 | cache_dir=cache_dir, layer_num1=args.layer_num1, layer_num2=args.layer_num2, layer_num3=args.layer_num3,
681 | num_labels = num_labels, auxnum_labels = auxnum_labels,dropout_rate=args.dropout_rate)
682 | if args.use_roberta:
683 | roberta_dict = torch.load('./model/roberta-base-cased/pytorch_model.bin') #changed to large or base
684 | new_state_dict = model.state_dict()
685 | miss_keys = []
686 | for k in new_state_dict.keys():
687 | if k in roberta_dict.keys():
688 | new_state_dict[k] = roberta_dict[k]
689 | else:
690 | miss_keys.append(k)
691 | if len(miss_keys) > 0:
692 | logger.info('miss keys: {}'.format(miss_keys))
693 | model.load_state_dict(new_state_dict)
694 |
695 | else:
696 | print('please define your MNER Model')
697 |
698 | net = getattr(resnet, 'resnet152')()
699 | net.load_state_dict(torch.load(os.path.join(args.resnet_root, 'resnet152.pth')))
700 | encoder = myResnet(net, args.fine_tune_cnn, device)
701 | if args.fp16:
702 | model.half()
703 | encoder.half()
704 | model.to(device)
705 | encoder.to(device)
706 | if args.local_rank != -1:
707 | try:
708 | from apex.parallel import DistributedDataParallel as DDP
709 | except ImportError:
710 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
711 |
712 | model = DDP(model)
713 | encoder = DDP(encoder)
714 | elif n_gpu > 1:
715 | model = torch.nn.DataParallel(model)
716 | encoder = torch.nn.DataParallel(encoder)
717 |
718 | param_optimizer = list(model.named_parameters())
719 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
720 | optimizer_grouped_parameters = [
721 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
722 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
723 | ]
724 | if args.fp16:
725 | try:
726 | from apex.optimizers import FP16_Optimizer
727 | from apex.optimizers import FusedAdam
728 | except ImportError:
729 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
730 |
731 | optimizer = FusedAdam(optimizer_grouped_parameters,
732 | lr=args.learning_rate,
733 | bias_correction=False,
734 | max_grad_norm=1.0)
735 | if args.loss_scale == 0:
736 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
737 | else:
738 | optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
739 |
740 | else:
741 | optimizer = BertAdam(optimizer_grouped_parameters,
742 | lr=args.learning_rate,
743 | warmup=args.warmup_proportion,
744 | t_total=num_train_optimization_steps)
745 |
746 | global_step = 0
747 | nb_tr_steps = 0
748 | tr_loss = 0
749 |
750 | output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
751 | output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
752 | output_encoder_file = os.path.join(args.output_dir, "pytorch_encoder.bin")
753 |
754 | if args.do_train:
755 | train_features = convert_mm_examples_to_features(
756 | train_examples, label_list, auxlabel_list, args.max_seq_length, tokenizer, args.crop_size, args.path_image)
757 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
758 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
759 | all_added_input_mask = torch.tensor([f.added_input_mask for f in train_features], dtype=torch.long)
760 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
761 | all_img_feats = torch.stack([f.img_feat for f in train_features])
762 | all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
763 | all_auxlabel_ids = torch.tensor([f.auxlabel_id for f in train_features], dtype=torch.long)
764 | all_imagelabel = torch.tensor([f.imagelabel for f in train_features], dtype=torch.float)
765 |
766 | train_data = TensorDataset(all_input_ids, all_input_mask, all_added_input_mask, all_segment_ids, all_img_feats,
767 | all_label_ids, all_auxlabel_ids,all_imagelabel)
768 |
769 | if args.local_rank == -1:
770 | train_sampler = RandomSampler(train_data)
771 | else:
772 | train_sampler = DistributedSampler(train_data)
773 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
774 |
775 | eval_examples = processor.get_dev_examples(args.data_dir, args.image_filename, args.path_image) # yl add
776 | eval_features = convert_mm_examples_to_features(
777 | eval_examples, label_list, auxlabel_list, args.max_seq_length, tokenizer, args.crop_size, args.path_image)
778 |
779 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
780 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
781 | all_added_input_mask = torch.tensor([f.added_input_mask for f in eval_features], dtype=torch.long)
782 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
783 | all_img_feats = torch.stack([f.img_feat for f in eval_features])
784 | all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
785 | all_auxlabel_ids = torch.tensor([f.auxlabel_id for f in eval_features], dtype=torch.long)
786 | all_imagelabel = torch.tensor([f.imagelabel for f in eval_features], dtype=torch.float)
787 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_added_input_mask, all_segment_ids, all_img_feats,
788 | all_label_ids, all_auxlabel_ids,all_imagelabel)
789 |
790 | # Run prediction for full data
791 | eval_sampler = SequentialSampler(eval_data)
792 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
793 |
794 | test_eval_examples = processor.get_test_examples(args.data_dir, args.image_filename, args.path_image) # yl add
795 | test_eval_features = convert_mm_examples_to_features(
796 | test_eval_examples, label_list, auxlabel_list, args.max_seq_length, tokenizer, args.crop_size, args.path_image)
797 | all_input_ids = torch.tensor([f.input_ids for f in test_eval_features], dtype=torch.long)
798 | all_input_mask = torch.tensor([f.input_mask for f in test_eval_features], dtype=torch.long)
799 | all_added_input_mask = torch.tensor([f.added_input_mask for f in test_eval_features], dtype=torch.long)
800 | all_segment_ids = torch.tensor([f.segment_ids for f in test_eval_features], dtype=torch.long)
801 | all_img_feats = torch.stack([f.img_feat for f in test_eval_features])
802 | all_label_ids = torch.tensor([f.label_id for f in test_eval_features], dtype=torch.long)
803 | all_auxlabel_ids = torch.tensor([f.auxlabel_id for f in test_eval_features], dtype=torch.long)
804 | all_imagelabel = torch.tensor([f.imagelabel for f in test_eval_features], dtype=torch.float)
805 |
806 | test_eval_data = TensorDataset(all_input_ids, all_input_mask, all_added_input_mask, all_segment_ids, all_img_feats,
807 | all_label_ids, all_auxlabel_ids,all_imagelabel)
808 | # Run prediction for full data
809 | test_eval_sampler = SequentialSampler(test_eval_data)
810 | test_eval_dataloader = DataLoader(test_eval_data, sampler=test_eval_sampler, batch_size=args.eval_batch_size)
811 |
812 | max_dev_f1 = 0.0
813 | max_test_f1 = 0.0
814 | best_dev_epoch = 0
815 | best_test_epoch = 0
816 | logger.info("***** Running training *****")
817 | for train_idx in trange(int(args.num_train_epochs), desc="Epoch"):
818 | logger.info("********** Epoch: " + str(train_idx) + " **********")
819 | logger.info(" Num examples = %d", len(train_examples))
820 | logger.info(" Batch size = %d", args.train_batch_size)
821 | logger.info(" Num steps = %d", num_train_optimization_steps)
822 | model.train()
823 | encoder.train()
824 | encoder.zero_grad()
825 | tr_loss = 0
826 | nb_tr_examples, nb_tr_steps = 0, 0
827 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
828 | batch = tuple(t.to(device) for t in batch)
829 | input_ids, input_mask, added_input_mask, segment_ids, img_feats, label_ids, auxlabel_ids, imagelabel = batch
830 | with torch.no_grad():
831 | imgs_f, img_mean, img_att = encoder(img_feats)
832 | trans_matrix = torch.tensor(trans_matrix).to(device)
833 | neg_log_likelihood = model(input_ids, segment_ids, input_mask, added_input_mask,
834 | img_att, trans_matrix, imagelabel,args.alpha, args.beta,label_ids, auxlabel_ids)
835 | if n_gpu > 1:
836 | neg_log_likelihood = neg_log_likelihood.mean() # mean() to average on multi-gpu.
837 | if args.gradient_accumulation_steps > 1:
838 | neg_log_likelihood = neg_log_likelihood / args.gradient_accumulation_steps
839 |
840 | if args.fp16:
841 | optimizer.backward(neg_log_likelihood)
842 | else:
843 | neg_log_likelihood.backward()
844 |
845 | tr_loss += neg_log_likelihood.item()
846 | nb_tr_examples += input_ids.size(0)
847 | nb_tr_steps += 1
848 | if (step + 1) % args.gradient_accumulation_steps == 0:
849 | if args.fp16:
850 | # modify learning rate with special warm up BERT uses
851 | # if args.fp16 is False, BertAdam is used that handles this automatically
852 | lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
853 | for param_group in optimizer.param_groups:
854 | param_group['lr'] = lr_this_step
855 | optimizer.step()
856 | optimizer.zero_grad()
857 | global_step += 1
858 |
859 | logger.info("***** Running evaluation on Dev Set*****")
860 | logger.info(" Num examples = %d", len(eval_examples))
861 | logger.info(" Batch size = %d", args.eval_batch_size)
862 |
863 | model.eval()
864 | encoder.eval()
865 | eval_loss, eval_accuracy = 0, 0
866 | nb_eval_steps, nb_eval_examples = 0, 0
867 |
868 | y_true = []
869 | y_pred = []
870 | y_true_idx = []
871 | y_pred_idx = []
872 | label_map = {i: label for i, label in enumerate(label_list, 1)}
873 | label_map[0] = "PAD"
874 | for input_ids, input_mask, added_input_mask, segment_ids, img_feats, label_ids, auxlabel_ids,imagelabel in tqdm(
875 | eval_dataloader,
876 | desc="Evaluating"):
877 | input_ids = input_ids.to(device)
878 | input_mask = input_mask.to(device)
879 | added_input_mask = added_input_mask.to(device)
880 | segment_ids = segment_ids.to(device)
881 | img_feats = img_feats.to(device)
882 | label_ids = label_ids.to(device)
883 | auxlabel_ids = auxlabel_ids.to(device)
884 |
885 | with torch.no_grad():
886 | imgs_f, img_mean, img_att = encoder(img_feats)
887 | predicted_label_seq_ids = model(input_ids, segment_ids, input_mask, added_input_mask, img_att,
888 | trans_matrix, imagelabel,args.alpha, args.beta)
889 |
890 | logits = predicted_label_seq_ids
891 | label_ids = label_ids.to('cpu').numpy()
892 | input_mask = input_mask.to('cpu').numpy()
893 | for i, mask in enumerate(input_mask):
894 | temp_1 = []
895 | temp_2 = []
896 | tmp1_idx = []
897 | tmp2_idx = []
898 | for j, m in enumerate(mask):
899 | if j == 0:
900 | continue
901 | if m:
902 | if label_map[label_ids[i][j]] != "X" and label_map[
903 | label_ids[i][j]] != "":
904 | temp_1.append(label_map[label_ids[i][j]])
905 | tmp1_idx.append(label_ids[i][j])
906 | temp_2.append(label_map[logits[i][j]])
907 | tmp2_idx.append(logits[i][j])
908 | else:
909 | break
910 | y_true.append(temp_1)
911 | y_pred.append(temp_2)
912 | y_true_idx.append(tmp1_idx)
913 | y_pred_idx.append(tmp2_idx)
914 |
915 | # report = classification_report(y_true, y_pred, digits=4)
916 | sentence_list = []
917 | dev_data, imgs, _ ,_ = processor._read_mmtsv(os.path.join(args.data_dir, "valid.txt"),args.image_filename, args.path_image)
918 | for i in range(len(y_pred)):
919 | sentence = dev_data[i][0]
920 | sentence_list.append(sentence)
921 | reverse_label_map = {label: i for i, label in enumerate(label_list, 1)}
922 | acc, f1, p, r = evaluate(y_pred_idx, y_true_idx, sentence_list, reverse_label_map)
923 | logger.info("***** Dev Eval results *****")
924 | print("Overall: ", p, r, f1)
925 | per_f1, per_p, per_r = evaluate_each_class(y_pred_idx, y_true_idx, sentence_list, reverse_label_map, 'POS')
926 | print("Positive: ", per_p, per_r, per_f1)
927 | loc_f1, loc_p, loc_r = evaluate_each_class(y_pred_idx, y_true_idx, sentence_list, reverse_label_map, 'NEU')
928 | print("Neutral: ", loc_p, loc_r, loc_f1)
929 | org_f1, org_p, org_r = evaluate_each_class(y_pred_idx, y_true_idx, sentence_list, reverse_label_map, 'NEG')
930 | print("Negative: ", org_p, org_r, org_f1)
931 | F_score_dev = f1
932 |
933 | logger.info("***** Running Test evaluation *****")
934 | logger.info(" Num examples = %d", len(test_eval_examples))
935 | logger.info(" Batch size = %d", args.eval_batch_size)
936 | y_true = []
937 | y_pred = []
938 | y_true_idx = []
939 | y_pred_idx = []
940 | label_map = {i: label for i, label in enumerate(label_list, 1)}
941 | label_map[0] = "PAD"
942 | for input_ids, input_mask, added_input_mask, segment_ids, img_feats, label_ids, auxlabel_ids,imagelabel in tqdm(test_eval_dataloader,
943 | desc="Evaluating"):
944 | input_ids = input_ids.to(device)
945 | input_mask = input_mask.to(device)
946 | added_input_mask = added_input_mask.to(device)
947 | segment_ids = segment_ids.to(device)
948 | img_feats = img_feats.to(device)
949 | label_ids = label_ids.to(device)
950 | auxlabel_ids = auxlabel_ids.to(device)
951 |
952 | with torch.no_grad():
953 | imgs_f, img_mean, img_att = encoder(img_feats)
954 | predicted_label_seq_ids = model(input_ids, segment_ids, input_mask, added_input_mask, img_att,trans_matrix,imagelabel,args.alpha, args.beta)
955 |
956 | logits = predicted_label_seq_ids
957 | label_ids = label_ids.to('cpu').numpy()
958 | input_mask = input_mask.to('cpu').numpy()
959 | for i, mask in enumerate(input_mask):
960 | temp_1 = []
961 | temp_2 = []
962 | tmp1_idx = []
963 | tmp2_idx = []
964 | for j, m in enumerate(mask):
965 | if j == 0:
966 | continue
967 | if m:
968 | if label_map[label_ids[i][j]] != "X" and label_map[
969 | label_ids[i][j]] != "":
970 | temp_1.append(label_map[label_ids[i][j]])
971 | tmp1_idx.append(label_ids[i][j])
972 | temp_2.append(label_map[logits[i][j]])
973 | tmp2_idx.append(logits[i][j])
974 | else:
975 | break
976 | y_true.append(temp_1)
977 | y_pred.append(temp_2)
978 | y_true_idx.append(tmp1_idx)
979 | y_pred_idx.append(tmp2_idx)
980 |
981 | #report = classification_report(y_true, y_pred, digits=4)
982 | sentence_list = []
983 | test_data, imgs, _,_ = processor._read_mmtsv(os.path.join(args.data_dir, "test.txt"),args.image_filename, args.path_image)
984 | for i in range(len(y_pred)):
985 | sentence = test_data[i][0]
986 | sentence_list.append(sentence)
987 |
988 | reverse_label_map = {label: i for i, label in enumerate(label_list, 1)}
989 | acc, f1, p, r = evaluate(y_pred_idx, y_true_idx, sentence_list, reverse_label_map)
990 | logger.info("***** Test Eval results *****")
991 | print("Overall: ", p, r, f1)
992 | per_f1, per_p, per_r = evaluate_each_class(y_pred_idx, y_true_idx, sentence_list, reverse_label_map, 'POS')
993 | print("Positive: ", per_p, per_r, per_f1)
994 | loc_f1, loc_p, loc_r = evaluate_each_class(y_pred_idx, y_true_idx, sentence_list, reverse_label_map, 'NEU')
995 | print("Neutral: ", loc_p, loc_r, loc_f1)
996 | org_f1, org_p, org_r = evaluate_each_class(y_pred_idx, y_true_idx, sentence_list, reverse_label_map, 'NEG')
997 | print("Negative: ", org_p, org_r, org_f1)
998 | F_score_test = f1
999 |
1000 | if F_score_dev > max_dev_f1:
1001 | # Save a trained model and the associated configuration
1002 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
1003 | encoder_to_save = encoder.module if hasattr(encoder,
1004 | 'module') else encoder # Only save the model it-self
1005 | torch.save(model_to_save.state_dict(), output_model_file)
1006 | torch.save(encoder_to_save.state_dict(), output_encoder_file)
1007 | with open(output_config_file, 'w') as f:
1008 | f.write(model_to_save.config.to_json_string())
1009 | label_map = {i: label for i, label in enumerate(label_list, 1)}
1010 | model_config = {"bert_model": args.bert_model, "do_lower": args.do_lower_case,
1011 | "max_seq_length": args.max_seq_length, "num_labels": len(label_list) + 1,
1012 | "label_map": label_map}
1013 | json.dump(model_config, open(os.path.join(args.output_dir, "model_config.json"), "w"))
1014 | max_dev_f1 = F_score_dev
1015 | best_dev_epoch = train_idx
1016 | if F_score_test > max_test_f1:
1017 | max_test_f1 = F_score_test
1018 | best_test_epoch = train_idx
1019 |
1020 | print("**************************************************")
1021 | print("The best epoch on the dev set: ", best_dev_epoch)
1022 | print("The best Micro-F1 score on the dev set: ", max_dev_f1)
1023 | print("The best epoch on the test set: ", best_test_epoch)
1024 | print("The best Micro-F1 score on the test set: ", max_test_f1)
1025 | print('\n')
1026 |
1027 | config = BertConfig(output_config_file)
1028 | if args.mm_model == 'MTCCMBert':
1029 | model = MTCCMBertForMMTokenClassificationCRF(config,args.use_roberta, layer_num1=args.layer_num1, layer_num2=args.layer_num2,
1030 | layer_num3=args.layer_num3, num_labels=num_labels, auxnum_labels = auxnum_labels)
1031 |
1032 | else:
1033 | print('please define your MNER Model')
1034 |
1035 | model.load_state_dict(torch.load(output_model_file))
1036 | model.to(device)
1037 | encoder_state_dict = torch.load(output_encoder_file)
1038 | encoder.load_state_dict(encoder_state_dict)
1039 | encoder.to(device)
1040 |
1041 | if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
1042 | eval_examples = processor.get_test_examples(args.data_dir, args.image_filename, args.path_image)
1043 | eval_features = convert_mm_examples_to_features(
1044 | eval_examples, label_list, auxlabel_list, args.max_seq_length, tokenizer, args.crop_size, args.path_image)
1045 | logger.info("***** Running Test Evaluation with the Best Model on the Dev Set*****")
1046 | logger.info(" Num examples = %d", len(eval_examples))
1047 | logger.info(" Batch size = %d", args.eval_batch_size)
1048 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
1049 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
1050 | all_added_input_mask = torch.tensor([f.added_input_mask for f in eval_features], dtype=torch.long)
1051 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
1052 | all_img_feats = torch.stack([f.img_feat for f in eval_features])
1053 | all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
1054 | all_auxlabel_ids = torch.tensor([f.auxlabel_id for f in eval_features], dtype=torch.long)
1055 | all_imagelabel = torch.tensor([f.imagelabel for f in eval_features], dtype=torch.float)
1056 |
1057 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_added_input_mask, all_segment_ids, all_img_feats,
1058 | all_label_ids, all_auxlabel_ids,all_imagelabel)
1059 | # Run prediction for full data
1060 | eval_sampler = SequentialSampler(eval_data)
1061 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
1062 | model.eval()
1063 | encoder.eval()
1064 | eval_loss, eval_accuracy = 0, 0
1065 | nb_eval_steps, nb_eval_examples = 0, 0
1066 | y_true = []
1067 | y_pred = []
1068 | y_true_idx = []
1069 | y_pred_idx = []
1070 | label_map = {i : label for i, label in enumerate(label_list,1)}
1071 | label_map[0] = "PAD"
1072 | for input_ids, input_mask, added_input_mask, segment_ids, img_feats, label_ids, auxlabel_ids,imagelabel in tqdm(eval_dataloader, desc="Evaluating"):
1073 | input_ids = input_ids.to(device)
1074 | input_mask = input_mask.to(device)
1075 | added_input_mask = added_input_mask.to(device)
1076 | segment_ids = segment_ids.to(device)
1077 | img_feats = img_feats.to(device)
1078 | label_ids = label_ids.to(device)
1079 | auxlabel_ids = auxlabel_ids.to(device)
1080 | trans_matrix = torch.tensor(trans_matrix).to(device)
1081 |
1082 | with torch.no_grad():
1083 | imgs_f, img_mean, img_att = encoder(img_feats)
1084 | predicted_label_seq_ids = model(input_ids, segment_ids, input_mask, added_input_mask, img_att, trans_matrix,imagelabel,args.alpha, args.beta)
1085 |
1086 | logits = predicted_label_seq_ids
1087 | label_ids = label_ids.to('cpu').numpy()
1088 | input_mask = input_mask.to('cpu').numpy()
1089 | for i,mask in enumerate(input_mask):
1090 | temp_1 = []
1091 | temp_2 = []
1092 | tmp1_idx = []
1093 | tmp2_idx = []
1094 | for j, m in enumerate(mask):
1095 | if j == 0:
1096 | continue
1097 | if m:
1098 | if label_map[label_ids[i][j]] != "X" and label_map[label_ids[i][j]] != "":
1099 | temp_1.append(label_map[label_ids[i][j]])
1100 | tmp1_idx.append(label_ids[i][j])
1101 | temp_2.append(label_map[logits[i][j]])
1102 | tmp2_idx.append(logits[i][j])
1103 | else:
1104 | break
1105 | y_true.append(temp_1)
1106 | y_pred.append(temp_2)
1107 | y_true_idx.append(tmp1_idx)
1108 | y_pred_idx.append(tmp2_idx)
1109 |
1110 | sentence_list = []
1111 | test_data, imgs, _,_ = processor._read_mmtsv(os.path.join(args.data_dir, "test.txt"),args.image_filename, args.path_image)
1112 | output_pred_file = os.path.join(args.output_dir, "mtmner_pred.txt")
1113 | fout = open(output_pred_file, 'w', encoding='UTF-8')
1114 | for i in range(len(y_pred)):
1115 | sentence = test_data[i][0]
1116 | sentence_list.append(sentence)
1117 | img = imgs[i]
1118 | samp_pred_label = y_pred[i]
1119 | samp_true_label = y_true[i]
1120 | fout.write(img+'\n')
1121 | fout.write(' '.join(sentence)+'\n')
1122 | fout.write(' '.join(samp_pred_label)+'\n')
1123 | fout.write(' '.join(samp_true_label)+'\n'+'\n')
1124 | fout.close()
1125 | logger.info("***** Test Eval results *****")
1126 |
1127 | reverse_label_map = {label: i for i, label in enumerate(label_list, 1)}
1128 | acc, f1, p, r = evaluate(y_pred_idx, y_true_idx, sentence_list, reverse_label_map)
1129 | print("Overall: ", p, r, f1)
1130 | per_f1, per_p, per_r = evaluate_each_class(y_pred_idx, y_true_idx, sentence_list, reverse_label_map, 'POS')
1131 | print("Positive: ", per_p, per_r, per_f1)
1132 | loc_f1, loc_p, loc_r = evaluate_each_class(y_pred_idx, y_true_idx, sentence_list, reverse_label_map, 'NEU')
1133 | print("Neutral: ", loc_p, loc_r, loc_f1)
1134 | org_f1, org_p, org_r = evaluate_each_class(y_pred_idx, y_true_idx, sentence_list, reverse_label_map, 'NEG')
1135 | print("Negative: ", org_p, org_r, org_f1)
1136 |
1137 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
1138 | with open(output_eval_file, "w") as writer:
1139 | #logger.info("\n%s", report)
1140 | #writer.write(report)
1141 | writer.write("Overall: " + str(p) + ' ' + str(r) + ' ' + str(f1) + '\n')
1142 | writer.write("Positive: " + str(per_p) + ' ' + str(per_r) + ' ' + str(per_f1) + '\n')
1143 | writer.write("Neutral: " + str(loc_p) + ' ' + str(loc_r) + ' ' + str(loc_f1) + '\n')
1144 | writer.write("Negative: " + str(org_p) + ' ' + str(org_r) + ' ' + str(org_f1) + '\n')
1145 |
1146 |
1147 | if __name__ == "__main__":
1148 | main()
1149 |
--------------------------------------------------------------------------------
/run_cmmt_crf.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | for i in 'twitter2015' 'twitter2017'
3 | do
4 | echo 'run_cmmt_crf.py'
5 | echo ${i}
6 | echo ${k}
7 | PYTHONIOENCODING=utf-8 CUDA_VISIBLE_DEVICES=0 python run_cmmt_crf.py --task_name=${i}
8 | done
9 |
--------------------------------------------------------------------------------