├── __pycache__
├── bert.cpython-35.pyc
├── DataGen.cpython-35.pyc
└── custom_model.cpython-35.pyc
├── layers
├── __pycache__
│ ├── crf.cpython-35.pyc
│ ├── decoder.cpython-35.pyc
│ ├── embedding.cpython-35.pyc
│ └── encoder.cpython-35.pyc
├── encoder.py
├── .ipynb_checkpoints
│ ├── encoder-checkpoint.py
│ ├── embedding-checkpoint.py
│ ├── decoder-checkpoint.py
│ └── crf-checkpoint.py
├── embedding.py
├── decoder.py
└── crf.py
├── pytorch_pretrained_bert
├── __init__.pyc
├── file_utils.pyc
├── modeling.pyc
├── optimization.pyc
├── tokenization.pyc
├── modeling_gpt2.pyc
├── modeling_openai.pyc
├── tokenization_gpt2.pyc
├── modeling_transfo_xl.pyc
├── optimization_openai.pyc
├── tokenization_openai.pyc
├── tokenization_transfo_xl.pyc
├── modeling_transfo_xl_utilities.pyc
├── __pycache__
│ ├── __init__.cpython-35.pyc
│ ├── modeling.cpython-35.pyc
│ ├── file_utils.cpython-35.pyc
│ ├── optimization.cpython-35.pyc
│ ├── tokenization.cpython-35.pyc
│ ├── modeling_gpt2.cpython-35.pyc
│ ├── modeling_openai.cpython-35.pyc
│ ├── tokenization_gpt2.cpython-35.pyc
│ ├── modeling_transfo_xl.cpython-35.pyc
│ ├── optimization_openai.cpython-35.pyc
│ ├── tokenization_openai.cpython-35.pyc
│ ├── tokenization_transfo_xl.cpython-35.pyc
│ └── modeling_transfo_xl_utilities.cpython-35.pyc
├── __init__.py
├── convert_tf_checkpoint_to_pytorch.py
├── convert_gpt2_checkpoint_to_pytorch.py
├── convert_openai_checkpoint_to_pytorch.py
├── __main__.py
├── convert_transfo_xl_checkpoint_to_pytorch.py
├── optimization_openai.py
├── optimization.py
├── file_utils.py
├── tokenization_gpt2.py
├── tokenization_openai.py
├── tokenization.py
├── modeling_transfo_xl_utilities.py
└── tokenization_transfo_xl.py
├── data
└── train_analysis.txt
├── README.md
├── multi_cased_L-12_H-768_A-12
└── bert_config.json
├── debug.py
├── predict.py
├── train.py
├── model.py
└── DataGen.py
/__pycache__/bert.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/__pycache__/bert.cpython-35.pyc
--------------------------------------------------------------------------------
/__pycache__/DataGen.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/__pycache__/DataGen.cpython-35.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/crf.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/layers/__pycache__/crf.cpython-35.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__init__.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/file_utils.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/file_utils.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/modeling.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/modeling.pyc
--------------------------------------------------------------------------------
/__pycache__/custom_model.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/__pycache__/custom_model.cpython-35.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/optimization.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/optimization.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/tokenization.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/tokenization.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/decoder.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/layers/__pycache__/decoder.cpython-35.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/embedding.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/layers/__pycache__/embedding.cpython-35.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/encoder.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/layers/__pycache__/encoder.cpython-35.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/modeling_gpt2.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/modeling_gpt2.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/modeling_openai.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/modeling_openai.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/tokenization_gpt2.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/tokenization_gpt2.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/modeling_transfo_xl.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/modeling_transfo_xl.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/optimization_openai.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/optimization_openai.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/tokenization_openai.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/tokenization_openai.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/tokenization_transfo_xl.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/tokenization_transfo_xl.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/modeling_transfo_xl_utilities.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/modeling_transfo_xl_utilities.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__pycache__/modeling.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/modeling.cpython-35.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__pycache__/file_utils.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/file_utils.cpython-35.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__pycache__/optimization.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/optimization.cpython-35.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__pycache__/tokenization.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/tokenization.cpython-35.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__pycache__/modeling_gpt2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/modeling_gpt2.cpython-35.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__pycache__/modeling_openai.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/modeling_openai.cpython-35.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__pycache__/tokenization_gpt2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/tokenization_gpt2.cpython-35.pyc
--------------------------------------------------------------------------------
/data/train_analysis.txt:
--------------------------------------------------------------------------------
1 | #doc1
2 | 0 Ai B-VAR ['N'] [0]
3 | 1 là B-RLT ['PO','SP'] [0,2]
4 | 2 Langston_Hughes B-ETT ['N'] [2]
5 | #doc2
6 | 0 Ai B-VAR ['N'] [0]
7 | 1 là B-RLT ['PO','SP'] [0,2]
8 | 2 Damocles B-ETT ['N'] [2]
9 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__pycache__/modeling_transfo_xl.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/modeling_transfo_xl.cpython-35.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__pycache__/optimization_openai.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/optimization_openai.cpython-35.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__pycache__/tokenization_openai.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/tokenization_openai.cpython-35.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__pycache__/tokenization_transfo_xl.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/tokenization_transfo_xl.cpython-35.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__pycache__/modeling_transfo_xl_utilities.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/modeling_transfo_xl_utilities.cpython-35.pyc
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # bert-jointly-relation-entity-extraction
2 |
3 | ## Reference:
4 | [1] BERT pytorch: https://github.com/huggingface/pytorch-pretrained-BERT
5 | [2] CRF, NER pytorch: https://github.com/sberbank-ai/ner-bert
6 | [3] multihead jointly relation & entity extraction: https://github.com/bekou/multihead_joint_entity_relation_extraction
7 | Article for [3]: https://arxiv.org/abs/1804.07847
8 |
--------------------------------------------------------------------------------
/multi_cased_L-12_H-768_A-12/bert_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "attention_probs_dropout_prob": 0.1,
3 | "directionality": "bidi",
4 | "hidden_act": "gelu",
5 | "hidden_dropout_prob": 0.1,
6 | "hidden_size": 768,
7 | "initializer_range": 0.02,
8 | "intermediate_size": 3072,
9 | "max_position_embeddings": 512,
10 | "num_attention_heads": 12,
11 | "num_hidden_layers": 12,
12 | "pooler_fc_size": 768,
13 | "pooler_num_attention_heads": 12,
14 | "pooler_num_fc_layers": 3,
15 | "pooler_size_per_head": 128,
16 | "pooler_type": "first_token_transform",
17 | "type_vocab_size": 2,
18 | "vocab_size": 119547
19 | }
20 |
--------------------------------------------------------------------------------
/debug.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytorch_pretrained_bert.modeling import *
3 | from DataGen import DataGenerator
4 | from model import BertBiLSTMCRF
5 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
6 |
7 | if __name__ == "__main__":
8 | BERT_PRETRAINED_PATH = "./multi_cased_L-12_H-768_A-12/"
9 | TRAIN_PATH = "/data/loclh2/QABot/data/train_analysis.txt"
10 | batch_size = 32
11 | shuffle = False
12 |
13 | data_gen = DataGenerator(model=BertModel, model_name=BERT_PRETRAINED_PATH)
14 |
15 | train_gen = data_gen.get_generator(TRAIN_PATH, batch_size, shuffle=shuffle)
16 | for batch in train_gen:
17 | pass
18 |
19 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.6.1"
2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
3 | from .tokenization_openai import OpenAIGPTTokenizer
4 | from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
5 | from .tokenization_gpt2 import GPT2Tokenizer
6 |
7 | from .modeling import (BertConfig, BertModel, BertForPreTraining,
8 | BertForMaskedLM, BertForNextSentencePrediction,
9 | BertForSequenceClassification, BertForMultipleChoice,
10 | BertForTokenClassification, BertForQuestionAnswering,
11 | load_tf_weights_in_bert)
12 | from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
13 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
14 | load_tf_weights_in_openai_gpt)
15 | from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel,
16 | load_tf_weights_in_transfo_xl)
17 | from .modeling_gpt2 import (GPT2Config, GPT2Model,
18 | GPT2LMHeadModel, GPT2DoubleHeadsModel,
19 | load_tf_weights_in_gpt2)
20 |
21 | from .optimization import BertAdam
22 | from .optimization_openai import OpenAIAdam
23 |
24 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path
25 |
--------------------------------------------------------------------------------
/layers/encoder.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 |
4 | class BertBiLSTMEncoder(nn.Module):
5 | def __init__(self, embeddings,
6 | hidden_dim=128, rnn_layers=1, use_cuda=True):
7 | super(BertBiLSTMEncoder, self).__init__()
8 | self.embeddings = embeddings
9 | self.hidden_dim = hidden_dim
10 | self.rnn_layers = rnn_layers
11 | self.use_cuda = use_cuda
12 | self.lstm = nn.LSTM(
13 | self.embeddings.embedding_dim, hidden_dim // 2,
14 | rnn_layers, batch_first=True, bidirectional=True)
15 | self.hidden = None
16 | if use_cuda:
17 | self.cuda()
18 | self.init_weights()
19 | self.output_dim = hidden_dim
20 |
21 | def init_weights(self):
22 | #for p in self.lstm.parameters():
23 | # nn.init.xavier_normal(p)
24 | pass
25 | def sort_lengths(self, inputs, input_lens):
26 | inputs_list = inputs.tolist()
27 | sorted_input_lens = sorted(input_lens, key=lambda l: l, reverse=True)
28 | sorted_input_len_ids = [input_lens.index(i) for i in sorted_input_lens]
29 | sorted_input_list = [inputs_list[i] for i in sorted_input_len_ids]
30 | return torch.tensor(sorted_input_list), sorted_input_lens
31 |
32 | def forward(self, batch):
33 | input, input_mask = batch[0], batch[1]
34 | output = self.embeddings(*batch)
35 | # output = self.dropout(output)
36 | lens = input_mask.sum(-1).tolist()
37 | output, lens = self.sort_lengths(output, lens)
38 |
39 | output = nn.utils.rnn.pack_padded_sequence(output, lens, batch_first=True)
40 |
41 | if self.use_cuda:
42 | output, self.hidden = self.lstm(output.to("cuda"))
43 | else:
44 | output, self.hidden = self.lstm(output.to("cpu"))
45 |
46 | output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
47 |
48 | return output, self.hidden
49 |
50 | @classmethod
51 | def create(cls, embeddings, hidden_dim=128, rnn_layers=1, use_cuda=True):
52 | model = cls(
53 | embeddings=embeddings, hidden_dim=hidden_dim, rnn_layers=rnn_layers, use_cuda=use_cuda)
54 | return model
55 |
--------------------------------------------------------------------------------
/layers/.ipynb_checkpoints/encoder-checkpoint.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 |
4 | class BertBiLSTMEncoder(nn.Module):
5 | def __init__(self, embeddings,
6 | hidden_dim=128, rnn_layers=1, use_cuda=True):
7 | super(BertBiLSTMEncoder, self).__init__()
8 | self.embeddings = embeddings
9 | self.hidden_dim = hidden_dim
10 | self.rnn_layers = rnn_layers
11 | self.use_cuda = use_cuda
12 | self.lstm = nn.LSTM(
13 | self.embeddings.embedding_dim, hidden_dim // 2,
14 | rnn_layers, batch_first=True, bidirectional=True)
15 | self.hidden = None
16 | if use_cuda:
17 | self.cuda()
18 | self.init_weights()
19 | self.output_dim = hidden_dim
20 |
21 | def init_weights(self):
22 | #for p in self.lstm.parameters():
23 | # nn.init.xavier_normal(p)
24 | pass
25 | def sort_lengths(self, inputs, input_lens):
26 | inputs_list = inputs.tolist()
27 | sorted_input_lens = sorted(input_lens, key=lambda l: l, reverse=True)
28 | sorted_input_len_ids = [input_lens.index(i) for i in sorted_input_lens]
29 | sorted_input_list = [inputs_list[i] for i in sorted_input_len_ids]
30 | return torch.tensor(sorted_input_list), sorted_input_lens
31 |
32 | def forward(self, batch):
33 | input, input_mask = batch[0], batch[1]
34 | output = self.embeddings(*batch)
35 | # output = self.dropout(output)
36 | lens = input_mask.sum(-1).tolist()
37 | output, lens = self.sort_lengths(output, lens)
38 |
39 | output = nn.utils.rnn.pack_padded_sequence(output, lens, batch_first=True)
40 |
41 | if self.use_cuda:
42 | output, self.hidden = self.lstm(output.to("cuda"))
43 | else:
44 | output, self.hidden = self.lstm(output.to("cpu"))
45 |
46 | output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
47 |
48 | return output, self.hidden
49 |
50 | @classmethod
51 | def create(cls, embeddings, hidden_dim=128, rnn_layers=1, use_cuda=True):
52 | model = cls(
53 | embeddings=embeddings, hidden_dim=hidden_dim, rnn_layers=rnn_layers, use_cuda=use_cuda)
54 | return model
55 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert BERT checkpoint."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 | import re
23 | import argparse
24 | import tensorflow as tf
25 | import torch
26 | import numpy as np
27 |
28 | from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert
29 |
30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
31 | # Initialise PyTorch model
32 | config = BertConfig.from_json_file(bert_config_file)
33 | print("Building PyTorch model from configuration: {}".format(str(config)))
34 | model = BertForPreTraining(config)
35 |
36 | # Load weights from tf checkpoint
37 | load_tf_weights_in_bert(model, tf_checkpoint_path)
38 |
39 | # Save pytorch-model
40 | print("Save PyTorch model to {}".format(pytorch_dump_path))
41 | torch.save(model.state_dict(), pytorch_dump_path)
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | ## Required parameters
47 | parser.add_argument("--tf_checkpoint_path",
48 | default = None,
49 | type = str,
50 | required = True,
51 | help = "Path the TensorFlow checkpoint path.")
52 | parser.add_argument("--bert_config_file",
53 | default = None,
54 | type = str,
55 | required = True,
56 | help = "The config json file corresponding to the pre-trained BERT model. \n"
57 | "This specifies the model architecture.")
58 | parser.add_argument("--pytorch_dump_path",
59 | default = None,
60 | type = str,
61 | required = True,
62 | help = "Path to the output PyTorch model.")
63 | args = parser.parse_args()
64 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
65 | args.bert_config_file,
66 | args.pytorch_dump_path)
67 |
--------------------------------------------------------------------------------
/layers/embedding.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | from pytorch_pretrained_bert.modeling import BertModel, BertConfig
4 |
5 | class BertEmbedder(nn.Module):
6 | def __init__(self, model, bert_pretrained_path,
7 | freeze=True, embedding_dim=768, use_cuda=True, bert_mode="weighted",):
8 | super(BertEmbedder, self).__init__()
9 | self.bert_pretrained_path = bert_pretrained_path
10 | self.is_freeze = freeze
11 | self.embedding_dim = embedding_dim
12 | self.model = model
13 | self.use_cuda = use_cuda
14 | self.bert_mode = bert_mode
15 | if self.bert_mode == "weighted":
16 | self.bert_weights = nn.Parameter(torch.FloatTensor(12, 1))
17 | self.bert_gamma = nn.Parameter(torch.FloatTensor(1, 1))
18 |
19 | if use_cuda:
20 | self.cuda()
21 |
22 | self.init_weights()
23 |
24 | def init_weights(self):
25 | if self.bert_mode == "weighted":
26 | nn.init.xavier_normal(self.bert_gamma)
27 | nn.init.xavier_normal(self.bert_weights)
28 |
29 | def forward(self, *batch):
30 | input_ids, input_mask, input_type_ids = batch[:3]
31 | all_encoder_layers, _ = self.model(input_ids.long(), token_type_ids=input_type_ids.long(), attention_mask=input_mask.long())
32 | if self.bert_mode == "last":
33 | return all_encoder_layers[-1]
34 | elif self.bert_mode == "weighted":
35 | all_encoder_layers = torch.stack([a * b for a, b in zip(all_encoder_layers, self.bert_weights)])
36 | return self.bert_gamma * torch.sum(all_encoder_layers, dim=0)
37 |
38 | def freeze(self):
39 | for param in self.model.parameters():
40 | param.requires_grad = False
41 |
42 | def unfreeze(self):
43 | for param in self.model.parameters():
44 | param.requires_grad = True
45 |
46 | def freeze_to(self, to=-1):
47 | idx = 0
48 | if to < 0:
49 | to = len(self.model.encoder.layer) + to + 1
50 | for idx in range(to):
51 | for param in self.model.encoder.layer[idx].parameters():
52 | param.requires_grad = False
53 | print("Embeddings freezed to {}".format(to))
54 | to = len(self.model.encoder.layer)
55 | for idx in range(idx, to):
56 | for param in self.model.encoder.layer[idx].parameters():
57 | param.requires_grad = True
58 |
59 | @classmethod
60 | def create(cls,
61 | bert_pretrained_path, embedding_dim=768, use_cuda=True, bert_mode="weighted",
62 | freeze=True):
63 | model = BertModel.from_pretrained(bert_pretrained_path)
64 | #if use_cuda:
65 | # device = torch.device("cuda")
66 | # map_location = "cuda"
67 | #else:
68 | # map_location = "cpu"
69 | # device = torch.device("cpu")
70 | #model = model.to(device)
71 | model = cls(model=model, embedding_dim=embedding_dim, use_cuda=use_cuda, bert_mode=bert_mode,
72 | bert_pretrained_path=bert_pretrained_path, freeze=freeze)
73 | if freeze:
74 | model.freeze()
75 | return model
76 |
--------------------------------------------------------------------------------
/layers/.ipynb_checkpoints/embedding-checkpoint.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | from pytorch_pretrained_bert.modeling import BertModel, BertConfig
4 |
5 | class BertEmbedder(nn.Module):
6 | def __init__(self, model, bert_pretrained_path,
7 | freeze=True, embedding_dim=768, use_cuda=True, bert_mode="weighted",):
8 | super(BertEmbedder, self).__init__()
9 | self.bert_pretrained_path = bert_pretrained_path
10 | self.is_freeze = freeze
11 | self.embedding_dim = embedding_dim
12 | self.model = model
13 | self.use_cuda = use_cuda
14 | self.bert_mode = bert_mode
15 | if self.bert_mode == "weighted":
16 | self.bert_weights = nn.Parameter(torch.FloatTensor(12, 1))
17 | self.bert_gamma = nn.Parameter(torch.FloatTensor(1, 1))
18 |
19 | if use_cuda:
20 | self.cuda()
21 |
22 | self.init_weights()
23 |
24 | def init_weights(self):
25 | if self.bert_mode == "weighted":
26 | nn.init.xavier_normal(self.bert_gamma)
27 | nn.init.xavier_normal(self.bert_weights)
28 |
29 | def forward(self, *batch):
30 | input_ids, input_mask, input_type_ids = batch[:3]
31 | all_encoder_layers, _ = self.model(input_ids.long(), token_type_ids=input_type_ids.long(), attention_mask=input_mask.long())
32 | if self.bert_mode == "last":
33 | return all_encoder_layers[-1]
34 | elif self.bert_mode == "weighted":
35 | all_encoder_layers = torch.stack([a * b for a, b in zip(all_encoder_layers, self.bert_weights)])
36 | return self.bert_gamma * torch.sum(all_encoder_layers, dim=0)
37 |
38 | def freeze(self):
39 | for param in self.model.parameters():
40 | param.requires_grad = False
41 |
42 | def unfreeze(self):
43 | for param in self.model.parameters():
44 | param.requires_grad = True
45 |
46 | def freeze_to(self, to=-1):
47 | idx = 0
48 | if to < 0:
49 | to = len(self.model.encoder.layer) + to + 1
50 | for idx in range(to):
51 | for param in self.model.encoder.layer[idx].parameters():
52 | param.requires_grad = False
53 | print("Embeddings freezed to {}".format(to))
54 | to = len(self.model.encoder.layer)
55 | for idx in range(idx, to):
56 | for param in self.model.encoder.layer[idx].parameters():
57 | param.requires_grad = True
58 |
59 | @classmethod
60 | def create(cls,
61 | bert_pretrained_path, embedding_dim=768, use_cuda=True, bert_mode="weighted",
62 | freeze=True):
63 | model = BertModel.from_pretrained(bert_pretrained_path)
64 | #if use_cuda:
65 | # device = torch.device("cuda")
66 | # map_location = "cuda"
67 | #else:
68 | # map_location = "cpu"
69 | # device = torch.device("cpu")
70 | #model = model.to(device)
71 | model = cls(model=model, embedding_dim=embedding_dim, use_cuda=use_cuda, bert_mode=bert_mode,
72 | bert_pretrained_path=bert_pretrained_path, freeze=freeze)
73 | if freeze:
74 | model.freeze()
75 | return model
76 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/convert_gpt2_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert OpenAI GPT checkpoint."""
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import argparse
20 | from io import open
21 |
22 | import torch
23 |
24 | from pytorch_pretrained_bert.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME,
25 | GPT2Config,
26 | GPT2Model,
27 | load_tf_weights_in_gpt2)
28 |
29 |
30 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path):
31 | # Construct model
32 | if gpt2_config_file == "":
33 | config = GPT2Config()
34 | else:
35 | config = GPT2Config(gpt2_config_file)
36 | model = GPT2Model(config)
37 |
38 | # Load weights from numpy
39 | load_tf_weights_in_gpt2(model, gpt2_checkpoint_path)
40 |
41 | # Save pytorch-model
42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
45 | torch.save(model.state_dict(), pytorch_weights_dump_path)
46 | print("Save configuration file to {}".format(pytorch_config_dump_path))
47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
48 | f.write(config.to_json_string())
49 |
50 |
51 | if __name__ == "__main__":
52 | parser = argparse.ArgumentParser()
53 | ## Required parameters
54 | parser.add_argument("--gpt2_checkpoint_path",
55 | default = None,
56 | type = str,
57 | required = True,
58 | help = "Path the TensorFlow checkpoint path.")
59 | parser.add_argument("--pytorch_dump_folder_path",
60 | default = None,
61 | type = str,
62 | required = True,
63 | help = "Path to the output PyTorch model.")
64 | parser.add_argument("--gpt2_config_file",
65 | default = "",
66 | type = str,
67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n"
68 | "This specifies the model architecture.")
69 | args = parser.parse_args()
70 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path,
71 | args.gpt2_config_file,
72 | args.pytorch_dump_folder_path)
73 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert OpenAI GPT checkpoint."""
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import argparse
20 | from io import open
21 |
22 | import torch
23 |
24 | from pytorch_pretrained_bert.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME,
25 | OpenAIGPTConfig,
26 | OpenAIGPTModel,
27 | load_tf_weights_in_openai_gpt)
28 |
29 |
30 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
31 | # Construct model
32 | if openai_config_file == "":
33 | config = OpenAIGPTConfig()
34 | else:
35 | config = OpenAIGPTConfig(openai_config_file)
36 | model = OpenAIGPTModel(config)
37 |
38 | # Load weights from numpy
39 | load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path)
40 |
41 | # Save pytorch-model
42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
45 | torch.save(model.state_dict(), pytorch_weights_dump_path)
46 | print("Save configuration file to {}".format(pytorch_config_dump_path))
47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
48 | f.write(config.to_json_string())
49 |
50 |
51 | if __name__ == "__main__":
52 | parser = argparse.ArgumentParser()
53 | ## Required parameters
54 | parser.add_argument("--openai_checkpoint_folder_path",
55 | default = None,
56 | type = str,
57 | required = True,
58 | help = "Path the TensorFlow checkpoint path.")
59 | parser.add_argument("--pytorch_dump_folder_path",
60 | default = None,
61 | type = str,
62 | required = True,
63 | help = "Path to the output PyTorch model.")
64 | parser.add_argument("--openai_config_file",
65 | default = "",
66 | type = str,
67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n"
68 | "This specifies the model architecture.")
69 | args = parser.parse_args()
70 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path,
71 | args.openai_config_file,
72 | args.pytorch_dump_folder_path)
73 |
--------------------------------------------------------------------------------
/layers/decoder.py:
--------------------------------------------------------------------------------
1 | from .crf import CRF
2 | import torch
3 |
4 | from torch.autograd import Variable
5 | from torch import nn
6 | from torch.nn import init
7 |
8 | class Linear(nn.Linear):
9 | def __init__(self,
10 | in_features: int,
11 | out_features: int,
12 | bias: bool = True):
13 | super(Linear, self).__init__(in_features, out_features, bias=bias)
14 | init.orthogonal_(self.weight)
15 |
16 | class Linears(nn.Module):
17 | def __init__(self,
18 | in_features,
19 | out_features,
20 | hiddens,
21 | bias=True,
22 | activation='tanh'):
23 | super(Linears, self).__init__()
24 | assert len(hiddens) > 0
25 |
26 | self.in_features = in_features
27 | self.out_features = self.output_size = out_features
28 |
29 | in_dims = [in_features] + hiddens[:-1]
30 | self.linears = nn.ModuleList([Linear(in_dim, out_dim, bias=bias)
31 | for in_dim, out_dim
32 | in zip(in_dims, hiddens)])
33 | self.output_linear = Linear(hiddens[-1], out_features, bias=bias)
34 | self.activation = activation
35 |
36 | def forward(self, inputs):
37 | linear_outputs = inputs
38 | for linear in self.linears:
39 | linear_outputs = linear.forward(linear_outputs)
40 | if self.activation == 'tanh':
41 | linear_outputs = torch.tanh(linear_outputs)
42 | else:
43 | linear_outputs = torch.relu(linear_outputs)
44 | return self.output_linear.forward(linear_outputs)
45 |
46 | class CRFDecoder(nn.Module):
47 | def __init__(self, label_size, input_dim, input_dropout=0.5, activation='tanh'):
48 | super(CRFDecoder, self).__init__()
49 | self.input_dim = input_dim
50 | self.input_dropout = nn.Dropout(p=input_dropout)
51 | self.linear = Linears(in_features=input_dim,
52 | out_features=label_size,
53 | hiddens=[input_dim // 2],
54 | activation=activation)
55 | self.crf = CRF(label_size+2)
56 | self.label_size = label_size
57 |
58 | def forward_model(self, inputs):
59 | batch_size, seq_len, input_dim = inputs.size()
60 | output = inputs.contiguous().view(-1, self.input_dim)
61 | output = self.input_dropout(output)
62 | # Fully-connected layer
63 | output = self.linear.forward(output)
64 | output = output.view(batch_size, seq_len, self.label_size)
65 | return output
66 |
67 | def forward(self, inputs, labels_mask):
68 | self.eval()
69 | lens = labels_mask.sum(-1)
70 | logits = self.forward_model(inputs)
71 | logits = self.crf.pad_logits(logits)
72 | scores, preds = self.crf.viterbi_decode(logits, lens)
73 | self.train()
74 |
75 | return preds
76 |
77 | def score(self, inputs, labels_mask, labels):
78 | lens = labels_mask.sum(-1)
79 | logits = self.forward_model(inputs)
80 | logits = self.crf.pad_logits(logits)
81 | norm_score = self.crf.calc_norm_score(logits, lens)
82 | labels = labels[:, :logits.size(1)]
83 | gold_score = self.crf.calc_gold_score(logits, labels, lens)
84 | loglik = gold_score - norm_score
85 | return -loglik.mean()
86 |
--------------------------------------------------------------------------------
/layers/.ipynb_checkpoints/decoder-checkpoint.py:
--------------------------------------------------------------------------------
1 | from .crf import CRF
2 | import torch
3 |
4 | from torch.autograd import Variable
5 | from torch import nn
6 | from torch.nn import init
7 |
8 | class Linear(nn.Linear):
9 | def __init__(self,
10 | in_features: int,
11 | out_features: int,
12 | bias: bool = True):
13 | super(Linear, self).__init__(in_features, out_features, bias=bias)
14 | init.orthogonal_(self.weight)
15 |
16 | class Linears(nn.Module):
17 | def __init__(self,
18 | in_features,
19 | out_features,
20 | hiddens,
21 | bias=True,
22 | activation='tanh'):
23 | super(Linears, self).__init__()
24 | assert len(hiddens) > 0
25 |
26 | self.in_features = in_features
27 | self.out_features = self.output_size = out_features
28 |
29 | in_dims = [in_features] + hiddens[:-1]
30 | self.linears = nn.ModuleList([Linear(in_dim, out_dim, bias=bias)
31 | for in_dim, out_dim
32 | in zip(in_dims, hiddens)])
33 | self.output_linear = Linear(hiddens[-1], out_features, bias=bias)
34 | self.activation = activation
35 |
36 | def forward(self, inputs):
37 | linear_outputs = inputs
38 | for linear in self.linears:
39 | linear_outputs = linear.forward(linear_outputs)
40 | if self.activation == 'tanh':
41 | linear_outputs = torch.tanh(linear_outputs)
42 | else:
43 | linear_outputs = torch.relu(linear_outputs)
44 | return self.output_linear.forward(linear_outputs)
45 |
46 | class CRFDecoder(nn.Module):
47 | def __init__(self, label_size, input_dim, input_dropout=0.5, activation='tanh'):
48 | super(CRFDecoder, self).__init__()
49 | self.input_dim = input_dim
50 | self.input_dropout = nn.Dropout(p=input_dropout)
51 | self.linear = Linears(in_features=input_dim,
52 | out_features=label_size,
53 | hiddens=[input_dim // 2],
54 | activation=activation)
55 | self.crf = CRF(label_size+2)
56 | self.label_size = label_size
57 |
58 | def forward_model(self, inputs):
59 | batch_size, seq_len, input_dim = inputs.size()
60 | output = inputs.contiguous().view(-1, self.input_dim)
61 | output = self.input_dropout(output)
62 | # Fully-connected layer
63 | output = self.linear.forward(output)
64 | output = output.view(batch_size, seq_len, self.label_size)
65 | return output
66 |
67 | def forward(self, inputs, labels_mask):
68 | self.eval()
69 | lens = labels_mask.sum(-1)
70 | logits = self.forward_model(inputs)
71 | logits = self.crf.pad_logits(logits)
72 | scores, preds = self.crf.viterbi_decode(logits, lens)
73 | self.train()
74 |
75 | return preds
76 |
77 | def score(self, inputs, labels_mask, labels):
78 | lens = labels_mask.sum(-1)
79 | logits = self.forward_model(inputs)
80 | logits = self.crf.pad_logits(logits)
81 | norm_score = self.crf.calc_norm_score(logits, lens)
82 | labels = labels[:, :logits.size(1)]
83 | gold_score = self.crf.calc_gold_score(logits, labels, lens)
84 | loglik = gold_score - norm_score
85 | return -loglik.mean()
86 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytorch_pretrained_bert.modeling import *
3 | from DataGen import DataGenerator
4 | from model import BertBiLSTMCRF
5 | import pickle
6 |
7 | def convert_rel_matrix_to_str(predicted_rel, rel_list):
8 | index_rels = []
9 | label_rels = []
10 |
11 | for i, pos in enumerate(predicted_rel[0]):
12 | index_rel = []
13 | label_rel = []
14 | for j, val in enumerate(pos):
15 | if j % 5 == 0 and val == 1: #
16 | index_rel.append(int(j/5))
17 | label_rel.append("SP")
18 | elif j % 5 == 1 and val == 1: #
19 | index_rel.append(int(j/5))
20 | label_rel.append("N")
21 | elif j % 5 == 2 and val == 1:
22 | index_rel.append(int(j/5))
23 | label_rel.append("ST")
24 | elif j%5 == 3 and val == 1:
25 | index_rel.append(int(j/5))
26 | label_rel.append("PO")
27 | elif j%5 == 4 and val == 1:
28 | index_rel.append(int(j/5))
29 | label_rel.append("SC")
30 |
31 | if len(index_rel) == 0:
32 | index_rels.append([i])
33 | label_rels.append(["N"])
34 | else:
35 | index_rels.append(index_rel)
36 | label_rels.append(label_rel)
37 |
38 | return index_rels, label_rels
39 |
40 | def predict(model, featurized_sentence, rel_list, ner_list, use_extra=True):
41 | model.eval()
42 | input_tensor_1 = featurized_sentence[0].to("cuda")
43 | input_tensor_2 = featurized_sentence[1].to("cuda")
44 | input_tensor_3 = featurized_sentence[2].to("cuda")
45 | input_tensor_4 = featurized_sentence[3].to("cuda")
46 | input_tensor_5 = featurized_sentence[4].to("cuda")
47 | ner_logits, rel_logits = model([input_tensor_1, input_tensor_2, input_tensor_3, input_tensor_4, input_tensor_5])
48 |
49 | label_types = []
50 | ner_logits = ner_logits[0]
51 | for i in ner_logits:
52 | label_types.append(ner_list[i])
53 |
54 | if use_extra:
55 | rel_result = []
56 | ner_logits = ner_logits.tolist()
57 | print(rel_logits)
58 | index_rels, label_rels = convert_rel_matrix_to_str(rel_logits, rel_list)
59 | return label_types, index_rels, label_rels
60 |
61 | return label_types, None, None
62 |
63 | def load_model(save_path):
64 | with open(save_path,'rb') as f:
65 | data = pickle.load(f)
66 | return data["model"]
67 |
68 | if __name__ == "__main__":
69 | PRETRAINED_MODEL = "models/"
70 | BERT_PRETRAINED_PATH = "./multi_cased_L-12_H-768_A-12/"
71 | #VALID_PATH = "/data/loclh2/QABot/data/test_analysis.txt"
72 | batch_size = 32
73 | shuffle = True
74 | use_cuda = True
75 |
76 | data_gen = DataGenerator(model=BertModel, model_name=BERT_PRETRAINED_PATH)
77 |
78 | #model = BertBiLSTMCRF.create(16,
79 | # len(data_gen.rel_list),
80 | # BERT_PRETRAINED_PATH,
81 | # freeze=True,
82 | # rnn_layers=2,
83 | # input_dropout=0.1,
84 | # use_cuda=use_cuda,
85 | # hidden_size=64,
86 | # label_embedding_size=64,
87 | # enc_hidden_dim=64,
88 | # activation="tanh")
89 |
90 | # model.load_state_dict(torch.load(PRETRAINED_MODEL))
91 |
92 | model = load_model("models/rel_ner_v1_adam/bert_ner_epoches=50_valid_loss=10.281595188638438.pickle")
93 |
94 | sentence = "Con trai của Ronald là ai"
95 | featurized_sentence = data_gen.get_featurized_sentence(sentence)
96 |
97 | label_types, index_rels, label_rels = predict(model,
98 | featurized_sentence,
99 | data_gen.rel_list,
100 | data_gen.ner_list,
101 | use_extra=True)
102 | print("label_types : " + str(label_types))
103 | print("index_rels : " + str(index_rels))
104 | print("label_rels : " + str(label_rels))
105 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__main__.py:
--------------------------------------------------------------------------------
1 | # coding: utf8
2 | def main():
3 | import sys
4 | if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [
5 | "convert_tf_checkpoint_to_pytorch",
6 | "convert_openai_checkpoint",
7 | "convert_transfo_xl_checkpoint",
8 | "convert_gpt2_checkpoint",
9 | ]:
10 | print(
11 | "Should be used as one of: \n"
12 | ">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n"
13 | ">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n"
14 | ">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n"
15 | ">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`")
16 | else:
17 | if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
18 | try:
19 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
20 | except ImportError:
21 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
22 | "In that case, it requires TensorFlow to be installed. Please see "
23 | "https://www.tensorflow.org/install/ for installation instructions.")
24 | raise
25 |
26 | if len(sys.argv) != 5:
27 | # pylint: disable=line-too-long
28 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
29 | else:
30 | PYTORCH_DUMP_OUTPUT = sys.argv.pop()
31 | TF_CONFIG = sys.argv.pop()
32 | TF_CHECKPOINT = sys.argv.pop()
33 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
34 | elif sys.argv[1] == "convert_openai_checkpoint":
35 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
36 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2]
37 | PYTORCH_DUMP_OUTPUT = sys.argv[3]
38 | if len(sys.argv) == 5:
39 | OPENAI_GPT_CONFIG = sys.argv[4]
40 | else:
41 | OPENAI_GPT_CONFIG = ""
42 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
43 | OPENAI_GPT_CONFIG,
44 | PYTORCH_DUMP_OUTPUT)
45 | elif sys.argv[1] == "convert_transfo_xl_checkpoint":
46 | try:
47 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
48 | except ImportError:
49 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
50 | "In that case, it requires TensorFlow to be installed. Please see "
51 | "https://www.tensorflow.org/install/ for installation instructions.")
52 | raise
53 |
54 | if 'ckpt' in sys.argv[2].lower():
55 | TF_CHECKPOINT = sys.argv[2]
56 | TF_DATASET_FILE = ""
57 | else:
58 | TF_DATASET_FILE = sys.argv[2]
59 | TF_CHECKPOINT = ""
60 | PYTORCH_DUMP_OUTPUT = sys.argv[3]
61 | if len(sys.argv) == 5:
62 | TF_CONFIG = sys.argv[4]
63 | else:
64 | TF_CONFIG = ""
65 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE)
66 | else:
67 | try:
68 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch
69 | except ImportError:
70 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
71 | "In that case, it requires TensorFlow to be installed. Please see "
72 | "https://www.tensorflow.org/install/ for installation instructions.")
73 | raise
74 |
75 | TF_CHECKPOINT = sys.argv[2]
76 | PYTORCH_DUMP_OUTPUT = sys.argv[3]
77 | if len(sys.argv) == 5:
78 | TF_CONFIG = sys.argv[4]
79 | else:
80 | TF_CONFIG = ""
81 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
82 | if __name__ == '__main__':
83 | main()
84 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert Transformer XL checkpoint and datasets."""
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import argparse
20 | import os
21 | import sys
22 | from io import open
23 |
24 | import torch
25 |
26 | import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils
27 | from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME,
28 | WEIGHTS_NAME,
29 | TransfoXLConfig,
30 | TransfoXLLMHeadModel,
31 | load_tf_weights_in_transfo_xl)
32 | from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME,
33 | VOCAB_NAME)
34 |
35 | if sys.version_info[0] == 2:
36 | import cPickle as pickle
37 | else:
38 | import pickle
39 |
40 | # We do this to be able to load python 2 datasets pickles
41 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
42 | data_utils.Vocab = data_utils.TransfoXLTokenizer
43 | data_utils.Corpus = data_utils.TransfoXLCorpus
44 | sys.modules['data_utils'] = data_utils
45 | sys.modules['vocabulary'] = data_utils
46 |
47 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
48 | transfo_xl_config_file,
49 | pytorch_dump_folder_path,
50 | transfo_xl_dataset_file):
51 | if transfo_xl_dataset_file:
52 | # Convert a pre-processed corpus (see original TensorFlow repo)
53 | with open(transfo_xl_dataset_file, "rb") as fp:
54 | corpus = pickle.load(fp, encoding="latin1")
55 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
56 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME
57 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
58 | corpus_vocab_dict = corpus.vocab.__dict__
59 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)
60 |
61 | corpus_dict_no_vocab = corpus.__dict__
62 | corpus_dict_no_vocab.pop('vocab', None)
63 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME
64 | print("Save dataset to {}".format(pytorch_dataset_dump_path))
65 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)
66 |
67 | if tf_checkpoint_path:
68 | # Convert a pre-trained TensorFlow model
69 | config_path = os.path.abspath(transfo_xl_config_file)
70 | tf_path = os.path.abspath(tf_checkpoint_path)
71 |
72 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path))
73 | # Initialise PyTorch model
74 | if transfo_xl_config_file == "":
75 | config = TransfoXLConfig()
76 | else:
77 | config = TransfoXLConfig(transfo_xl_config_file)
78 | print("Building PyTorch model from configuration: {}".format(str(config)))
79 | model = TransfoXLLMHeadModel(config)
80 |
81 | model = load_tf_weights_in_transfo_xl(model, config, tf_path)
82 | # Save pytorch-model
83 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
84 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME)
85 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path)))
86 | torch.save(model.state_dict(), pytorch_weights_dump_path)
87 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path)))
88 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
89 | f.write(config.to_json_string())
90 |
91 |
92 | if __name__ == "__main__":
93 | parser = argparse.ArgumentParser()
94 | parser.add_argument("--pytorch_dump_folder_path",
95 | default = None,
96 | type = str,
97 | required = True,
98 | help = "Path to the folder to store the PyTorch model or dataset/vocab.")
99 | parser.add_argument("--tf_checkpoint_path",
100 | default = "",
101 | type = str,
102 | help = "An optional path to a TensorFlow checkpoint path to be converted.")
103 | parser.add_argument("--transfo_xl_config_file",
104 | default = "",
105 | type = str,
106 | help = "An optional config json file corresponding to the pre-trained BERT model. \n"
107 | "This specifies the model architecture.")
108 | parser.add_argument("--transfo_xl_dataset_file",
109 | default = "",
110 | type = str,
111 | help = "An optional dataset file to be converted in a vocabulary.")
112 | args = parser.parse_args()
113 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path,
114 | args.transfo_xl_config_file,
115 | args.pytorch_dump_folder_path,
116 | args.transfo_xl_dataset_file)
117 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/optimization_openai.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Open AI Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch optimization for OpenAI GPT model."""
16 |
17 | import math
18 | import torch
19 | from torch.optim import Optimizer
20 | from torch.optim.optimizer import required
21 | from torch.nn.utils import clip_grad_norm_
22 |
23 | def warmup_cosine(x, warmup=0.002):
24 | s = 1 if x <= warmup else 0
25 | return s*(x/warmup) + (1-s)*(0.5 * (1 + torch.cos(math.pi * x)))
26 |
27 | def warmup_constant(x, warmup=0.002):
28 | s = 1 if x <= warmup else 0
29 | return s*(x/warmup) + (1-s)*1
30 |
31 | def warmup_linear(x, warmup=0.002):
32 | s = 1 if x <= warmup else 0
33 | return (s*(x/warmup) + (1-s))*(1-x)
34 |
35 | SCHEDULES = {
36 | 'warmup_cosine':warmup_cosine,
37 | 'warmup_constant':warmup_constant,
38 | 'warmup_linear':warmup_linear,
39 | }
40 |
41 |
42 | class OpenAIAdam(Optimizer):
43 | """Implements Open AI version of Adam algorithm with weight decay fix.
44 | """
45 | def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1,
46 | b1=0.9, b2=0.999, e=1e-8, weight_decay=0,
47 | vector_l2=False, max_grad_norm=-1, **kwargs):
48 | if lr is not required and lr < 0.0:
49 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
50 | if schedule not in SCHEDULES:
51 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
52 | if not 0.0 <= warmup < 1.0 and not warmup == -1:
53 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
54 | if not 0.0 <= b1 < 1.0:
55 | raise ValueError("Invalid b1 parameter: {}".format(b1))
56 | if not 0.0 <= b2 < 1.0:
57 | raise ValueError("Invalid b2 parameter: {}".format(b2))
58 | if not e >= 0.0:
59 | raise ValueError("Invalid epsilon value: {}".format(e))
60 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
61 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2,
62 | max_grad_norm=max_grad_norm)
63 | super(OpenAIAdam, self).__init__(params, defaults)
64 |
65 | def get_lr(self):
66 | lr = []
67 | for group in self.param_groups:
68 | for p in group['params']:
69 | state = self.state[p]
70 | if len(state) == 0:
71 | return [0]
72 | if group['t_total'] != -1:
73 | schedule_fct = SCHEDULES[group['schedule']]
74 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
75 | else:
76 | lr_scheduled = group['lr']
77 | lr.append(lr_scheduled)
78 | return lr
79 |
80 | def step(self, closure=None):
81 | """Performs a single optimization step.
82 |
83 | Arguments:
84 | closure (callable, optional): A closure that reevaluates the model
85 | and returns the loss.
86 | """
87 | loss = None
88 | if closure is not None:
89 | loss = closure()
90 |
91 | for group in self.param_groups:
92 | for p in group['params']:
93 | if p.grad is None:
94 | continue
95 | grad = p.grad.data
96 | if grad.is_sparse:
97 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
98 |
99 | state = self.state[p]
100 |
101 | # State initialization
102 | if len(state) == 0:
103 | state['step'] = 0
104 | # Exponential moving average of gradient values
105 | state['exp_avg'] = torch.zeros_like(p.data)
106 | # Exponential moving average of squared gradient values
107 | state['exp_avg_sq'] = torch.zeros_like(p.data)
108 |
109 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
110 | beta1, beta2 = group['b1'], group['b2']
111 |
112 | state['step'] += 1
113 |
114 | # Add grad clipping
115 | if group['max_grad_norm'] > 0:
116 | clip_grad_norm_(p, group['max_grad_norm'])
117 |
118 | # Decay the first and second moment running average coefficient
119 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
120 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
121 | denom = exp_avg_sq.sqrt().add_(group['e'])
122 |
123 | bias_correction1 = 1 - beta1 ** state['step']
124 | bias_correction2 = 1 - beta2 ** state['step']
125 |
126 | if group['t_total'] != -1:
127 | schedule_fct = SCHEDULES[group['schedule']]
128 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
129 | else:
130 | lr_scheduled = group['lr']
131 |
132 | step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
133 |
134 | p.data.addcdiv_(-step_size, exp_avg, denom)
135 |
136 | # Add weight decay at the end (fixed version)
137 | if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0:
138 | p.data.add_(-lr_scheduled * group['weight_decay'], p.data)
139 |
140 | return loss
141 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytorch_pretrained_bert.modeling import *
3 | from tqdm import tqdm
4 | from DataGen import DataGenerator
5 | from model import BertBiLSTMCRF
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import pickle
9 |
10 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
11 | torch.backends.cudnn.benchmark = True
12 | torch.backends.cudnn.enabled = True
13 |
14 |
15 | def evaluate(model, epoch, valid_generator, device):
16 | #load valid data_generator
17 | num_valid = 0
18 | num_correct = total_loss = 0
19 | batch_size = 0
20 | idx = 0
21 |
22 | for batch_input_ids, batch_mask_ids, batch_input_type_ids, batch_label_ids, batch_rel_matrix in valid_generator:
23 | #compute loss
24 | loss = model.score([batch_input_ids.to(device), batch_mask_ids.to(device), batch_input_type_ids.to(device), batch_label_ids.to(device), batch_rel_matrix.to(device)])
25 |
26 | idx += 1
27 | total_loss += loss.mean().item()
28 |
29 | loss = total_loss / idx
30 |
31 | print('\nValidation : Loss: {:.6f} Accuracy: {}/{} ({:.4f}%)\n'.format(loss, 0, 0, 0))
32 |
33 | return loss, 0
34 |
35 | def draw_learning_curve(epoch, training_loss, valid_loss,
36 | save_path="models/learning_curve.png",
37 | title="Learning Curve"):
38 | plt.plot(epoch, training_loss, '-b')
39 | plt.plot(epoch, valid_loss, '--r')
40 |
41 | plt.xlabel('Epoch')
42 | plt.ylabel('Loss')
43 | plt.legend(['Training Loss', 'Validation Loss'])
44 | plt.title(title)
45 |
46 | # save image
47 | plt.savefig(save_path) # should before show method
48 |
49 | def train(model, data_gen, train_path, valid_path, start_epoch, num_epoches, save_path, device,
50 | batch_size=32,
51 | decay_rate=0.1,
52 | learning_rate=0.001,
53 | momentum=0.9,
54 | update_lr_epoches=[],
55 | shuffle=True):
56 |
57 | model.train()
58 | optimizer = torch.optim.SGD(model.parameters(),
59 | lr=learning_rate,
60 | momentum=0.9)
61 | loss_history = [[], [], []]
62 |
63 |
64 | for epoch in range(start_epoch, num_epoches + start_epoch):
65 | step = 0
66 | train_gen = data_gen.get_generator(train_path, batch_size, is_shuffle=shuffle)
67 | if epoch in update_lr_epoches:
68 | learning_rate = learning_rate * decay_rate
69 | for param_group in optimizer.param_groups:
70 | param_group['lr'] = learning_rate
71 | print('Updating the learning rate at epoch: ' + str(epoch) + ', value: ' + str(learning_rate))
72 |
73 | train_loss = []
74 | for batch_input_ids, batch_mask_ids, batch_input_type_ids, batch_label_types, batch_rel_matrix in train_gen:
75 | model.zero_grad()
76 | loss = model.score([batch_input_ids.to(device), batch_mask_ids.to(device), batch_input_type_ids.to(device), batch_label_types.to(device), batch_rel_matrix.to(device)])
77 |
78 | loss.backward()
79 |
80 | optimizer.step()
81 | optimizer.zero_grad()
82 | loss = loss.data.cpu().tolist()
83 | train_loss.append(loss)
84 | print('Training: Epoch %d, step %5d / %d loss: %.3f'%(epoch + 1, step + 1, data_gen.num_samples/batch_size + 1, loss))
85 | step += 1
86 |
87 | valid_gen = data_gen.get_generator(valid_path, batch_size, is_shuffle=shuffle)
88 | valid_loss, valid_accuracy = evaluate(model, epoch, valid_gen, device )
89 |
90 | ### visualization the learning curve
91 | train_loss_value = np.mean(train_loss)
92 | valid_loss_value = np.mean(valid_loss)
93 |
94 | loss_history[0].append(epoch + 1)
95 | loss_history[1].append(train_loss_value)
96 | loss_history[2].append(valid_loss_value)
97 |
98 | draw_learning_curve(loss_history[0], loss_history[1], loss_history[2])
99 |
100 | model.train()
101 |
102 | save_data = {"model": model, "history": loss_history}
103 | with open(save_path + "bert_ner_epoches=" + str(epoch + 1) + "_valid_loss=" + str(valid_loss) +'.pickle', 'wb') as handle:
104 | pickle.dump(save_data, handle, protocol=2)
105 |
106 | def load_pretrain_model(model, pretrained_path):
107 | model.load_state_dict(torch.load(pretrained_path))
108 | return model
109 |
110 | if __name__ == "__main__":
111 | TRAIN_PATH = "data/train_analysis.txt"
112 | VALID_PATH = "data/test_analysis.txt"
113 | PRETRAINED_PATH = "./multi_cased_L-12_H-768_A-12/"
114 | SAVE_PATH = "models/rel_ner_v1_adam/"
115 | batch_size = 4
116 | shuffle = True
117 | use_cuda = True
118 | use_extra = True
119 | freeze = False ## !(fine tuning) or not
120 | start_epoch=0
121 | num_epoches = 50
122 | learning_rate = 0.001
123 | decay_rate = 0.1
124 | update_lr_epoches = [35,]
125 | momentum=0.9
126 |
127 | if use_cuda:
128 | device = "cuda"
129 | else:
130 | device = "cpu"
131 |
132 | data_gen = DataGenerator(model=BertModel,
133 | model_name=PRETRAINED_PATH,
134 | device=device)
135 |
136 | ner_size = len(data_gen.ner_list)
137 | rel_size = len(data_gen.rel_list)
138 |
139 | model = BertBiLSTMCRF.create(ner_size,
140 | rel_size,
141 | PRETRAINED_PATH,
142 | freeze=freeze,
143 | rnn_layers=2,
144 | input_dropout=0.1,
145 | use_cuda=use_cuda,
146 | use_extra=use_extra,
147 | hidden_size=64,
148 | label_embedding_size=32,
149 | enc_hidden_dim=64,
150 | activation="tanh")
151 |
152 | train(model=model,
153 | data_gen=data_gen,
154 | train_path=TRAIN_PATH,
155 | valid_path=VALID_PATH,
156 | start_epoch=start_epoch,
157 | num_epoches=num_epoches,
158 | save_path=SAVE_PATH,
159 | device=device,
160 | batch_size=batch_size,
161 | decay_rate=decay_rate,
162 | learning_rate=learning_rate,
163 | momentum=momentum,
164 | update_lr_epoches=update_lr_epoches,
165 | shuffle=shuffle)
166 |
--------------------------------------------------------------------------------
/layers/crf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | # TODO: move to utils
6 | def log_sum_exp(tensor, dim=0):
7 | """LogSumExp operation."""
8 | m, _ = torch.max(tensor, dim)
9 | m_exp = m.unsqueeze(-1).expand_as(tensor)
10 | return m + torch.log(torch.sum(torch.exp(tensor - m_exp), dim))
11 |
12 |
13 | def sequence_mask(lens, max_len=None):
14 | batch_size = lens.size(0)
15 |
16 | if max_len is None:
17 | max_len = lens.max().item()
18 |
19 | ranges = torch.arange(0, max_len).long()
20 | ranges = ranges.unsqueeze(0).expand(batch_size, max_len)
21 |
22 | if lens.data.is_cuda:
23 | ranges = ranges.cuda()
24 |
25 | lens_exp = lens.unsqueeze(1).expand_as(ranges)
26 | mask = ranges < lens_exp
27 |
28 | return mask
29 |
30 |
31 | class CRF(nn.Module):
32 | def __init__(self, label_size):
33 | super(CRF, self).__init__()
34 |
35 | self.label_size = label_size
36 | self.start = self.label_size - 2
37 | self.end = self.label_size - 1
38 | transition = torch.randn(self.label_size, self.label_size)
39 | self.transition = nn.Parameter(transition)
40 | self.initialize()
41 |
42 | def initialize(self):
43 | self.transition.data[:, self.end] = -100.0
44 | self.transition.data[self.start, :] = -100.0
45 |
46 | def pad_logits(self, logits):
47 | # lens = lens.data
48 | batch_size, seq_len, label_num = logits.size()
49 | # pads = Variable(logits.data.new(batch_size, seq_len, 2).fill_(-1000.0),
50 | # requires_grad=False)
51 | pads = logits.new_full((batch_size, seq_len, 2), -1000.0,
52 | requires_grad=False)
53 | logits = torch.cat([logits, pads], dim=2)
54 | return logits
55 |
56 | def calc_binary_score(self, labels, lens):
57 | batch_size, seq_len = labels.size()
58 |
59 | # labels_ext = Variable(labels.data.new(batch_size, seq_len + 2))
60 | labels_ext = labels.new_empty((batch_size, seq_len + 2))
61 | labels_ext[:, 0] = self.start
62 | labels_ext[:, 1:-1] = labels
63 | mask = sequence_mask(lens + 1, max_len=(seq_len + 2)).long()
64 | # pad_stop = Variable(labels.data.new(1).fill_(self.end))
65 | pad_stop = labels.new_full((1,), self.end, requires_grad=False)
66 | pad_stop = pad_stop.unsqueeze(-1).expand(batch_size, seq_len + 2)
67 | labels_ext = (1 - mask) * pad_stop + mask * labels_ext
68 | labels = labels_ext
69 |
70 | trn = self.transition
71 | trn_exp = trn.unsqueeze(0).expand(batch_size, *trn.size())
72 | lbl_r = labels[:, 1:]
73 | lbl_rexp = lbl_r.unsqueeze(-1).expand(*lbl_r.size(), trn.size(0))
74 | trn_row = torch.gather(trn_exp, 1, lbl_rexp)
75 |
76 |
77 |
78 | lbl_lexp = labels[:, :-1].unsqueeze(-1)
79 | trn_scr = torch.gather(trn_row, 2, lbl_lexp)
80 | trn_scr = trn_scr.squeeze(-1)
81 |
82 | mask = sequence_mask(lens + 1).float()
83 | trn_scr = trn_scr * mask
84 | score = trn_scr
85 |
86 | return score
87 |
88 | def calc_unary_score(self, logits, labels, lens):
89 | labels = labels[:, :logits.size(1)]
90 |
91 | labels_exp = labels.unsqueeze(-1)
92 |
93 | scores = torch.gather(logits, 2, labels_exp).squeeze(-1)
94 | mask = sequence_mask(lens).float()
95 | scores = scores * mask
96 | return scores
97 |
98 | def calc_gold_score(self, logits, labels, lens):
99 | unary_score = self.calc_unary_score(logits, labels, lens).sum(
100 | 1).squeeze(-1)
101 | binary_score = self.calc_binary_score(labels, lens).sum(1).squeeze(-1)
102 | return unary_score + binary_score
103 |
104 | def calc_norm_score(self, logits, lens):
105 | batch_size, seq_len, feat_dim = logits.size()
106 | # alpha = logits.data.new(batch_size, self.label_size).fill_(-10000.0)
107 | alpha = logits.new_full((batch_size, self.label_size), -100.0)
108 | alpha[:, self.start] = 0
109 | # alpha = Variable(alpha)
110 | lens_ = lens.clone()
111 |
112 | logits_t = logits.transpose(1, 0)
113 | for logit in logits_t:
114 | logit_exp = logit.unsqueeze(-1).expand(batch_size,
115 | *self.transition.size())
116 | alpha_exp = alpha.unsqueeze(1).expand(batch_size,
117 | *self.transition.size())
118 | trans_exp = self.transition.unsqueeze(0).expand_as(alpha_exp)
119 | mat = logit_exp + alpha_exp + trans_exp
120 | alpha_nxt = log_sum_exp(mat, 2).squeeze(-1)
121 |
122 | mask = (lens_ > 0).float().unsqueeze(-1).expand_as(alpha)
123 | alpha = mask * alpha_nxt + (1 - mask) * alpha
124 | lens_ = lens_ - 1
125 |
126 | alpha = alpha + self.transition[self.end].unsqueeze(0).expand_as(alpha)
127 | norm = log_sum_exp(alpha, 1).squeeze(-1)
128 |
129 | return norm
130 |
131 | def viterbi_decode(self, logits, lens):
132 | """Borrowed from pytorch tutorial
133 | Arguments:
134 | logits: [batch_size, seq_len, n_labels] FloatTensor
135 | lens: [batch_size] LongTensor
136 | """
137 | batch_size, seq_len, n_labels = logits.size()
138 | # vit = logits.data.new(batch_size, self.label_size).fill_(-10000)
139 | vit = logits.new_full((batch_size, self.label_size), -100.0)
140 | vit[:, self.start] = 0
141 | # vit = Variable(vit)
142 | c_lens = lens.clone()
143 |
144 | logits_t = logits.transpose(1, 0)
145 | pointers = []
146 | for logit in logits_t:
147 | vit_exp = vit.unsqueeze(1).expand(batch_size, n_labels, n_labels)
148 | trn_exp = self.transition.unsqueeze(0).expand_as(vit_exp)
149 | vit_trn_sum = vit_exp + trn_exp
150 | vt_max, vt_argmax = vit_trn_sum.max(2)
151 |
152 | vt_max = vt_max.squeeze(-1)
153 | vit_nxt = vt_max + logit
154 | pointers.append(vt_argmax.squeeze(-1).unsqueeze(0))
155 |
156 | mask = (c_lens > 0).float().unsqueeze(-1).expand_as(vit_nxt)
157 | vit = mask * vit_nxt + (1 - mask) * vit
158 |
159 | mask = (c_lens == 1).float().unsqueeze(-1).expand_as(vit_nxt)
160 | vit += mask * self.transition[self.end].unsqueeze(
161 | 0).expand_as(vit_nxt)
162 |
163 | c_lens = c_lens - 1
164 |
165 | pointers = torch.cat(pointers)
166 | scores, idx = vit.max(1)
167 | # idx = idx.squeeze(-1)
168 | paths = [idx.unsqueeze(1)]
169 | for argmax in reversed(pointers):
170 | idx_exp = idx.unsqueeze(-1)
171 | idx = torch.gather(argmax, 1, idx_exp)
172 | idx = idx.squeeze(-1)
173 |
174 | paths.insert(0, idx.unsqueeze(1))
175 |
176 | paths = torch.cat(paths[1:], 1)
177 | scores = scores.squeeze(-1)
178 |
179 | return scores, paths
180 |
--------------------------------------------------------------------------------
/layers/.ipynb_checkpoints/crf-checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | # TODO: move to utils
6 | def log_sum_exp(tensor, dim=0):
7 | """LogSumExp operation."""
8 | m, _ = torch.max(tensor, dim)
9 | m_exp = m.unsqueeze(-1).expand_as(tensor)
10 | return m + torch.log(torch.sum(torch.exp(tensor - m_exp), dim))
11 |
12 |
13 | def sequence_mask(lens, max_len=None):
14 | batch_size = lens.size(0)
15 |
16 | if max_len is None:
17 | max_len = lens.max().item()
18 |
19 | ranges = torch.arange(0, max_len).long()
20 | ranges = ranges.unsqueeze(0).expand(batch_size, max_len)
21 |
22 | if lens.data.is_cuda:
23 | ranges = ranges.cuda()
24 |
25 | lens_exp = lens.unsqueeze(1).expand_as(ranges)
26 | mask = ranges < lens_exp
27 |
28 | return mask
29 |
30 |
31 | class CRF(nn.Module):
32 | def __init__(self, label_size):
33 | super(CRF, self).__init__()
34 |
35 | self.label_size = label_size
36 | self.start = self.label_size - 2
37 | self.end = self.label_size - 1
38 | transition = torch.randn(self.label_size, self.label_size)
39 | self.transition = nn.Parameter(transition)
40 | self.initialize()
41 |
42 | def initialize(self):
43 | self.transition.data[:, self.end] = -100.0
44 | self.transition.data[self.start, :] = -100.0
45 |
46 | def pad_logits(self, logits):
47 | # lens = lens.data
48 | batch_size, seq_len, label_num = logits.size()
49 | # pads = Variable(logits.data.new(batch_size, seq_len, 2).fill_(-1000.0),
50 | # requires_grad=False)
51 | pads = logits.new_full((batch_size, seq_len, 2), -1000.0,
52 | requires_grad=False)
53 | logits = torch.cat([logits, pads], dim=2)
54 | return logits
55 |
56 | def calc_binary_score(self, labels, lens):
57 | batch_size, seq_len = labels.size()
58 |
59 | # labels_ext = Variable(labels.data.new(batch_size, seq_len + 2))
60 | labels_ext = labels.new_empty((batch_size, seq_len + 2))
61 | labels_ext[:, 0] = self.start
62 | labels_ext[:, 1:-1] = labels
63 | mask = sequence_mask(lens + 1, max_len=(seq_len + 2)).long()
64 | # pad_stop = Variable(labels.data.new(1).fill_(self.end))
65 | pad_stop = labels.new_full((1,), self.end, requires_grad=False)
66 | pad_stop = pad_stop.unsqueeze(-1).expand(batch_size, seq_len + 2)
67 | labels_ext = (1 - mask) * pad_stop + mask * labels_ext
68 | labels = labels_ext
69 |
70 | trn = self.transition
71 | trn_exp = trn.unsqueeze(0).expand(batch_size, *trn.size())
72 | lbl_r = labels[:, 1:]
73 | lbl_rexp = lbl_r.unsqueeze(-1).expand(*lbl_r.size(), trn.size(0))
74 | trn_row = torch.gather(trn_exp, 1, lbl_rexp)
75 |
76 |
77 |
78 | lbl_lexp = labels[:, :-1].unsqueeze(-1)
79 | trn_scr = torch.gather(trn_row, 2, lbl_lexp)
80 | trn_scr = trn_scr.squeeze(-1)
81 |
82 | mask = sequence_mask(lens + 1).float()
83 | trn_scr = trn_scr * mask
84 | score = trn_scr
85 |
86 | return score
87 |
88 | def calc_unary_score(self, logits, labels, lens):
89 | labels = labels[:, :logits.size(1)]
90 |
91 | labels_exp = labels.unsqueeze(-1)
92 |
93 | scores = torch.gather(logits, 2, labels_exp).squeeze(-1)
94 | mask = sequence_mask(lens).float()
95 | scores = scores * mask
96 | return scores
97 |
98 | def calc_gold_score(self, logits, labels, lens):
99 | unary_score = self.calc_unary_score(logits, labels, lens).sum(
100 | 1).squeeze(-1)
101 | binary_score = self.calc_binary_score(labels, lens).sum(1).squeeze(-1)
102 | return unary_score + binary_score
103 |
104 | def calc_norm_score(self, logits, lens):
105 | batch_size, seq_len, feat_dim = logits.size()
106 | # alpha = logits.data.new(batch_size, self.label_size).fill_(-10000.0)
107 | alpha = logits.new_full((batch_size, self.label_size), -100.0)
108 | alpha[:, self.start] = 0
109 | # alpha = Variable(alpha)
110 | lens_ = lens.clone()
111 |
112 | logits_t = logits.transpose(1, 0)
113 | for logit in logits_t:
114 | logit_exp = logit.unsqueeze(-1).expand(batch_size,
115 | *self.transition.size())
116 | alpha_exp = alpha.unsqueeze(1).expand(batch_size,
117 | *self.transition.size())
118 | trans_exp = self.transition.unsqueeze(0).expand_as(alpha_exp)
119 | mat = logit_exp + alpha_exp + trans_exp
120 | alpha_nxt = log_sum_exp(mat, 2).squeeze(-1)
121 |
122 | mask = (lens_ > 0).float().unsqueeze(-1).expand_as(alpha)
123 | alpha = mask * alpha_nxt + (1 - mask) * alpha
124 | lens_ = lens_ - 1
125 |
126 | alpha = alpha + self.transition[self.end].unsqueeze(0).expand_as(alpha)
127 | norm = log_sum_exp(alpha, 1).squeeze(-1)
128 |
129 | return norm
130 |
131 | def viterbi_decode(self, logits, lens):
132 | """Borrowed from pytorch tutorial
133 | Arguments:
134 | logits: [batch_size, seq_len, n_labels] FloatTensor
135 | lens: [batch_size] LongTensor
136 | """
137 | batch_size, seq_len, n_labels = logits.size()
138 | # vit = logits.data.new(batch_size, self.label_size).fill_(-10000)
139 | vit = logits.new_full((batch_size, self.label_size), -100.0)
140 | vit[:, self.start] = 0
141 | # vit = Variable(vit)
142 | c_lens = lens.clone()
143 |
144 | logits_t = logits.transpose(1, 0)
145 | pointers = []
146 | for logit in logits_t:
147 | vit_exp = vit.unsqueeze(1).expand(batch_size, n_labels, n_labels)
148 | trn_exp = self.transition.unsqueeze(0).expand_as(vit_exp)
149 | vit_trn_sum = vit_exp + trn_exp
150 | vt_max, vt_argmax = vit_trn_sum.max(2)
151 |
152 | vt_max = vt_max.squeeze(-1)
153 | vit_nxt = vt_max + logit
154 | pointers.append(vt_argmax.squeeze(-1).unsqueeze(0))
155 |
156 | mask = (c_lens > 0).float().unsqueeze(-1).expand_as(vit_nxt)
157 | vit = mask * vit_nxt + (1 - mask) * vit
158 |
159 | mask = (c_lens == 1).float().unsqueeze(-1).expand_as(vit_nxt)
160 | vit += mask * self.transition[self.end].unsqueeze(
161 | 0).expand_as(vit_nxt)
162 |
163 | c_lens = c_lens - 1
164 |
165 | pointers = torch.cat(pointers)
166 | scores, idx = vit.max(1)
167 | # idx = idx.squeeze(-1)
168 | paths = [idx.unsqueeze(1)]
169 | for argmax in reversed(pointers):
170 | idx_exp = idx.unsqueeze(-1)
171 | idx = torch.gather(argmax, 1, idx_exp)
172 | idx = idx.squeeze(-1)
173 |
174 | paths.insert(0, idx.unsqueeze(1))
175 |
176 | paths = torch.cat(paths[1:], 1)
177 | scores = scores.squeeze(-1)
178 |
179 | return scores, paths
180 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch optimization for BERT model."""
16 |
17 | import math
18 | import torch
19 | from torch.optim import Optimizer
20 | from torch.optim.optimizer import required
21 | from torch.nn.utils import clip_grad_norm_
22 |
23 | def warmup_cosine(x, warmup=0.002):
24 | if x < warmup:
25 | return x/warmup
26 | return 0.5 * (1.0 + torch.cos(math.pi * x))
27 |
28 | def warmup_constant(x, warmup=0.002):
29 | if x < warmup:
30 | return x/warmup
31 | return 1.0
32 |
33 | def warmup_linear(x, warmup=0.002):
34 | if x < warmup:
35 | return x/warmup
36 | return 1.0 - x
37 |
38 | SCHEDULES = {
39 | 'warmup_cosine':warmup_cosine,
40 | 'warmup_constant':warmup_constant,
41 | 'warmup_linear':warmup_linear,
42 | }
43 |
44 |
45 | class BertAdam(Optimizer):
46 | """Implements BERT version of Adam algorithm with weight decay fix.
47 | Params:
48 | lr: learning rate
49 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
50 | t_total: total number of training steps for the learning
51 | rate schedule, -1 means constant learning rate. Default: -1
52 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
53 | b1: Adams b1. Default: 0.9
54 | b2: Adams b2. Default: 0.999
55 | e: Adams epsilon. Default: 1e-6
56 | weight_decay: Weight decay. Default: 0.01
57 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
58 | """
59 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
60 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
61 | max_grad_norm=1.0):
62 | if lr is not required and lr < 0.0:
63 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
64 | if schedule not in SCHEDULES:
65 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
66 | if not 0.0 <= warmup < 1.0 and not warmup == -1:
67 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
68 | if not 0.0 <= b1 < 1.0:
69 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
70 | if not 0.0 <= b2 < 1.0:
71 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
72 | if not e >= 0.0:
73 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
74 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
75 | b1=b1, b2=b2, e=e, weight_decay=weight_decay,
76 | max_grad_norm=max_grad_norm)
77 | super(BertAdam, self).__init__(params, defaults)
78 |
79 | def get_lr(self):
80 | lr = []
81 | for group in self.param_groups:
82 | for p in group['params']:
83 | state = self.state[p]
84 | if len(state) == 0:
85 | return [0]
86 | if group['t_total'] != -1:
87 | schedule_fct = SCHEDULES[group['schedule']]
88 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
89 | else:
90 | lr_scheduled = group['lr']
91 | lr.append(lr_scheduled)
92 | return lr
93 |
94 | def step(self, closure=None):
95 | """Performs a single optimization step.
96 |
97 | Arguments:
98 | closure (callable, optional): A closure that reevaluates the model
99 | and returns the loss.
100 | """
101 | loss = None
102 | if closure is not None:
103 | loss = closure()
104 |
105 | for group in self.param_groups:
106 | for p in group['params']:
107 | if p.grad is None:
108 | continue
109 | grad = p.grad.data
110 | if grad.is_sparse:
111 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
112 |
113 | state = self.state[p]
114 |
115 | # State initialization
116 | if len(state) == 0:
117 | state['step'] = 0
118 | # Exponential moving average of gradient values
119 | state['next_m'] = torch.zeros_like(p.data)
120 | # Exponential moving average of squared gradient values
121 | state['next_v'] = torch.zeros_like(p.data)
122 |
123 | next_m, next_v = state['next_m'], state['next_v']
124 | beta1, beta2 = group['b1'], group['b2']
125 |
126 | # Add grad clipping
127 | if group['max_grad_norm'] > 0:
128 | clip_grad_norm_(p, group['max_grad_norm'])
129 |
130 | # Decay the first and second moment running average coefficient
131 | # In-place operations to update the averages at the same time
132 | next_m.mul_(beta1).add_(1 - beta1, grad)
133 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
134 | update = next_m / (next_v.sqrt() + group['e'])
135 |
136 | # Just adding the square of the weights to the loss function is *not*
137 | # the correct way of using L2 regularization/weight decay with Adam,
138 | # since that will interact with the m and v parameters in strange ways.
139 | #
140 | # Instead we want to decay the weights in a manner that doesn't interact
141 | # with the m/v parameters. This is equivalent to adding the square
142 | # of the weights to the loss with plain (non-momentum) SGD.
143 | if group['weight_decay'] > 0.0:
144 | update += group['weight_decay'] * p.data
145 |
146 | if group['t_total'] != -1:
147 | schedule_fct = SCHEDULES[group['schedule']]
148 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
149 | else:
150 | lr_scheduled = group['lr']
151 |
152 | update_with_lr = lr_scheduled * update
153 | p.data.add_(-update_with_lr)
154 |
155 | state['step'] += 1
156 |
157 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
158 | # No bias correction
159 | # bias_correction1 = 1 - beta1 ** state['step']
160 | # bias_correction2 = 1 - beta2 ** state['step']
161 |
162 | return loss
163 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from pytorch_pretrained_bert.modeling import *
2 | from torch.nn import BCEWithLogitsLoss
3 | from layers.embedding import BertEmbedder
4 | from layers.encoder import BertBiLSTMEncoder
5 | from layers.decoder import CRFDecoder
6 | from torch.autograd import Variable
7 | import torch.nn.functional as F
8 |
9 |
10 | class BertForMultiHeadProblem(BertPreTrainedModel):
11 | def __init__(self, config, num_labels):
12 | super(BertForMultiHeadProblem, self).__init__(config)
13 | self.num_labels = num_labels
14 | self.bert = BertModel(config)
15 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
16 | self.classifier = nn.Linear(config.hidden_size, num_labels)
17 | self.apply(self.init_bert_weights)
18 |
19 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
20 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
21 | sequence_output = self.dropout(sequence_output)
22 | logits = self.classifier(sequence_output)
23 |
24 | if labels is not None:
25 | loss_bce = BCEWithLogitsLoss()
26 | # Only keep active parts of the loss
27 | if attention_mask is not None:
28 | active_loss = attention_mask.view(-1) == 1
29 | active_logits = logits.view(-1, self.num_labels)[active_loss]
30 | active_labels = labels.view(-1, self.num_labels)[active_loss]
31 | loss = loss_bce(active_logits, active_labels)
32 | else:
33 | loss = loss_bce(logits.view(-1, self.num_labels), labels.view(-1))
34 | return loss
35 | else:
36 | return logits
37 |
38 | class MultiHeadLayers(nn.Module):
39 | def __init__(self, label_embedding_size, hidden_size, lstm_hidden_size, ner_size, rel_size, activation="tanh", dropout=0.1):
40 | super(MultiHeadLayers, self).__init__()
41 | self.dropout_1 = nn.Dropout(p=dropout)
42 | self.dropout_2 = nn.Dropout(p=dropout)
43 |
44 | self.activation = activation
45 | self.rel_size = rel_size
46 |
47 | self.label_embedding_size = label_embedding_size
48 |
49 | self.u_a = nn.Parameter(torch.randn(lstm_hidden_size + label_embedding_size, hidden_size))
50 | self.w_a = nn.Parameter(torch.randn(lstm_hidden_size + label_embedding_size, hidden_size))
51 | self.v = nn.Parameter(torch.randn(hidden_size, rel_size))
52 | self.b_s = nn.Parameter(torch.randn(hidden_size))
53 |
54 | if self.label_embedding_size != 0:
55 | self.label_embedding = nn.Embedding(ner_size, label_embedding_size)
56 |
57 |
58 | def broadcasting(self, left, right):
59 | left = left.transpose(1, 0)
60 | left = left.unsqueeze(3)
61 |
62 | right = right.transpose(2, 1)
63 | right = right.unsqueeze(0)
64 |
65 | B = left + right
66 | B = B.transpose(1, 0).transpose(3, 2)
67 |
68 | return B
69 |
70 | def forward(self, lstm_output, pred_ner):
71 | lstm_output = self.dropout_1(lstm_output)
72 | if self.label_embedding_size != 0:
73 | embeded_label = self.label_embedding(pred_ner)
74 | z = torch.cat([lstm_output, embeded_label], dim=2)
75 | else:
76 | z = lstm_output
77 | left = torch.einsum('aij,jk->aik', z, self.u_a)
78 | right = torch.einsum('aij,jk->aik', z, self.w_a)
79 |
80 | outer_sum = self.broadcasting(left, right)
81 |
82 | outer_sum_bias = 1 * ( outer_sum + self.b_s )
83 |
84 | if self.activation=="tanh":
85 | output = torch.tanh(outer_sum_bias)
86 | elif self.activation=="relu":
87 | output = torch.relu(outer_sum_bias)
88 |
89 | output = self.dropout_2(output)
90 |
91 | g = torch.einsum('aijk,kp->aijp', output, self.v)
92 |
93 | g = g.view(g.size(0), g.size(1), g.size(2) * self.rel_size)
94 |
95 | sigmoid = torch.nn.Sigmoid()
96 | probas = sigmoid(g)
97 | predictedRel = torch.round(probas)
98 |
99 | return predictedRel
100 |
101 | def score(self, lstm_output, gold_ner_labels, gold_rel_labels):
102 | lstm_output = self.dropout_1(lstm_output)
103 |
104 | if self.label_embedding_size != 0:
105 | embeded_label = self.label_embedding(gold_ner_labels)[:, :lstm_output.size(1), :]
106 | z = torch.cat([lstm_output, embeded_label], dim=2)
107 | else:
108 | z = lstm_output
109 |
110 | left = torch.einsum('aij,jk->aik', z, self.u_a)
111 | right = torch.einsum('aij,jk->aik', z, self.w_a)
112 |
113 | outer_sum = self.broadcasting(left, right)
114 |
115 | outer_sum_bias = 1 * ( outer_sum + self.b_s)
116 |
117 | if self.activation == "tanh":
118 | output = torch.tanh(outer_sum_bias)
119 | elif self.activation == "relu":
120 | output = torch.relu(outer_sum_bias)
121 |
122 | output = self.dropout_2(output)
123 |
124 | g = torch.einsum('aijk,kp->aijp', output, self.v)
125 |
126 | g = g.view(g.size(0), g.size(1), g.size(2) * self.rel_size)
127 |
128 | loss_bce = BCEWithLogitsLoss(reduction="mean")
129 |
130 | active_rel_labels = gold_rel_labels[:, :g.size(1), :g.size(2)]
131 |
132 | loss = loss_bce(g, active_rel_labels)
133 |
134 | return loss
135 |
136 |
137 | class BertBiLSTMCRF(nn.Module):
138 | def __init__(self, encoder, decoder, extra=None, use_cuda=True, use_extra=True):
139 | super(BertBiLSTMCRF, self).__init__()
140 | self.encoder = encoder
141 | self.extra = extra
142 | self.decoder = decoder
143 | self.use_cuda = use_cuda
144 | self.use_extra = use_extra
145 | if use_cuda:
146 | self.cuda()
147 | #print(list(self.parameters()))
148 |
149 | def forward(self, batch):
150 | output, hidden = self.encoder(batch)
151 | predictedNer = self.decoder(output, batch[-3])
152 |
153 | if self.use_extra:
154 | predictedRel = self.extra(output, self.decoder(output, batch[-3]))
155 | return predictedNer, predictedRel
156 | else:
157 | return predictedNer, None
158 |
159 | def score(self, batch):
160 | output, _ = self.encoder(batch)
161 | lossNER = self.decoder.score(output, batch[-3], batch[-2].long())
162 | if self.use_extra:
163 | lossREL = self.extra.score(output, batch[-2].long(), batch[-1])
164 | return lossNER + lossREL
165 | else:
166 | return lossNER
167 |
168 | @classmethod
169 | def create(cls,
170 | ner_size,
171 | rel_size,
172 | bert_pretrained_path, embedding_dim=768, bert_mode="weighted",
173 | freeze=True,
174 | enc_hidden_dim=128, rnn_layers=1,
175 | input_dropout=0.1,
176 | use_cuda=True,
177 | use_extra=True,
178 | meta_dim=None,
179 | hidden_size=64,
180 | label_embedding_size=64,
181 | activation="tanh"):
182 |
183 | embedder = BertEmbedder.create(bert_pretrained_path, embedding_dim, use_cuda, bert_mode, freeze)
184 | encoder = BertBiLSTMEncoder.create(embedder, enc_hidden_dim, rnn_layers, use_cuda)
185 |
186 | extra = None
187 | if use_extra:
188 | extra = MultiHeadLayers(label_embedding_size, hidden_size, enc_hidden_dim, ner_size, rel_size, activation, input_dropout)
189 |
190 | decoder = CRFDecoder(ner_size, encoder.output_dim, input_dropout, activation=activation)
191 |
192 | return cls(encoder, decoder, extra, use_cuda, use_extra)
193 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/file_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for working with the local dataset cache.
3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4 | Copyright by the AllenNLP authors.
5 | """
6 | from __future__ import (absolute_import, division, print_function, unicode_literals)
7 |
8 | import json
9 | import logging
10 | import os
11 | import shutil
12 | import tempfile
13 | from functools import wraps
14 | from hashlib import sha256
15 | import sys
16 | from io import open
17 |
18 | import boto3
19 | import requests
20 | from botocore.exceptions import ClientError
21 | from tqdm import tqdm
22 |
23 | try:
24 | from urllib.parse import urlparse
25 | except ImportError:
26 | from urlparse import urlparse
27 |
28 | try:
29 | from pathlib import Path
30 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
31 | Path.home() / '.pytorch_pretrained_bert'))
32 | except AttributeError:
33 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
34 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
35 |
36 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
37 |
38 |
39 | def url_to_filename(url, etag=None):
40 | """
41 | Convert `url` into a hashed filename in a repeatable way.
42 | If `etag` is specified, append its hash to the url's, delimited
43 | by a period.
44 | """
45 | url_bytes = url.encode('utf-8')
46 | url_hash = sha256(url_bytes)
47 | filename = url_hash.hexdigest()
48 |
49 | if etag:
50 | etag_bytes = etag.encode('utf-8')
51 | etag_hash = sha256(etag_bytes)
52 | filename += '.' + etag_hash.hexdigest()
53 |
54 | return filename
55 |
56 |
57 | def filename_to_url(filename, cache_dir=None):
58 | """
59 | Return the url and etag (which may be ``None``) stored for `filename`.
60 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
61 | """
62 | if cache_dir is None:
63 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
64 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
65 | cache_dir = str(cache_dir)
66 |
67 | cache_path = os.path.join(cache_dir, filename)
68 | if not os.path.exists(cache_path):
69 | raise EnvironmentError("file {} not found".format(cache_path))
70 |
71 | meta_path = cache_path + '.json'
72 | if not os.path.exists(meta_path):
73 | raise EnvironmentError("file {} not found".format(meta_path))
74 |
75 | with open(meta_path, encoding="utf-8") as meta_file:
76 | metadata = json.load(meta_file)
77 | url = metadata['url']
78 | etag = metadata['etag']
79 |
80 | return url, etag
81 |
82 |
83 | def cached_path(url_or_filename, cache_dir=None):
84 | """
85 | Given something that might be a URL (or might be a local path),
86 | determine which. If it's a URL, download the file and cache it, and
87 | return the path to the cached file. If it's already a local path,
88 | make sure the file exists and then return the path.
89 | """
90 | if cache_dir is None:
91 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
92 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
93 | url_or_filename = str(url_or_filename)
94 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
95 | cache_dir = str(cache_dir)
96 |
97 |
98 | parsed = urlparse(url_or_filename)
99 |
100 | if parsed.scheme in ('http', 'https', 's3'):
101 | # URL, so get it from the cache (downloading if necessary)
102 | return get_from_cache(url_or_filename, cache_dir)
103 | elif os.path.exists(url_or_filename):
104 | # File, and it exists.
105 | return url_or_filename
106 | elif parsed.scheme == '':
107 | # File, but it doesn't exist.
108 | raise EnvironmentError("file {} not found".format(url_or_filename))
109 | else:
110 | # Something unknown
111 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
112 |
113 |
114 | def split_s3_path(url):
115 | """Split a full s3 path into the bucket name and path."""
116 | parsed = urlparse(url)
117 | if not parsed.netloc or not parsed.path:
118 | raise ValueError("bad s3 path {}".format(url))
119 | bucket_name = parsed.netloc
120 | s3_path = parsed.path
121 | # Remove '/' at beginning of path.
122 | if s3_path.startswith("/"):
123 | s3_path = s3_path[1:]
124 | return bucket_name, s3_path
125 |
126 |
127 | def s3_request(func):
128 | """
129 | Wrapper function for s3 requests in order to create more helpful error
130 | messages.
131 | """
132 |
133 | @wraps(func)
134 | def wrapper(url, *args, **kwargs):
135 | try:
136 | return func(url, *args, **kwargs)
137 | except ClientError as exc:
138 | if int(exc.response["Error"]["Code"]) == 404:
139 | raise EnvironmentError("file {} not found".format(url))
140 | else:
141 | raise
142 |
143 | return wrapper
144 |
145 |
146 | @s3_request
147 | def s3_etag(url):
148 | """Check ETag on S3 object."""
149 | s3_resource = boto3.resource("s3")
150 | bucket_name, s3_path = split_s3_path(url)
151 | s3_object = s3_resource.Object(bucket_name, s3_path)
152 | return s3_object.e_tag
153 |
154 |
155 | @s3_request
156 | def s3_get(url, temp_file):
157 | """Pull a file directly from S3."""
158 | s3_resource = boto3.resource("s3")
159 | bucket_name, s3_path = split_s3_path(url)
160 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
161 |
162 |
163 | def http_get(url, temp_file):
164 | req = requests.get(url, stream=True)
165 | content_length = req.headers.get('Content-Length')
166 | total = int(content_length) if content_length is not None else None
167 | progress = tqdm(unit="B", total=total)
168 | for chunk in req.iter_content(chunk_size=1024):
169 | if chunk: # filter out keep-alive new chunks
170 | progress.update(len(chunk))
171 | temp_file.write(chunk)
172 | progress.close()
173 |
174 |
175 | def get_from_cache(url, cache_dir=None):
176 | """
177 | Given a URL, look for the corresponding dataset in the local cache.
178 | If it's not there, download it. Then return the path to the cached file.
179 | """
180 | if cache_dir is None:
181 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
182 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
183 | cache_dir = str(cache_dir)
184 |
185 | if not os.path.exists(cache_dir):
186 | os.makedirs(cache_dir)
187 |
188 | # Get eTag to add to filename, if it exists.
189 | if url.startswith("s3://"):
190 | etag = s3_etag(url)
191 | else:
192 | response = requests.head(url, allow_redirects=True)
193 | if response.status_code != 200:
194 | raise IOError("HEAD request failed for url {} with status code {}"
195 | .format(url, response.status_code))
196 | etag = response.headers.get("ETag")
197 |
198 | filename = url_to_filename(url, etag)
199 |
200 | # get cache path to put the file
201 | cache_path = os.path.join(cache_dir, filename)
202 |
203 | if not os.path.exists(cache_path):
204 | # Download to temporary file, then copy to cache dir once finished.
205 | # Otherwise you get corrupt cache entries if the download gets interrupted.
206 | with tempfile.NamedTemporaryFile() as temp_file:
207 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
208 |
209 | # GET file object
210 | if url.startswith("s3://"):
211 | s3_get(url, temp_file)
212 | else:
213 | http_get(url, temp_file)
214 |
215 | # we are copying the file before closing it, so flush to avoid truncation
216 | temp_file.flush()
217 | # shutil.copyfileobj() starts at the current position, so go to the start
218 | temp_file.seek(0)
219 |
220 | logger.info("copying %s to cache at %s", temp_file.name, cache_path)
221 | with open(cache_path, 'wb') as cache_file:
222 | shutil.copyfileobj(temp_file, cache_file)
223 |
224 | logger.info("creating metadata file for %s", cache_path)
225 | meta = {'url': url, 'etag': etag}
226 | meta_path = cache_path + '.json'
227 | with open(meta_path, 'w', encoding="utf-8") as meta_file:
228 | json.dump(meta, meta_file)
229 |
230 | logger.info("removing temp file %s", temp_file.name)
231 |
232 | return cache_path
233 |
234 |
235 | def read_set_from_file(filename):
236 | '''
237 | Extract a de-duped collection (set) of text from a file.
238 | Expected file format is one item per line.
239 | '''
240 | collection = set()
241 | with open(filename, 'r', encoding='utf-8') as file_:
242 | for line in file_:
243 | collection.add(line.rstrip())
244 | return collection
245 |
246 |
247 | def get_file_extension(path, dot=True, lower=True):
248 | ext = os.path.splitext(path)[1]
249 | ext = ext if dot else ext[1:]
250 | return ext.lower() if lower else ext
251 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/tokenization_gpt2.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Open AI Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for OpenAI GPT."""
16 | from __future__ import (absolute_import, division, print_function,
17 | unicode_literals)
18 |
19 | import json
20 | import logging
21 | import os
22 | import regex as re
23 | from io import open
24 |
25 | try:
26 | from functools import lru_cache
27 | except ImportError:
28 | # Just a dummy decorator to get the checks to run on python2
29 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
30 | def lru_cache():
31 | return lambda func: func
32 |
33 | from .file_utils import cached_path
34 |
35 | logger = logging.getLogger(__name__)
36 |
37 | PRETRAINED_VOCAB_ARCHIVE_MAP = {
38 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
39 | }
40 | PRETRAINED_MERGES_ARCHIVE_MAP = {
41 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
42 | }
43 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
44 | 'gpt2': 1024,
45 | }
46 | VOCAB_NAME = 'vocab.json'
47 | MERGES_NAME = 'merges.txt'
48 |
49 | @lru_cache()
50 | def bytes_to_unicode():
51 | """
52 | Returns list of utf-8 byte and a corresponding list of unicode strings.
53 | The reversible bpe codes work on unicode strings.
54 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
55 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
56 | This is a signficant percentage of your normal, say, 32K bpe vocab.
57 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
58 | And avoids mapping to whitespace/control characters the bpe code barfs on.
59 | """
60 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
61 | cs = bs[:]
62 | n = 0
63 | for b in range(2**8):
64 | if b not in bs:
65 | bs.append(b)
66 | cs.append(2**8+n)
67 | n += 1
68 | cs = [chr(n) for n in cs]
69 | return dict(zip(bs, cs))
70 |
71 | def get_pairs(word):
72 | """Return set of symbol pairs in a word.
73 |
74 | Word is represented as tuple of symbols (symbols being variable-length strings).
75 | """
76 | pairs = set()
77 | prev_char = word[0]
78 | for char in word[1:]:
79 | pairs.add((prev_char, char))
80 | prev_char = char
81 | return pairs
82 |
83 | class GPT2Tokenizer(object):
84 | """
85 | GPT-2 BPE tokenizer. Peculiarities:
86 | - Byte-level BPE
87 | """
88 | @classmethod
89 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
90 | """
91 | Instantiate a PreTrainedBertModel from a pre-trained model file.
92 | Download and cache the pre-trained model file if needed.
93 | """
94 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
95 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
96 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
97 | else:
98 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
99 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
100 | # redirect to the cache, if necessary
101 | try:
102 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
103 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
104 | except EnvironmentError:
105 | logger.error(
106 | "Model name '{}' was not found in model name list ({}). "
107 | "We assumed '{}' was a path or url but couldn't find files {} and {} "
108 | "at this path or url.".format(
109 | pretrained_model_name_or_path,
110 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
111 | pretrained_model_name_or_path,
112 | vocab_file, merges_file))
113 | return None
114 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
115 | logger.info("loading vocabulary file {}".format(vocab_file))
116 | logger.info("loading merges file {}".format(merges_file))
117 | else:
118 | logger.info("loading vocabulary file {} from cache at {}".format(
119 | vocab_file, resolved_vocab_file))
120 | logger.info("loading merges file {} from cache at {}".format(
121 | merges_file, resolved_merges_file))
122 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
123 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
124 | # than the number of positional embeddings
125 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
126 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
127 | # Instantiate tokenizer.
128 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
129 | return tokenizer
130 |
131 | def __init__(self, vocab_file, merges_file, errors='replace', max_len=None):
132 | self.max_len = max_len if max_len is not None else int(1e12)
133 | self.encoder = json.load(open(vocab_file))
134 | self.decoder = {v:k for k,v in self.encoder.items()}
135 | self.errors = errors # how to handle errors in decoding
136 | self.byte_encoder = bytes_to_unicode()
137 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
138 | bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
139 | bpe_merges = [tuple(merge.split()) for merge in bpe_data]
140 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
141 | self.cache = {}
142 |
143 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
144 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
145 |
146 | def __len__(self):
147 | return len(self.encoder)
148 |
149 | def bpe(self, token):
150 | if token in self.cache:
151 | return self.cache[token]
152 | word = tuple(token)
153 | pairs = get_pairs(word)
154 |
155 | if not pairs:
156 | return token
157 |
158 | while True:
159 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
160 | if bigram not in self.bpe_ranks:
161 | break
162 | first, second = bigram
163 | new_word = []
164 | i = 0
165 | while i < len(word):
166 | try:
167 | j = word.index(first, i)
168 | new_word.extend(word[i:j])
169 | i = j
170 | except:
171 | new_word.extend(word[i:])
172 | break
173 |
174 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
175 | new_word.append(first+second)
176 | i += 2
177 | else:
178 | new_word.append(word[i])
179 | i += 1
180 | new_word = tuple(new_word)
181 | word = new_word
182 | if len(word) == 1:
183 | break
184 | else:
185 | pairs = get_pairs(word)
186 | word = ' '.join(word)
187 | self.cache[token] = word
188 | return word
189 |
190 | def encode(self, text):
191 | bpe_tokens = []
192 | for token in re.findall(self.pat, text):
193 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
194 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
195 | if len(bpe_tokens) > self.max_len:
196 | raise ValueError(
197 | "Token indices sequence length is longer than the specified maximum "
198 | " sequence length for this OpenAI GPT-2 model ({} > {}). Running this"
199 | " sequence through the model will result in indexing errors".format(len(bpe_tokens), self.max_len)
200 | )
201 | return bpe_tokens
202 |
203 | def decode(self, tokens):
204 | text = ''.join([self.decoder[token] for token in tokens])
205 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
206 | return text
207 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/tokenization_openai.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Open AI Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for OpenAI GPT."""
16 | from __future__ import (absolute_import, division, print_function,
17 | unicode_literals)
18 |
19 | import json
20 | import logging
21 | import os
22 | import re
23 | import sys
24 | from io import open
25 |
26 | from tqdm import tqdm
27 |
28 | from .file_utils import cached_path
29 | from .tokenization import BasicTokenizer
30 |
31 | logger = logging.getLogger(__name__)
32 |
33 | PRETRAINED_VOCAB_ARCHIVE_MAP = {
34 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
35 | }
36 | PRETRAINED_MERGES_ARCHIVE_MAP = {
37 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
38 | }
39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
40 | 'openai-gpt': 512,
41 | }
42 | VOCAB_NAME = 'vocab.json'
43 | MERGES_NAME = 'merges.txt'
44 |
45 | def get_pairs(word):
46 | """
47 | Return set of symbol pairs in a word.
48 | word is represented as tuple of symbols (symbols being variable-length strings)
49 | """
50 | pairs = set()
51 | prev_char = word[0]
52 | for char in word[1:]:
53 | pairs.add((prev_char, char))
54 | prev_char = char
55 | return pairs
56 |
57 | def text_standardize(text):
58 | """
59 | fixes some issues the spacy tokenizer had on books corpus
60 | also does some whitespace standardization
61 | """
62 | text = text.replace('—', '-')
63 | text = text.replace('–', '-')
64 | text = text.replace('―', '-')
65 | text = text.replace('…', '...')
66 | text = text.replace('´', "'")
67 | text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text)
68 | text = re.sub(r'\s*\n\s*', ' \n ', text)
69 | text = re.sub(r'[^\S\n]+', ' ', text)
70 | return text.strip()
71 |
72 | class OpenAIGPTTokenizer(object):
73 | """
74 | BPE tokenizer. Peculiarities:
75 | - lower case all inputs
76 | - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
77 | - argument special_tokens and function set_special_tokens:
78 | can be used to add additional symbols (ex: "__classify__") to a vocabulary.
79 | """
80 | @classmethod
81 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
82 | """
83 | Instantiate a PreTrainedBertModel from a pre-trained model file.
84 | Download and cache the pre-trained model file if needed.
85 | """
86 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
87 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
88 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
89 | else:
90 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
91 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
92 | # redirect to the cache, if necessary
93 | try:
94 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
95 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
96 | except EnvironmentError:
97 | logger.error(
98 | "Model name '{}' was not found in model name list ({}). "
99 | "We assumed '{}' was a path or url but couldn't find files {} and {} "
100 | "at this path or url.".format(
101 | pretrained_model_name_or_path,
102 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
103 | pretrained_model_name_or_path,
104 | vocab_file, merges_file))
105 | return None
106 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
107 | logger.info("loading vocabulary file {}".format(vocab_file))
108 | logger.info("loading merges file {}".format(merges_file))
109 | else:
110 | logger.info("loading vocabulary file {} from cache at {}".format(
111 | vocab_file, resolved_vocab_file))
112 | logger.info("loading merges file {} from cache at {}".format(
113 | merges_file, resolved_merges_file))
114 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
115 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
116 | # than the number of positional embeddings
117 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
118 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
119 | # Instantiate tokenizer.
120 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
121 | return tokenizer
122 |
123 | def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None):
124 | try:
125 | import ftfy
126 | import spacy
127 | self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
128 | self.fix_text = ftfy.fix_text
129 | except ImportError:
130 | logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
131 | self.nlp = BasicTokenizer(do_lower_case=True,
132 | never_split=special_tokens if special_tokens is not None else [])
133 | self.fix_text = None
134 |
135 | self.max_len = max_len if max_len is not None else int(1e12)
136 | self.encoder = json.load(open(vocab_file, encoding="utf-8"))
137 | self.decoder = {v:k for k,v in self.encoder.items()}
138 | merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
139 | merges = [tuple(merge.split()) for merge in merges]
140 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
141 | self.cache = {}
142 | self.set_special_tokens(special_tokens)
143 |
144 | def __len__(self):
145 | return len(self.encoder) + len(self.special_tokens)
146 |
147 | def set_special_tokens(self, special_tokens):
148 | """ Add a list of additional tokens to the encoder.
149 | The additional tokens are indexed starting from the last index of the
150 | current vocabulary in the order of the `special_tokens` list.
151 | """
152 | if not special_tokens:
153 | self.special_tokens = {}
154 | self.special_tokens_decoder = {}
155 | return
156 | self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
157 | self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
158 | if self.fix_text is None:
159 | # Using BERT's BasicTokenizer: we can update the tokenizer
160 | self.nlp.never_split = special_tokens
161 | logger.info("Special tokens {}".format(self.special_tokens))
162 |
163 | def bpe(self, token):
164 | word = tuple(token[:-1]) + (token[-1] + '',)
165 | if token in self.cache:
166 | return self.cache[token]
167 | pairs = get_pairs(word)
168 |
169 | if not pairs:
170 | return token+''
171 |
172 | while True:
173 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
174 | if bigram not in self.bpe_ranks:
175 | break
176 | first, second = bigram
177 | new_word = []
178 | i = 0
179 | while i < len(word):
180 | try:
181 | j = word.index(first, i)
182 | new_word.extend(word[i:j])
183 | i = j
184 | except:
185 | new_word.extend(word[i:])
186 | break
187 |
188 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
189 | new_word.append(first+second)
190 | i += 2
191 | else:
192 | new_word.append(word[i])
193 | i += 1
194 | new_word = tuple(new_word)
195 | word = new_word
196 | if len(word) == 1:
197 | break
198 | else:
199 | pairs = get_pairs(word)
200 | word = ' '.join(word)
201 | if word == '\n ':
202 | word = '\n'
203 | self.cache[token] = word
204 | return word
205 |
206 | def tokenize(self, text):
207 | """ Tokenize a string. """
208 | split_tokens = []
209 | if self.fix_text is None:
210 | # Using BERT's BasicTokenizer
211 | text = self.nlp.tokenize(text)
212 | for token in text:
213 | split_tokens.extend([t for t in self.bpe(token).split(' ')])
214 | else:
215 | # Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
216 | text = self.nlp(text_standardize(self.fix_text(text)))
217 | for token in text:
218 | split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
219 | return split_tokens
220 |
221 | def convert_tokens_to_ids(self, tokens):
222 | """ Converts a sequence of tokens into ids using the vocab. """
223 | ids = []
224 | if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
225 | if tokens in self.special_tokens:
226 | return self.special_tokens[tokens]
227 | else:
228 | return self.encoder.get(tokens, 0)
229 | for token in tokens:
230 | if token in self.special_tokens:
231 | ids.append(self.special_tokens[token])
232 | else:
233 | ids.append(self.encoder.get(token, 0))
234 | if len(ids) > self.max_len:
235 | raise ValueError(
236 | "Token indices sequence length is longer than the specified maximum "
237 | " sequence length for this OpenAI GPT model ({} > {}). Running this"
238 | " sequence through the model will result in indexing errors".format(len(ids), self.max_len)
239 | )
240 | return ids
241 |
242 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
243 | """Converts a sequence of ids in BPE tokens using the vocab."""
244 | tokens = []
245 | for i in ids:
246 | if i in self.special_tokens_decoder:
247 | if not skip_special_tokens:
248 | tokens.append(self.special_tokens_decoder[i])
249 | else:
250 | tokens.append(self.decoder[i])
251 | return tokens
252 |
253 | def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=False):
254 | """Converts a sequence of ids in a string."""
255 | tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
256 | out_string = ''.join(tokens).replace('', ' ').strip()
257 | if clean_up_tokenization_spaces:
258 | out_string = out_string.replace('', '')
259 | out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',').replace(' ,', ','
260 | ).replace(" n't", "n't").replace(" 'm", "'m").replace(" 're", "'re").replace(" do not", " don't"
261 | ).replace(" 's", "'s").replace(" t ", "'t ").replace(" s ", "'s ").replace(" m ", "'m "
262 | ).replace(" 've", "'ve")
263 | return out_string
264 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import, division, print_function, unicode_literals
18 |
19 | import collections
20 | import logging
21 | import os
22 | import unicodedata
23 | from io import open
24 |
25 | from .file_utils import cached_path
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | PRETRAINED_VOCAB_ARCHIVE_MAP = {
30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
37 | }
38 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
39 | 'bert-base-uncased': 512,
40 | 'bert-large-uncased': 512,
41 | 'bert-base-cased': 512,
42 | 'bert-large-cased': 512,
43 | 'bert-base-multilingual-uncased': 512,
44 | 'bert-base-multilingual-cased': 512,
45 | 'bert-base-chinese': 512,
46 | }
47 | VOCAB_NAME = 'vocab.txt'
48 |
49 |
50 | def load_vocab(vocab_file):
51 | """Loads a vocabulary file into a dictionary."""
52 | vocab = collections.OrderedDict()
53 | index = 0
54 | with open(vocab_file, "r", encoding="utf-8") as reader:
55 | while True:
56 | token = reader.readline()
57 | if not token:
58 | break
59 | token = token.strip()
60 | vocab[token] = index
61 | index += 1
62 | return vocab
63 |
64 |
65 | def whitespace_tokenize(text):
66 | """Runs basic whitespace cleaning and splitting on a piece of text."""
67 | text = text.strip()
68 | if not text:
69 | return []
70 | tokens = text.split()
71 | return tokens
72 |
73 |
74 | class BertTokenizer(object):
75 | """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
76 |
77 | def __init__(self, vocab_file, do_lower_case=True, max_len=None,
78 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
79 | if not os.path.isfile(vocab_file):
80 | raise ValueError(
81 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
82 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
83 | self.vocab = load_vocab(vocab_file)
84 | self.ids_to_tokens = collections.OrderedDict(
85 | [(ids, tok) for tok, ids in self.vocab.items()])
86 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
87 | never_split=never_split)
88 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
89 | self.max_len = max_len if max_len is not None else int(1e12)
90 |
91 | def tokenize(self, text):
92 | split_tokens = []
93 | for token in self.basic_tokenizer.tokenize(text):
94 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
95 | split_tokens.append(sub_token)
96 | return split_tokens
97 |
98 | def convert_tokens_to_ids(self, tokens):
99 | """Converts a sequence of tokens into ids using the vocab."""
100 | ids = []
101 | for token in tokens:
102 | ids.append(self.vocab[token])
103 | if len(ids) > self.max_len:
104 | raise ValueError(
105 | "Token indices sequence length is longer than the specified maximum "
106 | " sequence length for this BERT model ({} > {}). Running this"
107 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
108 | )
109 | return ids
110 |
111 | def convert_ids_to_tokens(self, ids):
112 | """Converts a sequence of ids in wordpiece tokens using the vocab."""
113 | tokens = []
114 | for i in ids:
115 | tokens.append(self.ids_to_tokens[i])
116 | return tokens
117 |
118 | @classmethod
119 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
120 | """
121 | Instantiate a PreTrainedBertModel from a pre-trained model file.
122 | Download and cache the pre-trained model file if needed.
123 | """
124 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
125 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
126 | else:
127 | vocab_file = pretrained_model_name_or_path
128 | if os.path.isdir(vocab_file):
129 | vocab_file = os.path.join(vocab_file, VOCAB_NAME)
130 | # redirect to the cache, if necessary
131 | try:
132 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
133 | except EnvironmentError:
134 | logger.error(
135 | "Model name '{}' was not found in model name list ({}). "
136 | "We assumed '{}' was a path or url but couldn't find any file "
137 | "associated to this path or url.".format(
138 | pretrained_model_name_or_path,
139 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
140 | vocab_file))
141 | return None
142 | if resolved_vocab_file == vocab_file:
143 | logger.info("loading vocabulary file {}".format(vocab_file))
144 | else:
145 | logger.info("loading vocabulary file {} from cache at {}".format(
146 | vocab_file, resolved_vocab_file))
147 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
148 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
149 | # than the number of positional embeddings
150 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
151 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
152 | # Instantiate tokenizer.
153 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
154 | return tokenizer
155 |
156 |
157 | class BasicTokenizer(object):
158 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
159 |
160 | def __init__(self,
161 | do_lower_case=True,
162 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
163 | """Constructs a BasicTokenizer.
164 |
165 | Args:
166 | do_lower_case: Whether to lower case the input.
167 | """
168 | self.do_lower_case = do_lower_case
169 | self.never_split = never_split
170 |
171 | def tokenize(self, text):
172 | """Tokenizes a piece of text."""
173 | text = self._clean_text(text)
174 | # This was added on November 1st, 2018 for the multilingual and Chinese
175 | # models. This is also applied to the English models now, but it doesn't
176 | # matter since the English models were not trained on any Chinese data
177 | # and generally don't have any Chinese data in them (there are Chinese
178 | # characters in the vocabulary because Wikipedia does have some Chinese
179 | # words in the English Wikipedia.).
180 | text = self._tokenize_chinese_chars(text)
181 | orig_tokens = whitespace_tokenize(text)
182 | split_tokens = []
183 | for token in orig_tokens:
184 | if self.do_lower_case and token not in self.never_split:
185 | token = token.lower()
186 | token = self._run_strip_accents(token)
187 | split_tokens.extend(self._run_split_on_punc(token))
188 |
189 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
190 | return output_tokens
191 |
192 | def _run_strip_accents(self, text):
193 | """Strips accents from a piece of text."""
194 | text = unicodedata.normalize("NFD", text)
195 | output = []
196 | for char in text:
197 | cat = unicodedata.category(char)
198 | if cat == "Mn":
199 | continue
200 | output.append(char)
201 | return "".join(output)
202 |
203 | def _run_split_on_punc(self, text):
204 | """Splits punctuation on a piece of text."""
205 | if text in self.never_split:
206 | return [text]
207 | chars = list(text)
208 | i = 0
209 | start_new_word = True
210 | output = []
211 | while i < len(chars):
212 | char = chars[i]
213 | if _is_punctuation(char):
214 | output.append([char])
215 | start_new_word = True
216 | else:
217 | if start_new_word:
218 | output.append([])
219 | start_new_word = False
220 | output[-1].append(char)
221 | i += 1
222 |
223 | return ["".join(x) for x in output]
224 |
225 | def _tokenize_chinese_chars(self, text):
226 | """Adds whitespace around any CJK character."""
227 | output = []
228 | for char in text:
229 | cp = ord(char)
230 | if self._is_chinese_char(cp):
231 | output.append(" ")
232 | output.append(char)
233 | output.append(" ")
234 | else:
235 | output.append(char)
236 | return "".join(output)
237 |
238 | def _is_chinese_char(self, cp):
239 | """Checks whether CP is the codepoint of a CJK character."""
240 | # This defines a "chinese character" as anything in the CJK Unicode block:
241 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
242 | #
243 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
244 | # despite its name. The modern Korean Hangul alphabet is a different block,
245 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
246 | # space-separated words, so they are not treated specially and handled
247 | # like the all of the other languages.
248 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
249 | (cp >= 0x3400 and cp <= 0x4DBF) or #
250 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
251 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
252 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
253 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
254 | (cp >= 0xF900 and cp <= 0xFAFF) or #
255 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
256 | return True
257 |
258 | return False
259 |
260 | def _clean_text(self, text):
261 | """Performs invalid character removal and whitespace cleanup on text."""
262 | output = []
263 | for char in text:
264 | cp = ord(char)
265 | if cp == 0 or cp == 0xfffd or _is_control(char):
266 | continue
267 | if _is_whitespace(char):
268 | output.append(" ")
269 | else:
270 | output.append(char)
271 | return "".join(output)
272 |
273 |
274 | class WordpieceTokenizer(object):
275 | """Runs WordPiece tokenization."""
276 |
277 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
278 | self.vocab = vocab
279 | self.unk_token = unk_token
280 | self.max_input_chars_per_word = max_input_chars_per_word
281 |
282 | def tokenize(self, text):
283 | """Tokenizes a piece of text into its word pieces.
284 |
285 | This uses a greedy longest-match-first algorithm to perform tokenization
286 | using the given vocabulary.
287 |
288 | For example:
289 | input = "unaffable"
290 | output = ["un", "##aff", "##able"]
291 |
292 | Args:
293 | text: A single token or whitespace separated tokens. This should have
294 | already been passed through `BasicTokenizer`.
295 |
296 | Returns:
297 | A list of wordpiece tokens.
298 | """
299 |
300 | output_tokens = []
301 | for token in whitespace_tokenize(text):
302 | chars = list(token)
303 | if len(chars) > self.max_input_chars_per_word:
304 | output_tokens.append(self.unk_token)
305 | continue
306 |
307 | is_bad = False
308 | start = 0
309 | sub_tokens = []
310 | while start < len(chars):
311 | end = len(chars)
312 | cur_substr = None
313 | while start < end:
314 | substr = "".join(chars[start:end])
315 | if start > 0:
316 | substr = "##" + substr
317 | if substr in self.vocab:
318 | cur_substr = substr
319 | break
320 | end -= 1
321 | if cur_substr is None:
322 | is_bad = True
323 | break
324 | sub_tokens.append(cur_substr)
325 | start = end
326 |
327 | if is_bad:
328 | output_tokens.append(self.unk_token)
329 | else:
330 | output_tokens.extend(sub_tokens)
331 | return output_tokens
332 |
333 |
334 | def _is_whitespace(char):
335 | """Checks whether `chars` is a whitespace character."""
336 | # \t, \n, and \r are technically contorl characters but we treat them
337 | # as whitespace since they are generally considered as such.
338 | if char == " " or char == "\t" or char == "\n" or char == "\r":
339 | return True
340 | cat = unicodedata.category(char)
341 | if cat == "Zs":
342 | return True
343 | return False
344 |
345 |
346 | def _is_control(char):
347 | """Checks whether `chars` is a control character."""
348 | # These are technically control characters but we count them as whitespace
349 | # characters.
350 | if char == "\t" or char == "\n" or char == "\r":
351 | return False
352 | cat = unicodedata.category(char)
353 | if cat.startswith("C"):
354 | return True
355 | return False
356 |
357 |
358 | def _is_punctuation(char):
359 | """Checks whether `chars` is a punctuation character."""
360 | cp = ord(char)
361 | # We treat all non-letter/number ASCII as punctuation.
362 | # Characters such as "^", "$", and "`" are not in the Unicode
363 | # Punctuation class but we treat them as punctuation anyways, for
364 | # consistency.
365 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
366 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
367 | return True
368 | cat = unicodedata.category(char)
369 | if cat.startswith("P"):
370 | return True
371 | return False
372 |
--------------------------------------------------------------------------------
/DataGen.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import torch
3 | from torch import nn
4 |
5 | import pytorch_pretrained_bert as _bert
6 | from random import shuffle
7 |
8 | import numpy as np
9 |
10 | #from dougu import flatten, lines
11 |
12 |
13 | _device = torch.device("cpu")
14 |
15 | class DataGenerator:
16 |
17 | MASK = "[MASK]"
18 | CLS = "[CLS]"
19 | SEP = "[SEP]"
20 | ner_list = ['O', 'I-CAT', 'B-TIM', 'I-ETT', 'B-RLT', 'B-ETT', 'B-VAR', 'I-RLT', 'I-TIM', 'B-CAT', 'I-VAR', 'X', '[CLS]', '[SEP]']
21 | rel_list = ["SP", "N", "ST", "PO", "SC"]
22 |
23 | def __init__(self, model, model_name, device=None, half_precision=False, rel_max_len=32):
24 | self.model_name = model_name
25 | self.device = device or _device
26 | do_lower_case = "uncased" in model_name
27 | self.tokenizer = _bert.BertTokenizer.from_pretrained(self.model_name, do_lower_case=do_lower_case)
28 |
29 | maybe_model_wrapper = model.from_pretrained(model_name).to(device=self.device)
30 | try:
31 | self.model = maybe_model_wrapper.bert
32 | except AttributeError:
33 | self.model = maybe_model_wrapper
34 | if half_precision:
35 | self.model.half()
36 | self.max_len = \
37 | self.model.embeddings.position_embeddings.weight.size(0)
38 | self.dim = self.model.embeddings.position_embeddings.weight.size(1)
39 | self.rel_max_len = rel_max_len
40 | def tokenize(self, text, masked_idxs=None):
41 | tokenized_text = self.tokenizer.tokenize(text)
42 | if masked_idxs is not None:
43 | for idx in masked_idxs:
44 | tokenized_text[idx] = self.MASK
45 | tokenized = [self.CLS] + tokenized_text + [self.SEP]
46 | return tokenized
47 |
48 | def tokenize_to_ids(self, text, masked_idxs=None, pad=True):
49 | tokens = self.tokenize(text, masked_idxs)
50 | return self.convert_tokens_to_ids(tokens, pad=pad)
51 |
52 | def convert_tokens_to_ids(self, tokens, pad=True):
53 | token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
54 | ids = torch.tensor([token_ids]).to(device=self.device)
55 | assert ids.size(1) < self.max_len
56 | if pad:
57 | padded_ids = torch.zeros(1, self.max_len).to(ids)
58 | padded_ids[0, :ids.size(1)] = ids
59 | mask = torch.zeros(1, self.max_len).to(ids)
60 | mask[0, :ids.size(1)] = 1
61 |
62 | return padded_ids, mask
63 | else:
64 | return ids
65 |
66 | def padding_labels(self, true_labels, pad=True):
67 | true_labels = torch.tensor([true_labels]).to(device=self.device)
68 | assert true_labels.size(1) < self.max_len
69 | if pad:
70 |
71 | padded_labels = torch.zeros(1, self.max_len).to(true_labels)
72 | padded_labels[0, :true_labels.size(1)] = true_labels
73 | return padded_labels
74 | else:
75 | return_true_labels = torch.zeros(1, true_labels.size(1))
76 | return_true_labels[0,:] = true_labels
77 | return return_true_labels
78 |
79 | def parseData(self, file_name):
80 | data = {}
81 | with open(file_name, "r") as f:
82 | current_id = ""
83 | for line in f:
84 | if len(line) <= 2:
85 | continue
86 | if "#" in line:
87 | current_id = line.strip()
88 | data[current_id] = {}
89 | data[current_id]["tokens"] = []
90 | data[current_id]["label_types"] = []
91 | data[current_id]["index_rels"] = []
92 | data[current_id]["label_rels"] = []
93 | else:
94 | info = line.split("\t")
95 |
96 | data[current_id]["tokens"].append(info[1])
97 | data[current_id]["label_types"].append(info[2])
98 |
99 | index_rels_info = info[4].strip().replace("[","").replace("]","").replace(" ", "").split(",")
100 | data[current_id]["index_rels"].append([int(i) for i in index_rels_info])
101 |
102 | label_rels_info = info[3].strip().replace("[","").replace("]","").replace(" ", "").split(",")
103 | data[current_id]["label_rels"].append([str(i).replace("'", "") for i in label_rels_info])
104 | return data
105 |
106 | def convertData(self, original_data, padding=True):
107 | ver_1_data = {}
108 |
109 | for id in original_data:
110 | ver_1_item = {}
111 | ver_1_item["tokens"] = []
112 | ver_1_item["label_types"] = original_data[id]["label_types"]
113 | ver_1_item["index_rels"] = original_data[id]["index_rels"]
114 | ver_1_item["label_rels"] = original_data[id]["label_rels"]
115 | shift_idx = 0
116 | for idx, orig_token in enumerate(original_data[id]["tokens"]):
117 | sub_orig_tokens = orig_token.split("_")
118 | ## tokens
119 | ver_1_item["tokens"] += sub_orig_tokens
120 | if len(sub_orig_tokens) > 1:
121 | ## label_types
122 | orig_item_label_type = ver_1_item["label_types"][idx + shift_idx].replace("B-","").replace("I-", "")
123 | if ver_1_item["label_types"][idx + shift_idx] == "O":
124 | ver_1_item["label_types"] = ver_1_item["label_types"][:idx + shift_idx + 1] + ["O"] * (len(sub_orig_tokens) - 1) + ver_1_item["label_types"][idx+shift_idx + 1:]
125 | else:
126 | ver_1_item["label_types"] = ver_1_item["label_types"][:idx + shift_idx + 1] + ["I-" + orig_item_label_type] * (len(sub_orig_tokens) - 1) + ver_1_item["label_types"][idx+shift_idx + 1:]
127 | ## label_rels
128 | ver_1_item["label_rels"] = ver_1_item["label_rels"][:idx + shift_idx] + [["N"]] * (len(sub_orig_tokens) - 1) + ver_1_item["label_rels"][idx + shift_idx:]
129 |
130 | ## index_rels
131 | prev_index_rels = [] # ver_1_item["label_rels"][:idx + shift_idx]
132 | next_index_rels = [] # ver_1_item["label_rels"][idx+shift_idx:]
133 |
134 | for index_rels in ver_1_item["index_rels"][:idx + shift_idx]:
135 | prev_index_rel = []
136 | for index_rel in index_rels:
137 | if index_rel >= idx + shift_idx:
138 | prev_index_rel.append(index_rel + len(sub_orig_tokens) - 1)
139 | else:
140 | prev_index_rel.append(index_rel)
141 | prev_index_rels.append(prev_index_rel)
142 |
143 | for index_rels in ver_1_item["index_rels"][idx+shift_idx:]:
144 | next_index_rel = []
145 | for index_rel in index_rels:
146 | if index_rel >= idx + shift_idx:
147 | next_index_rel.append(index_rel + len(sub_orig_tokens) - 1)
148 | else:
149 | next_index_rel.append(index_rel)
150 |
151 | next_index_rels.append(next_index_rel)
152 |
153 | ver_1_item["index_rels"] = prev_index_rels + [[(idx + shift_idx + i)] for i in range(len(sub_orig_tokens) - 1)] + next_index_rels
154 | shift_idx += len(sub_orig_tokens) - 1
155 | ver_1_data[id] = ver_1_item
156 |
157 |
158 | bert_data = {}
159 | for id in ver_1_data:
160 | bert_item = {}
161 | bert_item["subword_ids"] = []
162 | bert_item["mask"] = []
163 | bert_item["token_starts"] = []
164 |
165 | bert_item["label_types"] = ver_1_data[id]["label_types"]
166 | bert_item["index_rels"] = ver_1_data[id]["index_rels"]
167 | bert_item["label_rels"] = ver_1_data[id]["label_rels"]
168 |
169 |
170 | shift_idx = 0
171 |
172 | bert_tokens = []
173 | bert_token_starts = []
174 | for idx, ver_1_token in enumerate(ver_1_data[id]["tokens"]):
175 | sub_ver_1_tokens = self.tokenizer.tokenize(ver_1_token)
176 | bert_token_starts.append(1 + len(sub_ver_1_tokens))
177 | ## tokens
178 | bert_tokens += sub_ver_1_tokens
179 | if len(sub_ver_1_tokens) > 1:
180 | ## label_types
181 | bert_item["label_types"] = bert_item["label_types"][:idx + shift_idx + 1] + ["X"] * (len(sub_ver_1_tokens) - 1) + bert_item["label_types"][idx+shift_idx + 1:]
182 |
183 | ## label_rels
184 | bert_item["label_rels"] = bert_item["label_rels"][:idx + shift_idx] + [["N"]] * (len(sub_ver_1_tokens) - 1) + bert_item["label_rels"][idx + shift_idx:]
185 | ## index_rels
186 | prev_index_rels = [] # bert_item["label_rels"][:idx + shift_idx]
187 | next_index_rels = [] # bert_item["label_rels"][idx+shift_idx:]
188 | for index_rels in bert_item["index_rels"][:idx + shift_idx]:
189 | prev_index_rel = []
190 | for index_rel in index_rels:
191 | if index_rel >= idx + shift_idx:
192 | prev_index_rel.append(index_rel + len(sub_ver_1_tokens) - 1)
193 | else:
194 | prev_index_rel.append(index_rel)
195 | prev_index_rels.append(prev_index_rel)
196 |
197 | for index_rels in bert_item["index_rels"][idx+shift_idx:]:
198 | next_index_rel = []
199 | for index_rel in index_rels:
200 | if index_rel >= idx + shift_idx:
201 | next_index_rel.append(index_rel + len(sub_ver_1_tokens) - 1)
202 | else:
203 | next_index_rel.append(index_rel)
204 |
205 | next_index_rels.append(next_index_rel)
206 |
207 | bert_item["index_rels"] = prev_index_rels + [[(idx + shift_idx + i)] for i in range(len(sub_ver_1_tokens) - 1)] + next_index_rels
208 | shift_idx += len(sub_ver_1_tokens) - 1
209 |
210 | ######################################################
211 | bert_tokens = [self.CLS] + bert_tokens + [self.SEP]
212 | bert_item["subword_ids"], bert_item["mask"] = self.convert_tokens_to_ids(bert_tokens)
213 |
214 | bert_item["label_types"] = [self.CLS] + bert_item["label_types"] + [self.SEP]
215 |
216 | bert_item["label_rels"] = [["N"]] + bert_item["label_rels"] + [["N"]]
217 |
218 | bert_item_index_rels_ = []
219 | for i in bert_item["index_rels"]:
220 | row = []
221 | for j in i:
222 | row.append(j+1)
223 | bert_item_index_rels_.append(row)
224 | bert_item["index_rels"] = [[0]] + bert_item_index_rels_ + [[len(bert_item_index_rels_) + 1]]
225 |
226 | bert_item["token_starts"] = torch.zeros(1, self.max_len).to(bert_item["subword_ids"])
227 | bert_item["token_starts"][0, bert_item["token_starts"]] = 1
228 |
229 | bert_item_label_types = []
230 | for label_type in bert_item["label_types"]:
231 | bert_item_label_types.append(self.ner_list.index(label_type))
232 |
233 | bert_item["label_types"] = self.padding_labels(bert_item_label_types, padding)
234 |
235 | bert_item["rel_matrix"] = self.convert_rel_ids(bert_item["index_rels"], bert_item["label_rels"])
236 |
237 | bert_item["token_type_ids"] = torch.zeros(bert_item["mask"].size(0), bert_item["mask"].size(1))
238 | bert_item["token_type_ids"][:, :len(bert_tokens)] = 1
239 | bert_item["token_type_ids"] = bert_item["token_type_ids"].long()
240 | #################################################################
241 | bert_data[id] = bert_item
242 |
243 |
244 | return bert_data
245 |
246 | def convert_rel_ids(self, index_rels, label_rels):
247 | num_rels = len(self.rel_list)
248 | rel_matrix = torch.zeros(1, self.max_len, self.rel_max_len * num_rels)
249 |
250 | for i, (idxs, labels) in enumerate(zip(index_rels, label_rels)):
251 | for idx, label in zip(idxs, labels):
252 | rel_matrix[0, i, idx * num_rels + self.rel_list.index(label)] = 1
253 | return rel_matrix
254 |
255 | def subword_tokenize_to_ids(self, tokens):
256 | subwords, token_start_idxs = self.subword_tokenize(tokens)
257 | subword_ids, mask = self.convert_tokens_to_ids(subwords)
258 | token_starts = torch.zeros(1, self.max_len).to(subword_ids)
259 | token_starts[0, token_start_idxs] = 1
260 | return subword_ids, mask, token_starts
261 |
262 | def segment_ids(self, segment1_len, segment2_len):
263 | ids = [0] * segment1_len + [1] * segment2_len
264 | return torch.tensor([ids]).to(device=self.device)
265 |
266 | def get_featurized_sentence(self, sentence, padding=True):
267 | tokens = self.tokenize(sentence)
268 | if padding:
269 | input_ids, mask = self.convert_tokens_to_ids (tokens, padding)
270 | else:
271 | input_ids = self.convert_tokens_to_ids (tokens, padding)
272 |
273 | input_types = torch.zeros(input_ids.size(0), input_ids.size(1))
274 | input_types[:, :len(tokens)] = 1
275 |
276 | empty_ners = torch.zeros(input_ids.size(0), input_ids.size(1))
277 | empty_rels = torch.zeros(input_ids.size(0), input_ids.size(1), len(self.rel_list))
278 | return input_ids, mask, input_types, empty_ners, empty_rels
279 |
280 | def get_featurized_sentences(self, file_name, padding=True):
281 | original_data = self.parseData(file_name)
282 | bertData = self.convertData(original_data, padding)
283 |
284 | featurized_sentences = []
285 | for id in bertData:
286 | features = {}
287 | features["bert_ids"], features["bert_mask"], features["bert_token_starts"], features["ett_tags"], features["rel_matrix"], features["token_type_ids"] = \
288 | bertData[id]["subword_ids"], bertData[id]["mask"], bertData[id]["token_starts"], bertData[id]["label_types"], bertData[id]["rel_matrix"], bertData[id]["token_type_ids"]
289 | featurized_sentences.append(features)
290 |
291 | self.num_samples = len(featurized_sentences)
292 | return featurized_sentences
293 |
294 | def get_generator(self, file_name, batch_size=32, padding=True, is_shuffle=True):
295 | batch_input_ids = []
296 | batch_mask_ids = []
297 | batch_ner_ids = []
298 | batch_input_type_ids = []
299 | batch_rel_matrix = []
300 | featurized_sentences = self.get_featurized_sentences(file_name, padding)
301 | if is_shuffle:
302 | shuffle(featurized_sentences)
303 | for idx, features in enumerate(featurized_sentences):
304 | batch_input_ids.append(features["bert_ids"])
305 | batch_input_type_ids.append(features["token_type_ids"])
306 | batch_mask_ids.append(features["bert_mask"])
307 | batch_ner_ids.append(features["ett_tags"])
308 | batch_rel_matrix.append(features["rel_matrix"])
309 | if idx % batch_size == batch_size-1:
310 | return_batch_input_ids = torch.cat(batch_input_ids, dim=0)
311 | return_batch_mask_ids = torch.cat(batch_mask_ids, dim=0)
312 | return_batch_input_type_ids = torch.cat(batch_input_type_ids, dim=0)
313 | return_batch_label_ids = torch.cat(batch_ner_ids, dim=0)
314 | return_batch_rel_matrix = torch.cat(batch_rel_matrix, dim=0)
315 |
316 | batch_input_ids = []
317 | batch_mask_ids = []
318 | batch_ner_ids = []
319 | batch_input_type_ids = []
320 | batch_rel_matrix = []
321 |
322 | yield return_batch_input_ids, return_batch_mask_ids, return_batch_input_type_ids, return_batch_label_ids, return_batch_rel_matrix
323 |
324 | if len(batch_input_ids) != 0:
325 | yield torch.cat(batch_input_ids, dim=0), torch.cat(batch_mask_ids, dim=0), torch.cat(batch_input_type_ids, dim=0), torch.cat(batch_ner_ids, dim=0), torch.cat(batch_rel_matrix, dim=0)
326 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/modeling_transfo_xl_utilities.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HugginFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ Utilities for PyTorch Transformer XL model.
17 | Directly adapted from https://github.com/kimiyoung/transformer-xl.
18 | """
19 |
20 | from collections import defaultdict
21 |
22 | import numpy as np
23 |
24 | import torch
25 | import torch.nn as nn
26 | import torch.nn.functional as F
27 |
28 | # CUDA_MAJOR = int(torch.version.cuda.split('.')[0])
29 | # CUDA_MINOR = int(torch.version.cuda.split('.')[1])
30 |
31 | class ProjectedAdaptiveLogSoftmax(nn.Module):
32 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
33 | keep_order=False):
34 | super(ProjectedAdaptiveLogSoftmax, self).__init__()
35 |
36 | self.n_token = n_token
37 | self.d_embed = d_embed
38 | self.d_proj = d_proj
39 |
40 | self.cutoffs = cutoffs + [n_token]
41 | self.cutoff_ends = [0] + self.cutoffs
42 | self.div_val = div_val
43 |
44 | self.shortlist_size = self.cutoffs[0]
45 | self.n_clusters = len(self.cutoffs) - 1
46 | self.head_size = self.shortlist_size + self.n_clusters
47 |
48 | if self.n_clusters > 0:
49 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))
50 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
51 |
52 | self.out_layers = nn.ModuleList()
53 | self.out_projs = nn.ParameterList()
54 |
55 | if div_val == 1:
56 | for i in range(len(self.cutoffs)):
57 | if d_proj != d_embed:
58 | self.out_projs.append(
59 | nn.Parameter(torch.Tensor(d_proj, d_embed))
60 | )
61 | else:
62 | self.out_projs.append(None)
63 |
64 | self.out_layers.append(nn.Linear(d_embed, n_token))
65 | else:
66 | for i in range(len(self.cutoffs)):
67 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
68 | d_emb_i = d_embed // (div_val ** i)
69 |
70 | self.out_projs.append(
71 | nn.Parameter(torch.Tensor(d_proj, d_emb_i))
72 | )
73 |
74 | self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx))
75 |
76 | self.keep_order = keep_order
77 |
78 | def _compute_logit(self, hidden, weight, bias, proj):
79 | if proj is None:
80 | logit = F.linear(hidden, weight, bias=bias)
81 | else:
82 | # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1:
83 | proj_hid = F.linear(hidden, proj.t().contiguous())
84 | logit = F.linear(proj_hid, weight, bias=bias)
85 | # else:
86 | # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
87 | # if bias is not None:
88 | # logit = logit + bias
89 |
90 | return logit
91 |
92 | def forward(self, hidden, target=None, keep_order=False):
93 | '''
94 | Params:
95 | hidden :: [len*bsz x d_proj]
96 | target :: [len*bsz]
97 | Return:
98 | if target is None:
99 | out :: [len*bsz] Negative log likelihood
100 | else:
101 | out :: [len*bsz x n_tokens] log probabilities of tokens over the vocabulary
102 | We could replace this implementation by the native PyTorch one
103 | if their's had an option to set bias on all clusters in the native one.
104 | here: https://github.com/pytorch/pytorch/blob/dbe6a7a9ff1a364a8706bf5df58a1ca96d2fd9da/torch/nn/modules/adaptive.py#L138
105 | '''
106 |
107 | if target is not None:
108 | target = target.view(-1)
109 | if hidden.size(0) != target.size(0):
110 | raise RuntimeError('Input and target should have the same size '
111 | 'in the batch dimension.')
112 |
113 | if self.n_clusters == 0:
114 | logit = self._compute_logit(hidden, self.out_layers[0].weight,
115 | self.out_layers[0].bias, self.out_projs[0])
116 | if target is not None:
117 | output = -F.log_softmax(logit, dim=-1) \
118 | .gather(1, target.unsqueeze(1)).squeeze(1)
119 | else:
120 | output = F.log_softmax(logit, dim=-1)
121 | else:
122 | # construct weights and biases
123 | weights, biases = [], []
124 | for i in range(len(self.cutoffs)):
125 | if self.div_val == 1:
126 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
127 | weight_i = self.out_layers[0].weight[l_idx:r_idx]
128 | bias_i = self.out_layers[0].bias[l_idx:r_idx]
129 | else:
130 | weight_i = self.out_layers[i].weight
131 | bias_i = self.out_layers[i].bias
132 |
133 | if i == 0:
134 | weight_i = torch.cat(
135 | [weight_i, self.cluster_weight], dim=0)
136 | bias_i = torch.cat(
137 | [bias_i, self.cluster_bias], dim=0)
138 |
139 | weights.append(weight_i)
140 | biases.append(bias_i)
141 |
142 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]
143 |
144 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
145 | head_logprob = F.log_softmax(head_logit, dim=1)
146 |
147 | if target is None:
148 | out = hidden.new_empty((head_logit.size(0), self.n_token))
149 | else:
150 | out = torch.zeros_like(target, dtype=hidden.dtype, device=hidden.device)
151 |
152 | offset = 0
153 | cutoff_values = [0] + self.cutoffs
154 | for i in range(len(cutoff_values) - 1):
155 | l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
156 |
157 | if target is not None:
158 | mask_i = (target >= l_idx) & (target < r_idx)
159 | indices_i = mask_i.nonzero().squeeze()
160 |
161 | if indices_i.numel() == 0:
162 | continue
163 |
164 | target_i = target.index_select(0, indices_i) - l_idx
165 | head_logprob_i = head_logprob.index_select(0, indices_i)
166 | hidden_i = hidden.index_select(0, indices_i)
167 | else:
168 | hidden_i = hidden
169 |
170 | if i == 0:
171 | if target is not None:
172 | logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1)
173 | else:
174 | out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]]
175 | else:
176 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
177 |
178 | tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
179 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
180 | cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster
181 | if target is not None:
182 | logprob_i = head_logprob_i[:, cluster_prob_idx] \
183 | + tail_logprob_i.gather(1, target_i[:, None]).squeeze(1)
184 | else:
185 | logprob_i = head_logprob[:, cluster_prob_idx, None] + tail_logprob_i
186 | out[:, l_idx:r_idx] = logprob_i
187 |
188 | if target is not None:
189 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
190 | out.index_copy_(0, indices_i, -logprob_i)
191 | else:
192 | out[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
193 | offset += logprob_i.size(0)
194 |
195 | return out
196 |
197 |
198 | def log_prob(self, hidden):
199 | r""" Computes log probabilities for all :math:`n\_classes`
200 | From: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/adaptive.py
201 | Args:
202 | hidden (Tensor): a minibatch of examples
203 | Returns:
204 | log-probabilities of for each class :math:`c`
205 | in range :math:`0 <= c <= n\_classes`, where :math:`n\_classes` is a
206 | parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor.
207 | Shape:
208 | - Input: :math:`(N, in\_features)`
209 | - Output: :math:`(N, n\_classes)`
210 | """
211 | if self.n_clusters == 0:
212 | logit = self._compute_logit(hidden, self.out_layers[0].weight,
213 | self.out_layers[0].bias, self.out_projs[0])
214 | return F.log_softmax(logit, dim=-1)
215 | else:
216 | # construct weights and biases
217 | weights, biases = [], []
218 | for i in range(len(self.cutoffs)):
219 | if self.div_val == 1:
220 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
221 | weight_i = self.out_layers[0].weight[l_idx:r_idx]
222 | bias_i = self.out_layers[0].bias[l_idx:r_idx]
223 | else:
224 | weight_i = self.out_layers[i].weight
225 | bias_i = self.out_layers[i].bias
226 |
227 | if i == 0:
228 | weight_i = torch.cat(
229 | [weight_i, self.cluster_weight], dim=0)
230 | bias_i = torch.cat(
231 | [bias_i, self.cluster_bias], dim=0)
232 |
233 | weights.append(weight_i)
234 | biases.append(bias_i)
235 |
236 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]
237 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
238 |
239 | out = hidden.new_empty((head_logit.size(0), self.n_token))
240 | head_logprob = F.log_softmax(head_logit, dim=1)
241 |
242 | cutoff_values = [0] + self.cutoffs
243 | for i in range(len(cutoff_values) - 1):
244 | start_idx, stop_idx = cutoff_values[i], cutoff_values[i + 1]
245 |
246 | if i == 0:
247 | out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]]
248 | else:
249 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
250 |
251 | tail_logit_i = self._compute_logit(hidden, weight_i, bias_i, proj_i)
252 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
253 |
254 | logprob_i = head_logprob[:, -i] + tail_logprob_i
255 | out[:, start_idx, stop_idx] = logprob_i
256 |
257 | return out
258 |
259 |
260 | class LogUniformSampler(object):
261 | def __init__(self, range_max, n_sample):
262 | """
263 | Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
264 | `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
265 |
266 | expected count can be approximated by 1 - (1 - p)^n
267 | and we use a numerically stable version -expm1(num_tries * log1p(-p))
268 |
269 | Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run
270 | """
271 | with torch.no_grad():
272 | self.range_max = range_max
273 | log_indices = torch.arange(1., range_max+2., 1.).log_()
274 | self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
275 | # print('P', self.dist.numpy().tolist()[-30:])
276 |
277 | self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float()
278 |
279 | self.n_sample = n_sample
280 |
281 | def sample(self, labels):
282 | """
283 | labels: [b1, b2]
284 | Return
285 | true_log_probs: [b1, b2]
286 | samp_log_probs: [n_sample]
287 | neg_samples: [n_sample]
288 | """
289 |
290 | # neg_samples = torch.empty(0).long()
291 | n_sample = self.n_sample
292 | n_tries = 2 * n_sample
293 |
294 | with torch.no_grad():
295 | neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
296 | device = labels.device
297 | neg_samples = neg_samples.to(device)
298 | true_log_probs = self.log_q[labels].to(device)
299 | samp_log_probs = self.log_q[neg_samples].to(device)
300 | return true_log_probs, samp_log_probs, neg_samples
301 |
302 | def sample_logits(embedding, bias, labels, inputs, sampler):
303 | """
304 | embedding: an nn.Embedding layer
305 | bias: [n_vocab]
306 | labels: [b1, b2]
307 | inputs: [b1, b2, n_emb]
308 | sampler: you may use a LogUniformSampler
309 | Return
310 | logits: [b1, b2, 1 + n_sample]
311 | """
312 | true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels)
313 | n_sample = neg_samples.size(0)
314 | b1, b2 = labels.size(0), labels.size(1)
315 | all_ids = torch.cat([labels.view(-1), neg_samples])
316 | all_w = embedding(all_ids)
317 | true_w = all_w[: -n_sample].view(b1, b2, -1)
318 | sample_w = all_w[- n_sample:].view(n_sample, -1)
319 |
320 | all_b = bias[all_ids]
321 | true_b = all_b[: -n_sample].view(b1, b2)
322 | sample_b = all_b[- n_sample:]
323 |
324 | hit = (labels[:, :, None] == neg_samples).detach()
325 |
326 | true_logits = torch.einsum('ijk,ijk->ij',
327 | [true_w, inputs]) + true_b - true_log_probs
328 | sample_logits = torch.einsum('lk,ijk->ijl',
329 | [sample_w, inputs]) + sample_b - samp_log_probs
330 | sample_logits.masked_fill_(hit, -1e30)
331 | logits = torch.cat([true_logits[:, :, None], sample_logits], -1)
332 |
333 | return logits
334 |
335 |
336 | # class LogUniformSampler(object):
337 | # def __init__(self, range_max, unique=False):
338 | # """
339 | # Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
340 | # `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
341 | # """
342 | # self.range_max = range_max
343 | # log_indices = torch.arange(1., range_max+2., 1.).log_()
344 | # self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
345 |
346 | # self.unique = unique
347 |
348 | # if self.unique:
349 | # self.exclude_mask = torch.ByteTensor(range_max).fill_(0)
350 |
351 | # def sample(self, n_sample, labels):
352 | # pos_sample, new_labels = labels.unique(return_inverse=True)
353 | # n_pos_sample = pos_sample.size(0)
354 | # n_neg_sample = n_sample - n_pos_sample
355 |
356 | # if self.unique:
357 | # self.exclude_mask.index_fill_(0, pos_sample, 1)
358 | # sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0)
359 | # self.exclude_mask.index_fill_(0, pos_sample, 0)
360 | # else:
361 | # sample_dist = self.dist
362 |
363 | # neg_sample = torch.multinomial(sample_dist, n_neg_sample)
364 |
365 | # sample = torch.cat([pos_sample, neg_sample])
366 | # sample_prob = self.dist[sample]
367 |
368 | # return new_labels, sample, sample_prob
369 |
370 |
371 | if __name__ == '__main__':
372 | S, B = 3, 4
373 | n_vocab = 10000
374 | n_sample = 5
375 | H = 32
376 |
377 | labels = torch.LongTensor(S, B).random_(0, n_vocab)
378 |
379 | # sampler = LogUniformSampler(n_vocab, unique=False)
380 | # new_labels, sample, sample_prob = sampler.sample(n_sample, labels)
381 |
382 | sampler = LogUniformSampler(n_vocab, n_sample)#, unique=True)
383 | # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels)
384 |
385 | # print('true_probs', true_probs.numpy().tolist())
386 | # print('samp_probs', samp_probs.numpy().tolist())
387 | # print('neg_samples', neg_samples.numpy().tolist())
388 |
389 | # print('sum', torch.sum(sampler.dist).item())
390 |
391 | # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item()
392 |
393 | embedding = nn.Embedding(n_vocab, H)
394 | bias = torch.zeros(n_vocab)
395 | inputs = torch.Tensor(S, B, H).normal_()
396 |
397 | logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample)
398 | print('logits', logits.detach().numpy().tolist())
399 | print('logits shape', logits.size())
400 | print('out_labels', out_labels.detach().numpy().tolist())
401 | print('out_labels shape', out_labels.size())
402 |
403 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/tokenization_transfo_xl.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HugginFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ Tokenization classes for Transformer XL model.
17 | Adapted from https://github.com/kimiyoung/transformer-xl.
18 | """
19 | from __future__ import (absolute_import, division, print_function,
20 | unicode_literals)
21 |
22 | import glob
23 | import logging
24 | import os
25 | import sys
26 | from collections import Counter, OrderedDict
27 | from io import open
28 | import unicodedata
29 |
30 | import torch
31 | import numpy as np
32 |
33 | from .file_utils import cached_path
34 |
35 | if sys.version_info[0] == 2:
36 | import cPickle as pickle
37 | else:
38 | import pickle
39 |
40 |
41 | logger = logging.getLogger(__name__)
42 |
43 | PRETRAINED_VOCAB_ARCHIVE_MAP = {
44 | 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",
45 | }
46 | VOCAB_NAME = 'vocab.bin'
47 |
48 | PRETRAINED_CORPUS_ARCHIVE_MAP = {
49 | 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin",
50 | }
51 | CORPUS_NAME = 'corpus.bin'
52 |
53 | class TransfoXLTokenizer(object):
54 | """
55 | Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl
56 | """
57 | @classmethod
58 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
59 | """
60 | Instantiate a TransfoXLTokenizer.
61 | The TransfoXLTokenizer.
62 | """
63 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
64 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
65 | else:
66 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
67 | # redirect to the cache, if necessary
68 | try:
69 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
70 | except EnvironmentError:
71 | logger.error(
72 | "Model name '{}' was not found in model name list ({}). "
73 | "We assumed '{}' was a path or url but couldn't find files {} "
74 | "at this path or url.".format(
75 | pretrained_model_name_or_path,
76 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
77 | pretrained_model_name_or_path,
78 | vocab_file))
79 | return None
80 | if resolved_vocab_file == vocab_file:
81 | logger.info("loading vocabulary file {}".format(vocab_file))
82 | else:
83 | logger.info("loading vocabulary file {} from cache at {}".format(
84 | vocab_file, resolved_vocab_file))
85 |
86 | # Instantiate tokenizer.
87 | tokenizer = cls(*inputs, **kwargs)
88 | vocab_dict = torch.load(resolved_vocab_file)
89 | for key, value in vocab_dict.items():
90 | tokenizer.__dict__[key] = value
91 | return tokenizer
92 |
93 | def __init__(self, special=[], min_freq=0, max_size=None, lower_case=False,
94 | delimiter=None, vocab_file=None, never_split=("", "", "")):
95 | self.counter = Counter()
96 | self.special = special
97 | self.min_freq = min_freq
98 | self.max_size = max_size
99 | self.lower_case = lower_case
100 | self.delimiter = delimiter
101 | self.vocab_file = vocab_file
102 | self.never_split = never_split
103 |
104 | def count_file(self, path, verbose=False, add_eos=False):
105 | if verbose: print('counting file {} ...'.format(path))
106 | assert os.path.exists(path)
107 |
108 | sents = []
109 | with open(path, 'r', encoding='utf-8') as f:
110 | for idx, line in enumerate(f):
111 | if verbose and idx > 0 and idx % 500000 == 0:
112 | print(' line {}'.format(idx))
113 | symbols = self.tokenize(line, add_eos=add_eos)
114 | self.counter.update(symbols)
115 | sents.append(symbols)
116 |
117 | return sents
118 |
119 | def count_sents(self, sents, verbose=False):
120 | """
121 | sents : a list of sentences, each a list of tokenized symbols
122 | """
123 | if verbose: print('counting {} sents ...'.format(len(sents)))
124 | for idx, symbols in enumerate(sents):
125 | if verbose and idx > 0 and idx % 500000 == 0:
126 | print(' line {}'.format(idx))
127 | self.counter.update(symbols)
128 |
129 | def _build_from_file(self, vocab_file):
130 | self.idx2sym = []
131 | self.sym2idx = OrderedDict()
132 |
133 | with open(vocab_file, 'r', encoding='utf-8') as f:
134 | for line in f:
135 | symb = line.strip().split()[0]
136 | self.add_symbol(symb)
137 | if '' in self.sym2idx:
138 | self.unk_idx = self.sym2idx['']
139 | elif '' in self.sym2idx:
140 | self.unk_idx = self.sym2idx['']
141 | else:
142 | raise ValueError('No token in vocabulary')
143 |
144 | def build_vocab(self):
145 | if self.vocab_file:
146 | print('building vocab from {}'.format(self.vocab_file))
147 | self._build_from_file(self.vocab_file)
148 | print('final vocab size {}'.format(len(self)))
149 | else:
150 | print('building vocab with min_freq={}, max_size={}'.format(
151 | self.min_freq, self.max_size))
152 | self.idx2sym = []
153 | self.sym2idx = OrderedDict()
154 |
155 | for sym in self.special:
156 | self.add_special(sym)
157 |
158 | for sym, cnt in self.counter.most_common(self.max_size):
159 | if cnt < self.min_freq: break
160 | self.add_symbol(sym)
161 |
162 | print('final vocab size {} from {} unique tokens'.format(
163 | len(self), len(self.counter)))
164 |
165 | def encode_file(self, path, ordered=False, verbose=False, add_eos=True,
166 | add_double_eos=False):
167 | if verbose: print('encoding file {} ...'.format(path))
168 | assert os.path.exists(path)
169 | encoded = []
170 | with open(path, 'r', encoding='utf-8') as f:
171 | for idx, line in enumerate(f):
172 | if verbose and idx > 0 and idx % 500000 == 0:
173 | print(' line {}'.format(idx))
174 | symbols = self.tokenize(line, add_eos=add_eos,
175 | add_double_eos=add_double_eos)
176 | encoded.append(self.convert_to_tensor(symbols))
177 |
178 | if ordered:
179 | encoded = torch.cat(encoded)
180 |
181 | return encoded
182 |
183 | def encode_sents(self, sents, ordered=False, verbose=False):
184 | if verbose: print('encoding {} sents ...'.format(len(sents)))
185 | encoded = []
186 | for idx, symbols in enumerate(sents):
187 | if verbose and idx > 0 and idx % 500000 == 0:
188 | print(' line {}'.format(idx))
189 | encoded.append(self.convert_to_tensor(symbols))
190 |
191 | if ordered:
192 | encoded = torch.cat(encoded)
193 |
194 | return encoded
195 |
196 | def add_special(self, sym):
197 | if sym not in self.sym2idx:
198 | self.idx2sym.append(sym)
199 | self.sym2idx[sym] = len(self.idx2sym) - 1
200 | setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])
201 |
202 | def add_symbol(self, sym):
203 | if sym not in self.sym2idx:
204 | self.idx2sym.append(sym)
205 | self.sym2idx[sym] = len(self.idx2sym) - 1
206 |
207 | def get_sym(self, idx):
208 | assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx)
209 | return self.idx2sym[idx]
210 |
211 | def get_idx(self, sym):
212 | if sym in self.sym2idx:
213 | return self.sym2idx[sym]
214 | else:
215 | # print('encounter unk {}'.format(sym))
216 | # assert '' not in sym
217 | if hasattr(self, 'unk_idx'):
218 | return self.sym2idx.get(sym, self.unk_idx)
219 | # Backward compatibility with pre-trained models
220 | elif '' in self.sym2idx:
221 | return self.sym2idx['']
222 | elif '' in self.sym2idx:
223 | return self.sym2idx['']
224 | else:
225 | raise ValueError('Token not in vocabulary and no token in vocabulary for replacement')
226 |
227 | def convert_ids_to_tokens(self, indices):
228 | """Converts a sequence of indices in symbols using the vocab."""
229 | return [self.get_sym(idx) for idx in indices]
230 |
231 | def convert_tokens_to_ids(self, symbols):
232 | """Converts a sequence of symbols into ids using the vocab."""
233 | return [self.get_idx(sym) for sym in symbols]
234 |
235 | def convert_to_tensor(self, symbols):
236 | return torch.LongTensor(self.convert_tokens_to_ids(symbols))
237 |
238 | def decode(self, indices, exclude=None):
239 | """Converts a sequence of indices in a string."""
240 | if exclude is None:
241 | return ' '.join([self.get_sym(idx) for idx in indices])
242 | else:
243 | return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
244 |
245 | def __len__(self):
246 | return len(self.idx2sym)
247 |
248 | def _run_split_on_punc(self, text):
249 | """Splits punctuation on a piece of text."""
250 | if text in self.never_split:
251 | return [text]
252 | chars = list(text)
253 | i = 0
254 | start_new_word = True
255 | output = []
256 | while i < len(chars):
257 | char = chars[i]
258 | if _is_punctuation(char):
259 | output.append([char])
260 | start_new_word = True
261 | else:
262 | if start_new_word:
263 | output.append([])
264 | start_new_word = False
265 | output[-1].append(char)
266 | i += 1
267 |
268 | return ["".join(x) for x in output]
269 |
270 | def _run_strip_accents(self, text):
271 | """Strips accents from a piece of text."""
272 | text = unicodedata.normalize("NFD", text)
273 | output = []
274 | for char in text:
275 | cat = unicodedata.category(char)
276 | if cat == "Mn":
277 | continue
278 | output.append(char)
279 | return "".join(output)
280 |
281 | def _clean_text(self, text):
282 | """Performs invalid character removal and whitespace cleanup on text."""
283 | output = []
284 | for char in text:
285 | cp = ord(char)
286 | if cp == 0 or cp == 0xfffd or _is_control(char):
287 | continue
288 | if _is_whitespace(char):
289 | output.append(" ")
290 | else:
291 | output.append(char)
292 | return "".join(output)
293 |
294 | def whitespace_tokenize(self, text):
295 | """Runs basic whitespace cleaning and splitting on a piece of text."""
296 | text = text.strip()
297 | if not text:
298 | return []
299 | if self.delimiter == '':
300 | tokens = text
301 | else:
302 | tokens = text.split(self.delimiter)
303 | return tokens
304 |
305 | def tokenize(self, line, add_eos=False, add_double_eos=False):
306 | line = self._clean_text(line)
307 | line = line.strip()
308 |
309 | symbols = self.whitespace_tokenize(line)
310 |
311 | split_symbols = []
312 | for symbol in symbols:
313 | if self.lower_case and symbol not in self.never_split:
314 | symbol = symbol.lower()
315 | symbol = self._run_strip_accents(symbol)
316 | split_symbols.extend(self._run_split_on_punc(symbol))
317 |
318 | if add_double_eos: # lm1b
319 | return [''] + split_symbols + ['']
320 | elif add_eos:
321 | return split_symbols + ['']
322 | else:
323 | return split_symbols
324 |
325 |
326 | class LMOrderedIterator(object):
327 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None):
328 | """
329 | data -- LongTensor -- the LongTensor is strictly ordered
330 | """
331 | self.bsz = bsz
332 | self.bptt = bptt
333 | self.ext_len = ext_len if ext_len is not None else 0
334 |
335 | self.device = device
336 |
337 | # Work out how cleanly we can divide the dataset into bsz parts.
338 | self.n_step = data.size(0) // bsz
339 |
340 | # Trim off any extra elements that wouldn't cleanly fit (remainders).
341 | data = data.narrow(0, 0, self.n_step * bsz)
342 |
343 | # Evenly divide the data across the bsz batches.
344 | self.data = data.view(bsz, -1).t().contiguous().to(device)
345 |
346 | # Number of mini-batches
347 | self.n_batch = (self.n_step + self.bptt - 1) // self.bptt
348 |
349 | def get_batch(self, i, bptt=None):
350 | if bptt is None: bptt = self.bptt
351 | seq_len = min(bptt, self.data.size(0) - 1 - i)
352 |
353 | end_idx = i + seq_len
354 | beg_idx = max(0, i - self.ext_len)
355 |
356 | data = self.data[beg_idx:end_idx]
357 | target = self.data[i+1:i+1+seq_len]
358 |
359 | data_out = data.transpose(0, 1).contiguous().to(self.device)
360 | target_out = target.transpose(0, 1).contiguous().to(self.device)
361 |
362 | return data_out, target_out, seq_len
363 |
364 | def get_fixlen_iter(self, start=0):
365 | for i in range(start, self.data.size(0) - 1, self.bptt):
366 | yield self.get_batch(i)
367 |
368 | def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3):
369 | max_len = self.bptt + max_deviation * std
370 | i = start
371 | while True:
372 | bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.
373 | bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std))))
374 | data, target, seq_len = self.get_batch(i, bptt)
375 | i += seq_len
376 | yield data, target, seq_len
377 | if i >= self.data.size(0) - 2:
378 | break
379 |
380 | def __iter__(self):
381 | return self.get_fixlen_iter()
382 |
383 |
384 | class LMShuffledIterator(object):
385 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False):
386 | """
387 | data -- list[LongTensor] -- there is no order among the LongTensors
388 | """
389 | self.data = data
390 |
391 | self.bsz = bsz
392 | self.bptt = bptt
393 | self.ext_len = ext_len if ext_len is not None else 0
394 |
395 | self.device = device
396 | self.shuffle = shuffle
397 |
398 | def get_sent_stream(self):
399 | # index iterator
400 | epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \
401 | else np.array(range(len(self.data)))
402 |
403 | # sentence iterator
404 | for idx in epoch_indices:
405 | yield self.data[idx]
406 |
407 | def stream_iterator(self, sent_stream):
408 | # streams for each data in the batch
409 | streams = [None] * self.bsz
410 |
411 | data = torch.LongTensor(self.bptt, self.bsz)
412 | target = torch.LongTensor(self.bptt, self.bsz)
413 |
414 | n_retain = 0
415 |
416 | while True:
417 | # data : [n_retain+bptt x bsz]
418 | # target : [bptt x bsz]
419 | data[n_retain:].fill_(-1)
420 | target.fill_(-1)
421 |
422 | valid_batch = True
423 |
424 | for i in range(self.bsz):
425 | n_filled = 0
426 | try:
427 | while n_filled < self.bptt:
428 | if streams[i] is None or len(streams[i]) <= 1:
429 | streams[i] = next(sent_stream)
430 | # number of new tokens to fill in
431 | n_new = min(len(streams[i]) - 1, self.bptt - n_filled)
432 | # first n_retain tokens are retained from last batch
433 | data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \
434 | streams[i][:n_new]
435 | target[n_filled:n_filled+n_new, i] = \
436 | streams[i][1:n_new+1]
437 | streams[i] = streams[i][n_new:]
438 | n_filled += n_new
439 | except StopIteration:
440 | valid_batch = False
441 | break
442 |
443 | if not valid_batch:
444 | return
445 |
446 | data_out = data.transpose(0, 1).contiguous().to(self.device)
447 | target_out = target.transpose(0, 1).contiguous().to(self.device)
448 |
449 | yield data_out, target_out, self.bptt
450 |
451 | n_retain = min(data.size(0), self.ext_len)
452 | if n_retain > 0:
453 | data[:n_retain] = data[-n_retain:]
454 | data.resize_(n_retain + self.bptt, data.size(1))
455 |
456 | def __iter__(self):
457 | # sent_stream is an iterator
458 | sent_stream = self.get_sent_stream()
459 |
460 | for batch in self.stream_iterator(sent_stream):
461 | yield batch
462 |
463 |
464 | class LMMultiFileIterator(LMShuffledIterator):
465 | def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None,
466 | shuffle=False):
467 |
468 | self.paths = paths
469 | self.vocab = vocab
470 |
471 | self.bsz = bsz
472 | self.bptt = bptt
473 | self.ext_len = ext_len if ext_len is not None else 0
474 |
475 | self.device = device
476 | self.shuffle = shuffle
477 |
478 | def get_sent_stream(self, path):
479 | sents = self.vocab.encode_file(path, add_double_eos=True)
480 | if self.shuffle:
481 | np.random.shuffle(sents)
482 | sent_stream = iter(sents)
483 |
484 | return sent_stream
485 |
486 | def __iter__(self):
487 | if self.shuffle:
488 | np.random.shuffle(self.paths)
489 |
490 | for path in self.paths:
491 | # sent_stream is an iterator
492 | sent_stream = self.get_sent_stream(path)
493 | for batch in self.stream_iterator(sent_stream):
494 | yield batch
495 |
496 |
497 | class TransfoXLCorpus(object):
498 | @classmethod
499 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
500 | """
501 | Instantiate a pre-processed corpus.
502 | """
503 | vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
504 | if pretrained_model_name_or_path in PRETRAINED_CORPUS_ARCHIVE_MAP:
505 | corpus_file = PRETRAINED_CORPUS_ARCHIVE_MAP[pretrained_model_name_or_path]
506 | else:
507 | corpus_file = os.path.join(pretrained_model_name_or_path, CORPUS_NAME)
508 | # redirect to the cache, if necessary
509 | try:
510 | resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir)
511 | except EnvironmentError:
512 | logger.error(
513 | "Corpus '{}' was not found in corpus list ({}). "
514 | "We assumed '{}' was a path or url but couldn't find files {} "
515 | "at this path or url.".format(
516 | pretrained_model_name_or_path,
517 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
518 | pretrained_model_name_or_path,
519 | corpus_file))
520 | return None
521 | if resolved_corpus_file == corpus_file:
522 | logger.info("loading corpus file {}".format(corpus_file))
523 | else:
524 | logger.info("loading corpus file {} from cache at {}".format(
525 | corpus_file, resolved_corpus_file))
526 |
527 | # Instantiate tokenizer.
528 | corpus = cls(*inputs, **kwargs)
529 | corpus_dict = torch.load(resolved_corpus_file)
530 | for key, value in corpus_dict.items():
531 | corpus.__dict__[key] = value
532 | corpus.vocab = vocab
533 | if corpus.train is not None:
534 | corpus.train = torch.tensor(corpus.train, dtype=torch.long)
535 | if corpus.valid is not None:
536 | corpus.valid = torch.tensor(corpus.valid, dtype=torch.long)
537 | if corpus.test is not None:
538 | corpus.test = torch.tensor(corpus.test, dtype=torch.long)
539 | return corpus
540 |
541 | def __init__(self, *args, **kwargs):
542 | self.vocab = TransfoXLTokenizer(*args, **kwargs)
543 | self.dataset = None
544 | self.train = None
545 | self.valid = None
546 | self.test = None
547 |
548 | def build_corpus(self, path, dataset):
549 | self.dataset = dataset
550 |
551 | if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
552 | self.vocab.count_file(os.path.join(path, 'train.txt'))
553 | self.vocab.count_file(os.path.join(path, 'valid.txt'))
554 | self.vocab.count_file(os.path.join(path, 'test.txt'))
555 | elif self.dataset == 'wt103':
556 | self.vocab.count_file(os.path.join(path, 'train.txt'))
557 | elif self.dataset == 'lm1b':
558 | train_path_pattern = os.path.join(
559 | path, '1-billion-word-language-modeling-benchmark-r13output',
560 | 'training-monolingual.tokenized.shuffled', 'news.en-*')
561 | train_paths = glob.glob(train_path_pattern)
562 | # the vocab will load from file when build_vocab() is called
563 |
564 | self.vocab.build_vocab()
565 |
566 | if self.dataset in ['ptb', 'wt2', 'wt103']:
567 | self.train = self.vocab.encode_file(
568 | os.path.join(path, 'train.txt'), ordered=True)
569 | self.valid = self.vocab.encode_file(
570 | os.path.join(path, 'valid.txt'), ordered=True)
571 | self.test = self.vocab.encode_file(
572 | os.path.join(path, 'test.txt'), ordered=True)
573 | elif self.dataset in ['enwik8', 'text8']:
574 | self.train = self.vocab.encode_file(
575 | os.path.join(path, 'train.txt'), ordered=True, add_eos=False)
576 | self.valid = self.vocab.encode_file(
577 | os.path.join(path, 'valid.txt'), ordered=True, add_eos=False)
578 | self.test = self.vocab.encode_file(
579 | os.path.join(path, 'test.txt'), ordered=True, add_eos=False)
580 | elif self.dataset == 'lm1b':
581 | self.train = train_paths
582 | self.valid = self.vocab.encode_file(
583 | os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True)
584 | self.test = self.vocab.encode_file(
585 | os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True)
586 |
587 | def get_iterator(self, split, *args, **kwargs):
588 | if split == 'train':
589 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
590 | data_iter = LMOrderedIterator(self.train, *args, **kwargs)
591 | elif self.dataset == 'lm1b':
592 | kwargs['shuffle'] = True
593 | data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs)
594 | elif split in ['valid', 'test']:
595 | data = self.valid if split == 'valid' else self.test
596 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
597 | data_iter = LMOrderedIterator(data, *args, **kwargs)
598 | elif self.dataset == 'lm1b':
599 | data_iter = LMShuffledIterator(data, *args, **kwargs)
600 |
601 | return data_iter
602 |
603 |
604 | def get_lm_corpus(datadir, dataset):
605 | fn = os.path.join(datadir, 'cache.pt')
606 | fn_pickle = os.path.join(datadir, 'cache.pkl')
607 | if os.path.exists(fn):
608 | print('Loading cached dataset...')
609 | corpus = torch.load(fn_pickle)
610 | elif os.path.exists(fn):
611 | print('Loading cached dataset from pickle...')
612 | with open(fn, "rb") as fp:
613 | corpus = pickle.load(fp)
614 | else:
615 | print('Producing dataset {}...'.format(dataset))
616 | kwargs = {}
617 | if dataset in ['wt103', 'wt2']:
618 | kwargs['special'] = ['']
619 | kwargs['lower_case'] = False
620 | elif dataset == 'ptb':
621 | kwargs['special'] = ['']
622 | kwargs['lower_case'] = True
623 | elif dataset == 'lm1b':
624 | kwargs['special'] = []
625 | kwargs['lower_case'] = False
626 | kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt')
627 | elif dataset in ['enwik8', 'text8']:
628 | pass
629 |
630 | corpus = TransfoXLCorpus(datadir, dataset, **kwargs)
631 | torch.save(corpus, fn)
632 |
633 | return corpus
634 |
635 | def _is_whitespace(char):
636 | """Checks whether `chars` is a whitespace character."""
637 | # \t, \n, and \r are technically contorl characters but we treat them
638 | # as whitespace since they are generally considered as such.
639 | if char == " " or char == "\t" or char == "\n" or char == "\r":
640 | return True
641 | cat = unicodedata.category(char)
642 | if cat == "Zs":
643 | return True
644 | return False
645 |
646 |
647 | def _is_control(char):
648 | """Checks whether `chars` is a control character."""
649 | # These are technically control characters but we count them as whitespace
650 | # characters.
651 | if char == "\t" or char == "\n" or char == "\r":
652 | return False
653 | cat = unicodedata.category(char)
654 | if cat.startswith("C"):
655 | return True
656 | return False
657 |
658 |
659 | def _is_punctuation(char):
660 | """Checks whether `chars` is a punctuation character."""
661 | cp = ord(char)
662 | # We treat all non-letter/number ASCII as punctuation.
663 | # Characters such as "^", "$", and "`" are not in the Unicode
664 | # Punctuation class but we treat them as punctuation anyways, for
665 | # consistency.
666 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
667 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
668 | return True
669 | cat = unicodedata.category(char)
670 | if cat.startswith("P"):
671 | return True
672 | return False
673 |
--------------------------------------------------------------------------------