├── __pycache__ ├── bert.cpython-35.pyc ├── DataGen.cpython-35.pyc └── custom_model.cpython-35.pyc ├── layers ├── __pycache__ │ ├── crf.cpython-35.pyc │ ├── decoder.cpython-35.pyc │ ├── embedding.cpython-35.pyc │ └── encoder.cpython-35.pyc ├── encoder.py ├── .ipynb_checkpoints │ ├── encoder-checkpoint.py │ ├── embedding-checkpoint.py │ ├── decoder-checkpoint.py │ └── crf-checkpoint.py ├── embedding.py ├── decoder.py └── crf.py ├── pytorch_pretrained_bert ├── __init__.pyc ├── file_utils.pyc ├── modeling.pyc ├── optimization.pyc ├── tokenization.pyc ├── modeling_gpt2.pyc ├── modeling_openai.pyc ├── tokenization_gpt2.pyc ├── modeling_transfo_xl.pyc ├── optimization_openai.pyc ├── tokenization_openai.pyc ├── tokenization_transfo_xl.pyc ├── modeling_transfo_xl_utilities.pyc ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── modeling.cpython-35.pyc │ ├── file_utils.cpython-35.pyc │ ├── optimization.cpython-35.pyc │ ├── tokenization.cpython-35.pyc │ ├── modeling_gpt2.cpython-35.pyc │ ├── modeling_openai.cpython-35.pyc │ ├── tokenization_gpt2.cpython-35.pyc │ ├── modeling_transfo_xl.cpython-35.pyc │ ├── optimization_openai.cpython-35.pyc │ ├── tokenization_openai.cpython-35.pyc │ ├── tokenization_transfo_xl.cpython-35.pyc │ └── modeling_transfo_xl_utilities.cpython-35.pyc ├── __init__.py ├── convert_tf_checkpoint_to_pytorch.py ├── convert_gpt2_checkpoint_to_pytorch.py ├── convert_openai_checkpoint_to_pytorch.py ├── __main__.py ├── convert_transfo_xl_checkpoint_to_pytorch.py ├── optimization_openai.py ├── optimization.py ├── file_utils.py ├── tokenization_gpt2.py ├── tokenization_openai.py ├── tokenization.py ├── modeling_transfo_xl_utilities.py └── tokenization_transfo_xl.py ├── data └── train_analysis.txt ├── README.md ├── multi_cased_L-12_H-768_A-12 └── bert_config.json ├── debug.py ├── predict.py ├── train.py ├── model.py └── DataGen.py /__pycache__/bert.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/__pycache__/bert.cpython-35.pyc -------------------------------------------------------------------------------- /__pycache__/DataGen.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/__pycache__/DataGen.cpython-35.pyc -------------------------------------------------------------------------------- /layers/__pycache__/crf.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/layers/__pycache__/crf.cpython-35.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__init__.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/file_utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/file_utils.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/modeling.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/modeling.pyc -------------------------------------------------------------------------------- /__pycache__/custom_model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/__pycache__/custom_model.cpython-35.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/optimization.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/optimization.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/tokenization.pyc -------------------------------------------------------------------------------- /layers/__pycache__/decoder.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/layers/__pycache__/decoder.cpython-35.pyc -------------------------------------------------------------------------------- /layers/__pycache__/embedding.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/layers/__pycache__/embedding.cpython-35.pyc -------------------------------------------------------------------------------- /layers/__pycache__/encoder.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/layers/__pycache__/encoder.cpython-35.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/modeling_gpt2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/modeling_gpt2.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/modeling_openai.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/modeling_openai.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization_gpt2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/tokenization_gpt2.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/modeling_transfo_xl.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/modeling_transfo_xl.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/optimization_openai.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/optimization_openai.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization_openai.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/tokenization_openai.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization_transfo_xl.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/tokenization_transfo_xl.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/modeling_transfo_xl_utilities.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/modeling_transfo_xl_utilities.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/modeling.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/modeling.cpython-35.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/file_utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/file_utils.cpython-35.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/optimization.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/optimization.cpython-35.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/tokenization.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/tokenization.cpython-35.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/modeling_gpt2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/modeling_gpt2.cpython-35.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/modeling_openai.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/modeling_openai.cpython-35.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/tokenization_gpt2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/tokenization_gpt2.cpython-35.pyc -------------------------------------------------------------------------------- /data/train_analysis.txt: -------------------------------------------------------------------------------- 1 | #doc1 2 | 0 Ai B-VAR ['N'] [0] 3 | 1 là B-RLT ['PO','SP'] [0,2] 4 | 2 Langston_Hughes B-ETT ['N'] [2] 5 | #doc2 6 | 0 Ai B-VAR ['N'] [0] 7 | 1 là B-RLT ['PO','SP'] [0,2] 8 | 2 Damocles B-ETT ['N'] [2] 9 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/modeling_transfo_xl.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/modeling_transfo_xl.cpython-35.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/optimization_openai.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/optimization_openai.cpython-35.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/tokenization_openai.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/tokenization_openai.cpython-35.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/tokenization_transfo_xl.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/tokenization_transfo_xl.cpython-35.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/modeling_transfo_xl_utilities.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoanglocla9/bert-jointly-relation-entity-extraction/HEAD/pytorch_pretrained_bert/__pycache__/modeling_transfo_xl_utilities.cpython-35.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bert-jointly-relation-entity-extraction 2 | 3 | ## Reference: 4 | [1] BERT pytorch: https://github.com/huggingface/pytorch-pretrained-BERT
5 | [2] CRF, NER pytorch: https://github.com/sberbank-ai/ner-bert
6 | [3] multihead jointly relation & entity extraction: https://github.com/bekou/multihead_joint_entity_relation_extraction
7 | Article for [3]: https://arxiv.org/abs/1804.07847
8 | -------------------------------------------------------------------------------- /multi_cased_L-12_H-768_A-12/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 12, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 119547 19 | } 20 | -------------------------------------------------------------------------------- /debug.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_pretrained_bert.modeling import * 3 | from DataGen import DataGenerator 4 | from model import BertBiLSTMCRF 5 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 6 | 7 | if __name__ == "__main__": 8 | BERT_PRETRAINED_PATH = "./multi_cased_L-12_H-768_A-12/" 9 | TRAIN_PATH = "/data/loclh2/QABot/data/train_analysis.txt" 10 | batch_size = 32 11 | shuffle = False 12 | 13 | data_gen = DataGenerator(model=BertModel, model_name=BERT_PRETRAINED_PATH) 14 | 15 | train_gen = data_gen.get_generator(TRAIN_PATH, batch_size, shuffle=shuffle) 16 | for batch in train_gen: 17 | pass 18 | 19 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.1" 2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | from .tokenization_openai import OpenAIGPTTokenizer 4 | from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) 5 | from .tokenization_gpt2 import GPT2Tokenizer 6 | 7 | from .modeling import (BertConfig, BertModel, BertForPreTraining, 8 | BertForMaskedLM, BertForNextSentencePrediction, 9 | BertForSequenceClassification, BertForMultipleChoice, 10 | BertForTokenClassification, BertForQuestionAnswering, 11 | load_tf_weights_in_bert) 12 | from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, 13 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, 14 | load_tf_weights_in_openai_gpt) 15 | from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel, 16 | load_tf_weights_in_transfo_xl) 17 | from .modeling_gpt2 import (GPT2Config, GPT2Model, 18 | GPT2LMHeadModel, GPT2DoubleHeadsModel, 19 | load_tf_weights_in_gpt2) 20 | 21 | from .optimization import BertAdam 22 | from .optimization_openai import OpenAIAdam 23 | 24 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path 25 | -------------------------------------------------------------------------------- /layers/encoder.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | class BertBiLSTMEncoder(nn.Module): 5 | def __init__(self, embeddings, 6 | hidden_dim=128, rnn_layers=1, use_cuda=True): 7 | super(BertBiLSTMEncoder, self).__init__() 8 | self.embeddings = embeddings 9 | self.hidden_dim = hidden_dim 10 | self.rnn_layers = rnn_layers 11 | self.use_cuda = use_cuda 12 | self.lstm = nn.LSTM( 13 | self.embeddings.embedding_dim, hidden_dim // 2, 14 | rnn_layers, batch_first=True, bidirectional=True) 15 | self.hidden = None 16 | if use_cuda: 17 | self.cuda() 18 | self.init_weights() 19 | self.output_dim = hidden_dim 20 | 21 | def init_weights(self): 22 | #for p in self.lstm.parameters(): 23 | # nn.init.xavier_normal(p) 24 | pass 25 | def sort_lengths(self, inputs, input_lens): 26 | inputs_list = inputs.tolist() 27 | sorted_input_lens = sorted(input_lens, key=lambda l: l, reverse=True) 28 | sorted_input_len_ids = [input_lens.index(i) for i in sorted_input_lens] 29 | sorted_input_list = [inputs_list[i] for i in sorted_input_len_ids] 30 | return torch.tensor(sorted_input_list), sorted_input_lens 31 | 32 | def forward(self, batch): 33 | input, input_mask = batch[0], batch[1] 34 | output = self.embeddings(*batch) 35 | # output = self.dropout(output) 36 | lens = input_mask.sum(-1).tolist() 37 | output, lens = self.sort_lengths(output, lens) 38 | 39 | output = nn.utils.rnn.pack_padded_sequence(output, lens, batch_first=True) 40 | 41 | if self.use_cuda: 42 | output, self.hidden = self.lstm(output.to("cuda")) 43 | else: 44 | output, self.hidden = self.lstm(output.to("cpu")) 45 | 46 | output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True) 47 | 48 | return output, self.hidden 49 | 50 | @classmethod 51 | def create(cls, embeddings, hidden_dim=128, rnn_layers=1, use_cuda=True): 52 | model = cls( 53 | embeddings=embeddings, hidden_dim=hidden_dim, rnn_layers=rnn_layers, use_cuda=use_cuda) 54 | return model 55 | -------------------------------------------------------------------------------- /layers/.ipynb_checkpoints/encoder-checkpoint.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | class BertBiLSTMEncoder(nn.Module): 5 | def __init__(self, embeddings, 6 | hidden_dim=128, rnn_layers=1, use_cuda=True): 7 | super(BertBiLSTMEncoder, self).__init__() 8 | self.embeddings = embeddings 9 | self.hidden_dim = hidden_dim 10 | self.rnn_layers = rnn_layers 11 | self.use_cuda = use_cuda 12 | self.lstm = nn.LSTM( 13 | self.embeddings.embedding_dim, hidden_dim // 2, 14 | rnn_layers, batch_first=True, bidirectional=True) 15 | self.hidden = None 16 | if use_cuda: 17 | self.cuda() 18 | self.init_weights() 19 | self.output_dim = hidden_dim 20 | 21 | def init_weights(self): 22 | #for p in self.lstm.parameters(): 23 | # nn.init.xavier_normal(p) 24 | pass 25 | def sort_lengths(self, inputs, input_lens): 26 | inputs_list = inputs.tolist() 27 | sorted_input_lens = sorted(input_lens, key=lambda l: l, reverse=True) 28 | sorted_input_len_ids = [input_lens.index(i) for i in sorted_input_lens] 29 | sorted_input_list = [inputs_list[i] for i in sorted_input_len_ids] 30 | return torch.tensor(sorted_input_list), sorted_input_lens 31 | 32 | def forward(self, batch): 33 | input, input_mask = batch[0], batch[1] 34 | output = self.embeddings(*batch) 35 | # output = self.dropout(output) 36 | lens = input_mask.sum(-1).tolist() 37 | output, lens = self.sort_lengths(output, lens) 38 | 39 | output = nn.utils.rnn.pack_padded_sequence(output, lens, batch_first=True) 40 | 41 | if self.use_cuda: 42 | output, self.hidden = self.lstm(output.to("cuda")) 43 | else: 44 | output, self.hidden = self.lstm(output.to("cpu")) 45 | 46 | output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True) 47 | 48 | return output, self.hidden 49 | 50 | @classmethod 51 | def create(cls, embeddings, hidden_dim=128, rnn_layers=1, use_cuda=True): 52 | model = cls( 53 | embeddings=embeddings, hidden_dim=hidden_dim, rnn_layers=rnn_layers, use_cuda=use_cuda) 54 | return model 55 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import re 23 | import argparse 24 | import tensorflow as tf 25 | import torch 26 | import numpy as np 27 | 28 | from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert 29 | 30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 31 | # Initialise PyTorch model 32 | config = BertConfig.from_json_file(bert_config_file) 33 | print("Building PyTorch model from configuration: {}".format(str(config))) 34 | model = BertForPreTraining(config) 35 | 36 | # Load weights from tf checkpoint 37 | load_tf_weights_in_bert(model, tf_checkpoint_path) 38 | 39 | # Save pytorch-model 40 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 41 | torch.save(model.state_dict(), pytorch_dump_path) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | ## Required parameters 47 | parser.add_argument("--tf_checkpoint_path", 48 | default = None, 49 | type = str, 50 | required = True, 51 | help = "Path the TensorFlow checkpoint path.") 52 | parser.add_argument("--bert_config_file", 53 | default = None, 54 | type = str, 55 | required = True, 56 | help = "The config json file corresponding to the pre-trained BERT model. \n" 57 | "This specifies the model architecture.") 58 | parser.add_argument("--pytorch_dump_path", 59 | default = None, 60 | type = str, 61 | required = True, 62 | help = "Path to the output PyTorch model.") 63 | args = parser.parse_args() 64 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 65 | args.bert_config_file, 66 | args.pytorch_dump_path) 67 | -------------------------------------------------------------------------------- /layers/embedding.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from pytorch_pretrained_bert.modeling import BertModel, BertConfig 4 | 5 | class BertEmbedder(nn.Module): 6 | def __init__(self, model, bert_pretrained_path, 7 | freeze=True, embedding_dim=768, use_cuda=True, bert_mode="weighted",): 8 | super(BertEmbedder, self).__init__() 9 | self.bert_pretrained_path = bert_pretrained_path 10 | self.is_freeze = freeze 11 | self.embedding_dim = embedding_dim 12 | self.model = model 13 | self.use_cuda = use_cuda 14 | self.bert_mode = bert_mode 15 | if self.bert_mode == "weighted": 16 | self.bert_weights = nn.Parameter(torch.FloatTensor(12, 1)) 17 | self.bert_gamma = nn.Parameter(torch.FloatTensor(1, 1)) 18 | 19 | if use_cuda: 20 | self.cuda() 21 | 22 | self.init_weights() 23 | 24 | def init_weights(self): 25 | if self.bert_mode == "weighted": 26 | nn.init.xavier_normal(self.bert_gamma) 27 | nn.init.xavier_normal(self.bert_weights) 28 | 29 | def forward(self, *batch): 30 | input_ids, input_mask, input_type_ids = batch[:3] 31 | all_encoder_layers, _ = self.model(input_ids.long(), token_type_ids=input_type_ids.long(), attention_mask=input_mask.long()) 32 | if self.bert_mode == "last": 33 | return all_encoder_layers[-1] 34 | elif self.bert_mode == "weighted": 35 | all_encoder_layers = torch.stack([a * b for a, b in zip(all_encoder_layers, self.bert_weights)]) 36 | return self.bert_gamma * torch.sum(all_encoder_layers, dim=0) 37 | 38 | def freeze(self): 39 | for param in self.model.parameters(): 40 | param.requires_grad = False 41 | 42 | def unfreeze(self): 43 | for param in self.model.parameters(): 44 | param.requires_grad = True 45 | 46 | def freeze_to(self, to=-1): 47 | idx = 0 48 | if to < 0: 49 | to = len(self.model.encoder.layer) + to + 1 50 | for idx in range(to): 51 | for param in self.model.encoder.layer[idx].parameters(): 52 | param.requires_grad = False 53 | print("Embeddings freezed to {}".format(to)) 54 | to = len(self.model.encoder.layer) 55 | for idx in range(idx, to): 56 | for param in self.model.encoder.layer[idx].parameters(): 57 | param.requires_grad = True 58 | 59 | @classmethod 60 | def create(cls, 61 | bert_pretrained_path, embedding_dim=768, use_cuda=True, bert_mode="weighted", 62 | freeze=True): 63 | model = BertModel.from_pretrained(bert_pretrained_path) 64 | #if use_cuda: 65 | # device = torch.device("cuda") 66 | # map_location = "cuda" 67 | #else: 68 | # map_location = "cpu" 69 | # device = torch.device("cpu") 70 | #model = model.to(device) 71 | model = cls(model=model, embedding_dim=embedding_dim, use_cuda=use_cuda, bert_mode=bert_mode, 72 | bert_pretrained_path=bert_pretrained_path, freeze=freeze) 73 | if freeze: 74 | model.freeze() 75 | return model 76 | -------------------------------------------------------------------------------- /layers/.ipynb_checkpoints/embedding-checkpoint.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from pytorch_pretrained_bert.modeling import BertModel, BertConfig 4 | 5 | class BertEmbedder(nn.Module): 6 | def __init__(self, model, bert_pretrained_path, 7 | freeze=True, embedding_dim=768, use_cuda=True, bert_mode="weighted",): 8 | super(BertEmbedder, self).__init__() 9 | self.bert_pretrained_path = bert_pretrained_path 10 | self.is_freeze = freeze 11 | self.embedding_dim = embedding_dim 12 | self.model = model 13 | self.use_cuda = use_cuda 14 | self.bert_mode = bert_mode 15 | if self.bert_mode == "weighted": 16 | self.bert_weights = nn.Parameter(torch.FloatTensor(12, 1)) 17 | self.bert_gamma = nn.Parameter(torch.FloatTensor(1, 1)) 18 | 19 | if use_cuda: 20 | self.cuda() 21 | 22 | self.init_weights() 23 | 24 | def init_weights(self): 25 | if self.bert_mode == "weighted": 26 | nn.init.xavier_normal(self.bert_gamma) 27 | nn.init.xavier_normal(self.bert_weights) 28 | 29 | def forward(self, *batch): 30 | input_ids, input_mask, input_type_ids = batch[:3] 31 | all_encoder_layers, _ = self.model(input_ids.long(), token_type_ids=input_type_ids.long(), attention_mask=input_mask.long()) 32 | if self.bert_mode == "last": 33 | return all_encoder_layers[-1] 34 | elif self.bert_mode == "weighted": 35 | all_encoder_layers = torch.stack([a * b for a, b in zip(all_encoder_layers, self.bert_weights)]) 36 | return self.bert_gamma * torch.sum(all_encoder_layers, dim=0) 37 | 38 | def freeze(self): 39 | for param in self.model.parameters(): 40 | param.requires_grad = False 41 | 42 | def unfreeze(self): 43 | for param in self.model.parameters(): 44 | param.requires_grad = True 45 | 46 | def freeze_to(self, to=-1): 47 | idx = 0 48 | if to < 0: 49 | to = len(self.model.encoder.layer) + to + 1 50 | for idx in range(to): 51 | for param in self.model.encoder.layer[idx].parameters(): 52 | param.requires_grad = False 53 | print("Embeddings freezed to {}".format(to)) 54 | to = len(self.model.encoder.layer) 55 | for idx in range(idx, to): 56 | for param in self.model.encoder.layer[idx].parameters(): 57 | param.requires_grad = True 58 | 59 | @classmethod 60 | def create(cls, 61 | bert_pretrained_path, embedding_dim=768, use_cuda=True, bert_mode="weighted", 62 | freeze=True): 63 | model = BertModel.from_pretrained(bert_pretrained_path) 64 | #if use_cuda: 65 | # device = torch.device("cuda") 66 | # map_location = "cuda" 67 | #else: 68 | # map_location = "cpu" 69 | # device = torch.device("cpu") 70 | #model = model.to(device) 71 | model = cls(model=model, embedding_dim=embedding_dim, use_cuda=use_cuda, bert_mode=bert_mode, 72 | bert_pretrained_path=bert_pretrained_path, freeze=freeze) 73 | if freeze: 74 | model.freeze() 75 | return model 76 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_gpt2_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME, 25 | GPT2Config, 26 | GPT2Model, 27 | load_tf_weights_in_gpt2) 28 | 29 | 30 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if gpt2_config_file == "": 33 | config = GPT2Config() 34 | else: 35 | config = GPT2Config(gpt2_config_file) 36 | model = GPT2Model(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_gpt2(model, gpt2_checkpoint_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument("--gpt2_checkpoint_path", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "Path the TensorFlow checkpoint path.") 59 | parser.add_argument("--pytorch_dump_folder_path", 60 | default = None, 61 | type = str, 62 | required = True, 63 | help = "Path to the output PyTorch model.") 64 | parser.add_argument("--gpt2_config_file", 65 | default = "", 66 | type = str, 67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.") 69 | args = parser.parse_args() 70 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, 71 | args.gpt2_config_file, 72 | args.pytorch_dump_folder_path) 73 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME, 25 | OpenAIGPTConfig, 26 | OpenAIGPTModel, 27 | load_tf_weights_in_openai_gpt) 28 | 29 | 30 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if openai_config_file == "": 33 | config = OpenAIGPTConfig() 34 | else: 35 | config = OpenAIGPTConfig(openai_config_file) 36 | model = OpenAIGPTModel(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument("--openai_checkpoint_folder_path", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "Path the TensorFlow checkpoint path.") 59 | parser.add_argument("--pytorch_dump_folder_path", 60 | default = None, 61 | type = str, 62 | required = True, 63 | help = "Path to the output PyTorch model.") 64 | parser.add_argument("--openai_config_file", 65 | default = "", 66 | type = str, 67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.") 69 | args = parser.parse_args() 70 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, 71 | args.openai_config_file, 72 | args.pytorch_dump_folder_path) 73 | -------------------------------------------------------------------------------- /layers/decoder.py: -------------------------------------------------------------------------------- 1 | from .crf import CRF 2 | import torch 3 | 4 | from torch.autograd import Variable 5 | from torch import nn 6 | from torch.nn import init 7 | 8 | class Linear(nn.Linear): 9 | def __init__(self, 10 | in_features: int, 11 | out_features: int, 12 | bias: bool = True): 13 | super(Linear, self).__init__(in_features, out_features, bias=bias) 14 | init.orthogonal_(self.weight) 15 | 16 | class Linears(nn.Module): 17 | def __init__(self, 18 | in_features, 19 | out_features, 20 | hiddens, 21 | bias=True, 22 | activation='tanh'): 23 | super(Linears, self).__init__() 24 | assert len(hiddens) > 0 25 | 26 | self.in_features = in_features 27 | self.out_features = self.output_size = out_features 28 | 29 | in_dims = [in_features] + hiddens[:-1] 30 | self.linears = nn.ModuleList([Linear(in_dim, out_dim, bias=bias) 31 | for in_dim, out_dim 32 | in zip(in_dims, hiddens)]) 33 | self.output_linear = Linear(hiddens[-1], out_features, bias=bias) 34 | self.activation = activation 35 | 36 | def forward(self, inputs): 37 | linear_outputs = inputs 38 | for linear in self.linears: 39 | linear_outputs = linear.forward(linear_outputs) 40 | if self.activation == 'tanh': 41 | linear_outputs = torch.tanh(linear_outputs) 42 | else: 43 | linear_outputs = torch.relu(linear_outputs) 44 | return self.output_linear.forward(linear_outputs) 45 | 46 | class CRFDecoder(nn.Module): 47 | def __init__(self, label_size, input_dim, input_dropout=0.5, activation='tanh'): 48 | super(CRFDecoder, self).__init__() 49 | self.input_dim = input_dim 50 | self.input_dropout = nn.Dropout(p=input_dropout) 51 | self.linear = Linears(in_features=input_dim, 52 | out_features=label_size, 53 | hiddens=[input_dim // 2], 54 | activation=activation) 55 | self.crf = CRF(label_size+2) 56 | self.label_size = label_size 57 | 58 | def forward_model(self, inputs): 59 | batch_size, seq_len, input_dim = inputs.size() 60 | output = inputs.contiguous().view(-1, self.input_dim) 61 | output = self.input_dropout(output) 62 | # Fully-connected layer 63 | output = self.linear.forward(output) 64 | output = output.view(batch_size, seq_len, self.label_size) 65 | return output 66 | 67 | def forward(self, inputs, labels_mask): 68 | self.eval() 69 | lens = labels_mask.sum(-1) 70 | logits = self.forward_model(inputs) 71 | logits = self.crf.pad_logits(logits) 72 | scores, preds = self.crf.viterbi_decode(logits, lens) 73 | self.train() 74 | 75 | return preds 76 | 77 | def score(self, inputs, labels_mask, labels): 78 | lens = labels_mask.sum(-1) 79 | logits = self.forward_model(inputs) 80 | logits = self.crf.pad_logits(logits) 81 | norm_score = self.crf.calc_norm_score(logits, lens) 82 | labels = labels[:, :logits.size(1)] 83 | gold_score = self.crf.calc_gold_score(logits, labels, lens) 84 | loglik = gold_score - norm_score 85 | return -loglik.mean() 86 | -------------------------------------------------------------------------------- /layers/.ipynb_checkpoints/decoder-checkpoint.py: -------------------------------------------------------------------------------- 1 | from .crf import CRF 2 | import torch 3 | 4 | from torch.autograd import Variable 5 | from torch import nn 6 | from torch.nn import init 7 | 8 | class Linear(nn.Linear): 9 | def __init__(self, 10 | in_features: int, 11 | out_features: int, 12 | bias: bool = True): 13 | super(Linear, self).__init__(in_features, out_features, bias=bias) 14 | init.orthogonal_(self.weight) 15 | 16 | class Linears(nn.Module): 17 | def __init__(self, 18 | in_features, 19 | out_features, 20 | hiddens, 21 | bias=True, 22 | activation='tanh'): 23 | super(Linears, self).__init__() 24 | assert len(hiddens) > 0 25 | 26 | self.in_features = in_features 27 | self.out_features = self.output_size = out_features 28 | 29 | in_dims = [in_features] + hiddens[:-1] 30 | self.linears = nn.ModuleList([Linear(in_dim, out_dim, bias=bias) 31 | for in_dim, out_dim 32 | in zip(in_dims, hiddens)]) 33 | self.output_linear = Linear(hiddens[-1], out_features, bias=bias) 34 | self.activation = activation 35 | 36 | def forward(self, inputs): 37 | linear_outputs = inputs 38 | for linear in self.linears: 39 | linear_outputs = linear.forward(linear_outputs) 40 | if self.activation == 'tanh': 41 | linear_outputs = torch.tanh(linear_outputs) 42 | else: 43 | linear_outputs = torch.relu(linear_outputs) 44 | return self.output_linear.forward(linear_outputs) 45 | 46 | class CRFDecoder(nn.Module): 47 | def __init__(self, label_size, input_dim, input_dropout=0.5, activation='tanh'): 48 | super(CRFDecoder, self).__init__() 49 | self.input_dim = input_dim 50 | self.input_dropout = nn.Dropout(p=input_dropout) 51 | self.linear = Linears(in_features=input_dim, 52 | out_features=label_size, 53 | hiddens=[input_dim // 2], 54 | activation=activation) 55 | self.crf = CRF(label_size+2) 56 | self.label_size = label_size 57 | 58 | def forward_model(self, inputs): 59 | batch_size, seq_len, input_dim = inputs.size() 60 | output = inputs.contiguous().view(-1, self.input_dim) 61 | output = self.input_dropout(output) 62 | # Fully-connected layer 63 | output = self.linear.forward(output) 64 | output = output.view(batch_size, seq_len, self.label_size) 65 | return output 66 | 67 | def forward(self, inputs, labels_mask): 68 | self.eval() 69 | lens = labels_mask.sum(-1) 70 | logits = self.forward_model(inputs) 71 | logits = self.crf.pad_logits(logits) 72 | scores, preds = self.crf.viterbi_decode(logits, lens) 73 | self.train() 74 | 75 | return preds 76 | 77 | def score(self, inputs, labels_mask, labels): 78 | lens = labels_mask.sum(-1) 79 | logits = self.forward_model(inputs) 80 | logits = self.crf.pad_logits(logits) 81 | norm_score = self.crf.calc_norm_score(logits, lens) 82 | labels = labels[:, :logits.size(1)] 83 | gold_score = self.crf.calc_gold_score(logits, labels, lens) 84 | loglik = gold_score - norm_score 85 | return -loglik.mean() 86 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_pretrained_bert.modeling import * 3 | from DataGen import DataGenerator 4 | from model import BertBiLSTMCRF 5 | import pickle 6 | 7 | def convert_rel_matrix_to_str(predicted_rel, rel_list): 8 | index_rels = [] 9 | label_rels = [] 10 | 11 | for i, pos in enumerate(predicted_rel[0]): 12 | index_rel = [] 13 | label_rel = [] 14 | for j, val in enumerate(pos): 15 | if j % 5 == 0 and val == 1: # 16 | index_rel.append(int(j/5)) 17 | label_rel.append("SP") 18 | elif j % 5 == 1 and val == 1: # 19 | index_rel.append(int(j/5)) 20 | label_rel.append("N") 21 | elif j % 5 == 2 and val == 1: 22 | index_rel.append(int(j/5)) 23 | label_rel.append("ST") 24 | elif j%5 == 3 and val == 1: 25 | index_rel.append(int(j/5)) 26 | label_rel.append("PO") 27 | elif j%5 == 4 and val == 1: 28 | index_rel.append(int(j/5)) 29 | label_rel.append("SC") 30 | 31 | if len(index_rel) == 0: 32 | index_rels.append([i]) 33 | label_rels.append(["N"]) 34 | else: 35 | index_rels.append(index_rel) 36 | label_rels.append(label_rel) 37 | 38 | return index_rels, label_rels 39 | 40 | def predict(model, featurized_sentence, rel_list, ner_list, use_extra=True): 41 | model.eval() 42 | input_tensor_1 = featurized_sentence[0].to("cuda") 43 | input_tensor_2 = featurized_sentence[1].to("cuda") 44 | input_tensor_3 = featurized_sentence[2].to("cuda") 45 | input_tensor_4 = featurized_sentence[3].to("cuda") 46 | input_tensor_5 = featurized_sentence[4].to("cuda") 47 | ner_logits, rel_logits = model([input_tensor_1, input_tensor_2, input_tensor_3, input_tensor_4, input_tensor_5]) 48 | 49 | label_types = [] 50 | ner_logits = ner_logits[0] 51 | for i in ner_logits: 52 | label_types.append(ner_list[i]) 53 | 54 | if use_extra: 55 | rel_result = [] 56 | ner_logits = ner_logits.tolist() 57 | print(rel_logits) 58 | index_rels, label_rels = convert_rel_matrix_to_str(rel_logits, rel_list) 59 | return label_types, index_rels, label_rels 60 | 61 | return label_types, None, None 62 | 63 | def load_model(save_path): 64 | with open(save_path,'rb') as f: 65 | data = pickle.load(f) 66 | return data["model"] 67 | 68 | if __name__ == "__main__": 69 | PRETRAINED_MODEL = "models/" 70 | BERT_PRETRAINED_PATH = "./multi_cased_L-12_H-768_A-12/" 71 | #VALID_PATH = "/data/loclh2/QABot/data/test_analysis.txt" 72 | batch_size = 32 73 | shuffle = True 74 | use_cuda = True 75 | 76 | data_gen = DataGenerator(model=BertModel, model_name=BERT_PRETRAINED_PATH) 77 | 78 | #model = BertBiLSTMCRF.create(16, 79 | # len(data_gen.rel_list), 80 | # BERT_PRETRAINED_PATH, 81 | # freeze=True, 82 | # rnn_layers=2, 83 | # input_dropout=0.1, 84 | # use_cuda=use_cuda, 85 | # hidden_size=64, 86 | # label_embedding_size=64, 87 | # enc_hidden_dim=64, 88 | # activation="tanh") 89 | 90 | # model.load_state_dict(torch.load(PRETRAINED_MODEL)) 91 | 92 | model = load_model("models/rel_ner_v1_adam/bert_ner_epoches=50_valid_loss=10.281595188638438.pickle") 93 | 94 | sentence = "Con trai của Ronald là ai" 95 | featurized_sentence = data_gen.get_featurized_sentence(sentence) 96 | 97 | label_types, index_rels, label_rels = predict(model, 98 | featurized_sentence, 99 | data_gen.rel_list, 100 | data_gen.ner_list, 101 | use_extra=True) 102 | print("label_types : " + str(label_types)) 103 | print("index_rels : " + str(index_rels)) 104 | print("label_rels : " + str(label_rels)) 105 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [ 5 | "convert_tf_checkpoint_to_pytorch", 6 | "convert_openai_checkpoint", 7 | "convert_transfo_xl_checkpoint", 8 | "convert_gpt2_checkpoint", 9 | ]: 10 | print( 11 | "Should be used as one of: \n" 12 | ">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n" 13 | ">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n" 14 | ">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n" 15 | ">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`") 16 | else: 17 | if sys.argv[1] == "convert_tf_checkpoint_to_pytorch": 18 | try: 19 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 20 | except ImportError: 21 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 22 | "In that case, it requires TensorFlow to be installed. Please see " 23 | "https://www.tensorflow.org/install/ for installation instructions.") 24 | raise 25 | 26 | if len(sys.argv) != 5: 27 | # pylint: disable=line-too-long 28 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 29 | else: 30 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 31 | TF_CONFIG = sys.argv.pop() 32 | TF_CHECKPOINT = sys.argv.pop() 33 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 34 | elif sys.argv[1] == "convert_openai_checkpoint": 35 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch 36 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] 37 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 38 | if len(sys.argv) == 5: 39 | OPENAI_GPT_CONFIG = sys.argv[4] 40 | else: 41 | OPENAI_GPT_CONFIG = "" 42 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, 43 | OPENAI_GPT_CONFIG, 44 | PYTORCH_DUMP_OUTPUT) 45 | elif sys.argv[1] == "convert_transfo_xl_checkpoint": 46 | try: 47 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch 48 | except ImportError: 49 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 50 | "In that case, it requires TensorFlow to be installed. Please see " 51 | "https://www.tensorflow.org/install/ for installation instructions.") 52 | raise 53 | 54 | if 'ckpt' in sys.argv[2].lower(): 55 | TF_CHECKPOINT = sys.argv[2] 56 | TF_DATASET_FILE = "" 57 | else: 58 | TF_DATASET_FILE = sys.argv[2] 59 | TF_CHECKPOINT = "" 60 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 61 | if len(sys.argv) == 5: 62 | TF_CONFIG = sys.argv[4] 63 | else: 64 | TF_CONFIG = "" 65 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) 66 | else: 67 | try: 68 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch 69 | except ImportError: 70 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 71 | "In that case, it requires TensorFlow to be installed. Please see " 72 | "https://www.tensorflow.org/install/ for installation instructions.") 73 | raise 74 | 75 | TF_CHECKPOINT = sys.argv[2] 76 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 77 | if len(sys.argv) == 5: 78 | TF_CONFIG = sys.argv[4] 79 | else: 80 | TF_CONFIG = "" 81 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Transformer XL checkpoint and datasets.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import os 21 | import sys 22 | from io import open 23 | 24 | import torch 25 | 26 | import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils 27 | from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME, 28 | WEIGHTS_NAME, 29 | TransfoXLConfig, 30 | TransfoXLLMHeadModel, 31 | load_tf_weights_in_transfo_xl) 32 | from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME, 33 | VOCAB_NAME) 34 | 35 | if sys.version_info[0] == 2: 36 | import cPickle as pickle 37 | else: 38 | import pickle 39 | 40 | # We do this to be able to load python 2 datasets pickles 41 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 42 | data_utils.Vocab = data_utils.TransfoXLTokenizer 43 | data_utils.Corpus = data_utils.TransfoXLCorpus 44 | sys.modules['data_utils'] = data_utils 45 | sys.modules['vocabulary'] = data_utils 46 | 47 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, 48 | transfo_xl_config_file, 49 | pytorch_dump_folder_path, 50 | transfo_xl_dataset_file): 51 | if transfo_xl_dataset_file: 52 | # Convert a pre-processed corpus (see original TensorFlow repo) 53 | with open(transfo_xl_dataset_file, "rb") as fp: 54 | corpus = pickle.load(fp, encoding="latin1") 55 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) 56 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME 57 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) 58 | corpus_vocab_dict = corpus.vocab.__dict__ 59 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) 60 | 61 | corpus_dict_no_vocab = corpus.__dict__ 62 | corpus_dict_no_vocab.pop('vocab', None) 63 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME 64 | print("Save dataset to {}".format(pytorch_dataset_dump_path)) 65 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) 66 | 67 | if tf_checkpoint_path: 68 | # Convert a pre-trained TensorFlow model 69 | config_path = os.path.abspath(transfo_xl_config_file) 70 | tf_path = os.path.abspath(tf_checkpoint_path) 71 | 72 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) 73 | # Initialise PyTorch model 74 | if transfo_xl_config_file == "": 75 | config = TransfoXLConfig() 76 | else: 77 | config = TransfoXLConfig(transfo_xl_config_file) 78 | print("Building PyTorch model from configuration: {}".format(str(config))) 79 | model = TransfoXLLMHeadModel(config) 80 | 81 | model = load_tf_weights_in_transfo_xl(model, config, tf_path) 82 | # Save pytorch-model 83 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 84 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 85 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 86 | torch.save(model.state_dict(), pytorch_weights_dump_path) 87 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 88 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 89 | f.write(config.to_json_string()) 90 | 91 | 92 | if __name__ == "__main__": 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument("--pytorch_dump_folder_path", 95 | default = None, 96 | type = str, 97 | required = True, 98 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 99 | parser.add_argument("--tf_checkpoint_path", 100 | default = "", 101 | type = str, 102 | help = "An optional path to a TensorFlow checkpoint path to be converted.") 103 | parser.add_argument("--transfo_xl_config_file", 104 | default = "", 105 | type = str, 106 | help = "An optional config json file corresponding to the pre-trained BERT model. \n" 107 | "This specifies the model architecture.") 108 | parser.add_argument("--transfo_xl_dataset_file", 109 | default = "", 110 | type = str, 111 | help = "An optional dataset file to be converted in a vocabulary.") 112 | args = parser.parse_args() 113 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, 114 | args.transfo_xl_config_file, 115 | args.pytorch_dump_folder_path, 116 | args.transfo_xl_dataset_file) 117 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/optimization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for OpenAI GPT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | 23 | def warmup_cosine(x, warmup=0.002): 24 | s = 1 if x <= warmup else 0 25 | return s*(x/warmup) + (1-s)*(0.5 * (1 + torch.cos(math.pi * x))) 26 | 27 | def warmup_constant(x, warmup=0.002): 28 | s = 1 if x <= warmup else 0 29 | return s*(x/warmup) + (1-s)*1 30 | 31 | def warmup_linear(x, warmup=0.002): 32 | s = 1 if x <= warmup else 0 33 | return (s*(x/warmup) + (1-s))*(1-x) 34 | 35 | SCHEDULES = { 36 | 'warmup_cosine':warmup_cosine, 37 | 'warmup_constant':warmup_constant, 38 | 'warmup_linear':warmup_linear, 39 | } 40 | 41 | 42 | class OpenAIAdam(Optimizer): 43 | """Implements Open AI version of Adam algorithm with weight decay fix. 44 | """ 45 | def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1, 46 | b1=0.9, b2=0.999, e=1e-8, weight_decay=0, 47 | vector_l2=False, max_grad_norm=-1, **kwargs): 48 | if lr is not required and lr < 0.0: 49 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 50 | if schedule not in SCHEDULES: 51 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 52 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 53 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 54 | if not 0.0 <= b1 < 1.0: 55 | raise ValueError("Invalid b1 parameter: {}".format(b1)) 56 | if not 0.0 <= b2 < 1.0: 57 | raise ValueError("Invalid b2 parameter: {}".format(b2)) 58 | if not e >= 0.0: 59 | raise ValueError("Invalid epsilon value: {}".format(e)) 60 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 61 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2, 62 | max_grad_norm=max_grad_norm) 63 | super(OpenAIAdam, self).__init__(params, defaults) 64 | 65 | def get_lr(self): 66 | lr = [] 67 | for group in self.param_groups: 68 | for p in group['params']: 69 | state = self.state[p] 70 | if len(state) == 0: 71 | return [0] 72 | if group['t_total'] != -1: 73 | schedule_fct = SCHEDULES[group['schedule']] 74 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 75 | else: 76 | lr_scheduled = group['lr'] 77 | lr.append(lr_scheduled) 78 | return lr 79 | 80 | def step(self, closure=None): 81 | """Performs a single optimization step. 82 | 83 | Arguments: 84 | closure (callable, optional): A closure that reevaluates the model 85 | and returns the loss. 86 | """ 87 | loss = None 88 | if closure is not None: 89 | loss = closure() 90 | 91 | for group in self.param_groups: 92 | for p in group['params']: 93 | if p.grad is None: 94 | continue 95 | grad = p.grad.data 96 | if grad.is_sparse: 97 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 98 | 99 | state = self.state[p] 100 | 101 | # State initialization 102 | if len(state) == 0: 103 | state['step'] = 0 104 | # Exponential moving average of gradient values 105 | state['exp_avg'] = torch.zeros_like(p.data) 106 | # Exponential moving average of squared gradient values 107 | state['exp_avg_sq'] = torch.zeros_like(p.data) 108 | 109 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 110 | beta1, beta2 = group['b1'], group['b2'] 111 | 112 | state['step'] += 1 113 | 114 | # Add grad clipping 115 | if group['max_grad_norm'] > 0: 116 | clip_grad_norm_(p, group['max_grad_norm']) 117 | 118 | # Decay the first and second moment running average coefficient 119 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 120 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 121 | denom = exp_avg_sq.sqrt().add_(group['e']) 122 | 123 | bias_correction1 = 1 - beta1 ** state['step'] 124 | bias_correction2 = 1 - beta2 ** state['step'] 125 | 126 | if group['t_total'] != -1: 127 | schedule_fct = SCHEDULES[group['schedule']] 128 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 129 | else: 130 | lr_scheduled = group['lr'] 131 | 132 | step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 133 | 134 | p.data.addcdiv_(-step_size, exp_avg, denom) 135 | 136 | # Add weight decay at the end (fixed version) 137 | if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0: 138 | p.data.add_(-lr_scheduled * group['weight_decay'], p.data) 139 | 140 | return loss 141 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_pretrained_bert.modeling import * 3 | from tqdm import tqdm 4 | from DataGen import DataGenerator 5 | from model import BertBiLSTMCRF 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import pickle 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 11 | torch.backends.cudnn.benchmark = True 12 | torch.backends.cudnn.enabled = True 13 | 14 | 15 | def evaluate(model, epoch, valid_generator, device): 16 | #load valid data_generator 17 | num_valid = 0 18 | num_correct = total_loss = 0 19 | batch_size = 0 20 | idx = 0 21 | 22 | for batch_input_ids, batch_mask_ids, batch_input_type_ids, batch_label_ids, batch_rel_matrix in valid_generator: 23 | #compute loss 24 | loss = model.score([batch_input_ids.to(device), batch_mask_ids.to(device), batch_input_type_ids.to(device), batch_label_ids.to(device), batch_rel_matrix.to(device)]) 25 | 26 | idx += 1 27 | total_loss += loss.mean().item() 28 | 29 | loss = total_loss / idx 30 | 31 | print('\nValidation : Loss: {:.6f} Accuracy: {}/{} ({:.4f}%)\n'.format(loss, 0, 0, 0)) 32 | 33 | return loss, 0 34 | 35 | def draw_learning_curve(epoch, training_loss, valid_loss, 36 | save_path="models/learning_curve.png", 37 | title="Learning Curve"): 38 | plt.plot(epoch, training_loss, '-b') 39 | plt.plot(epoch, valid_loss, '--r') 40 | 41 | plt.xlabel('Epoch') 42 | plt.ylabel('Loss') 43 | plt.legend(['Training Loss', 'Validation Loss']) 44 | plt.title(title) 45 | 46 | # save image 47 | plt.savefig(save_path) # should before show method 48 | 49 | def train(model, data_gen, train_path, valid_path, start_epoch, num_epoches, save_path, device, 50 | batch_size=32, 51 | decay_rate=0.1, 52 | learning_rate=0.001, 53 | momentum=0.9, 54 | update_lr_epoches=[], 55 | shuffle=True): 56 | 57 | model.train() 58 | optimizer = torch.optim.SGD(model.parameters(), 59 | lr=learning_rate, 60 | momentum=0.9) 61 | loss_history = [[], [], []] 62 | 63 | 64 | for epoch in range(start_epoch, num_epoches + start_epoch): 65 | step = 0 66 | train_gen = data_gen.get_generator(train_path, batch_size, is_shuffle=shuffle) 67 | if epoch in update_lr_epoches: 68 | learning_rate = learning_rate * decay_rate 69 | for param_group in optimizer.param_groups: 70 | param_group['lr'] = learning_rate 71 | print('Updating the learning rate at epoch: ' + str(epoch) + ', value: ' + str(learning_rate)) 72 | 73 | train_loss = [] 74 | for batch_input_ids, batch_mask_ids, batch_input_type_ids, batch_label_types, batch_rel_matrix in train_gen: 75 | model.zero_grad() 76 | loss = model.score([batch_input_ids.to(device), batch_mask_ids.to(device), batch_input_type_ids.to(device), batch_label_types.to(device), batch_rel_matrix.to(device)]) 77 | 78 | loss.backward() 79 | 80 | optimizer.step() 81 | optimizer.zero_grad() 82 | loss = loss.data.cpu().tolist() 83 | train_loss.append(loss) 84 | print('Training: Epoch %d, step %5d / %d loss: %.3f'%(epoch + 1, step + 1, data_gen.num_samples/batch_size + 1, loss)) 85 | step += 1 86 | 87 | valid_gen = data_gen.get_generator(valid_path, batch_size, is_shuffle=shuffle) 88 | valid_loss, valid_accuracy = evaluate(model, epoch, valid_gen, device ) 89 | 90 | ### visualization the learning curve 91 | train_loss_value = np.mean(train_loss) 92 | valid_loss_value = np.mean(valid_loss) 93 | 94 | loss_history[0].append(epoch + 1) 95 | loss_history[1].append(train_loss_value) 96 | loss_history[2].append(valid_loss_value) 97 | 98 | draw_learning_curve(loss_history[0], loss_history[1], loss_history[2]) 99 | 100 | model.train() 101 | 102 | save_data = {"model": model, "history": loss_history} 103 | with open(save_path + "bert_ner_epoches=" + str(epoch + 1) + "_valid_loss=" + str(valid_loss) +'.pickle', 'wb') as handle: 104 | pickle.dump(save_data, handle, protocol=2) 105 | 106 | def load_pretrain_model(model, pretrained_path): 107 | model.load_state_dict(torch.load(pretrained_path)) 108 | return model 109 | 110 | if __name__ == "__main__": 111 | TRAIN_PATH = "data/train_analysis.txt" 112 | VALID_PATH = "data/test_analysis.txt" 113 | PRETRAINED_PATH = "./multi_cased_L-12_H-768_A-12/" 114 | SAVE_PATH = "models/rel_ner_v1_adam/" 115 | batch_size = 4 116 | shuffle = True 117 | use_cuda = True 118 | use_extra = True 119 | freeze = False ## !(fine tuning) or not 120 | start_epoch=0 121 | num_epoches = 50 122 | learning_rate = 0.001 123 | decay_rate = 0.1 124 | update_lr_epoches = [35,] 125 | momentum=0.9 126 | 127 | if use_cuda: 128 | device = "cuda" 129 | else: 130 | device = "cpu" 131 | 132 | data_gen = DataGenerator(model=BertModel, 133 | model_name=PRETRAINED_PATH, 134 | device=device) 135 | 136 | ner_size = len(data_gen.ner_list) 137 | rel_size = len(data_gen.rel_list) 138 | 139 | model = BertBiLSTMCRF.create(ner_size, 140 | rel_size, 141 | PRETRAINED_PATH, 142 | freeze=freeze, 143 | rnn_layers=2, 144 | input_dropout=0.1, 145 | use_cuda=use_cuda, 146 | use_extra=use_extra, 147 | hidden_size=64, 148 | label_embedding_size=32, 149 | enc_hidden_dim=64, 150 | activation="tanh") 151 | 152 | train(model=model, 153 | data_gen=data_gen, 154 | train_path=TRAIN_PATH, 155 | valid_path=VALID_PATH, 156 | start_epoch=start_epoch, 157 | num_epoches=num_epoches, 158 | save_path=SAVE_PATH, 159 | device=device, 160 | batch_size=batch_size, 161 | decay_rate=decay_rate, 162 | learning_rate=learning_rate, 163 | momentum=momentum, 164 | update_lr_epoches=update_lr_epoches, 165 | shuffle=shuffle) 166 | -------------------------------------------------------------------------------- /layers/crf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | # TODO: move to utils 6 | def log_sum_exp(tensor, dim=0): 7 | """LogSumExp operation.""" 8 | m, _ = torch.max(tensor, dim) 9 | m_exp = m.unsqueeze(-1).expand_as(tensor) 10 | return m + torch.log(torch.sum(torch.exp(tensor - m_exp), dim)) 11 | 12 | 13 | def sequence_mask(lens, max_len=None): 14 | batch_size = lens.size(0) 15 | 16 | if max_len is None: 17 | max_len = lens.max().item() 18 | 19 | ranges = torch.arange(0, max_len).long() 20 | ranges = ranges.unsqueeze(0).expand(batch_size, max_len) 21 | 22 | if lens.data.is_cuda: 23 | ranges = ranges.cuda() 24 | 25 | lens_exp = lens.unsqueeze(1).expand_as(ranges) 26 | mask = ranges < lens_exp 27 | 28 | return mask 29 | 30 | 31 | class CRF(nn.Module): 32 | def __init__(self, label_size): 33 | super(CRF, self).__init__() 34 | 35 | self.label_size = label_size 36 | self.start = self.label_size - 2 37 | self.end = self.label_size - 1 38 | transition = torch.randn(self.label_size, self.label_size) 39 | self.transition = nn.Parameter(transition) 40 | self.initialize() 41 | 42 | def initialize(self): 43 | self.transition.data[:, self.end] = -100.0 44 | self.transition.data[self.start, :] = -100.0 45 | 46 | def pad_logits(self, logits): 47 | # lens = lens.data 48 | batch_size, seq_len, label_num = logits.size() 49 | # pads = Variable(logits.data.new(batch_size, seq_len, 2).fill_(-1000.0), 50 | # requires_grad=False) 51 | pads = logits.new_full((batch_size, seq_len, 2), -1000.0, 52 | requires_grad=False) 53 | logits = torch.cat([logits, pads], dim=2) 54 | return logits 55 | 56 | def calc_binary_score(self, labels, lens): 57 | batch_size, seq_len = labels.size() 58 | 59 | # labels_ext = Variable(labels.data.new(batch_size, seq_len + 2)) 60 | labels_ext = labels.new_empty((batch_size, seq_len + 2)) 61 | labels_ext[:, 0] = self.start 62 | labels_ext[:, 1:-1] = labels 63 | mask = sequence_mask(lens + 1, max_len=(seq_len + 2)).long() 64 | # pad_stop = Variable(labels.data.new(1).fill_(self.end)) 65 | pad_stop = labels.new_full((1,), self.end, requires_grad=False) 66 | pad_stop = pad_stop.unsqueeze(-1).expand(batch_size, seq_len + 2) 67 | labels_ext = (1 - mask) * pad_stop + mask * labels_ext 68 | labels = labels_ext 69 | 70 | trn = self.transition 71 | trn_exp = trn.unsqueeze(0).expand(batch_size, *trn.size()) 72 | lbl_r = labels[:, 1:] 73 | lbl_rexp = lbl_r.unsqueeze(-1).expand(*lbl_r.size(), trn.size(0)) 74 | trn_row = torch.gather(trn_exp, 1, lbl_rexp) 75 | 76 | 77 | 78 | lbl_lexp = labels[:, :-1].unsqueeze(-1) 79 | trn_scr = torch.gather(trn_row, 2, lbl_lexp) 80 | trn_scr = trn_scr.squeeze(-1) 81 | 82 | mask = sequence_mask(lens + 1).float() 83 | trn_scr = trn_scr * mask 84 | score = trn_scr 85 | 86 | return score 87 | 88 | def calc_unary_score(self, logits, labels, lens): 89 | labels = labels[:, :logits.size(1)] 90 | 91 | labels_exp = labels.unsqueeze(-1) 92 | 93 | scores = torch.gather(logits, 2, labels_exp).squeeze(-1) 94 | mask = sequence_mask(lens).float() 95 | scores = scores * mask 96 | return scores 97 | 98 | def calc_gold_score(self, logits, labels, lens): 99 | unary_score = self.calc_unary_score(logits, labels, lens).sum( 100 | 1).squeeze(-1) 101 | binary_score = self.calc_binary_score(labels, lens).sum(1).squeeze(-1) 102 | return unary_score + binary_score 103 | 104 | def calc_norm_score(self, logits, lens): 105 | batch_size, seq_len, feat_dim = logits.size() 106 | # alpha = logits.data.new(batch_size, self.label_size).fill_(-10000.0) 107 | alpha = logits.new_full((batch_size, self.label_size), -100.0) 108 | alpha[:, self.start] = 0 109 | # alpha = Variable(alpha) 110 | lens_ = lens.clone() 111 | 112 | logits_t = logits.transpose(1, 0) 113 | for logit in logits_t: 114 | logit_exp = logit.unsqueeze(-1).expand(batch_size, 115 | *self.transition.size()) 116 | alpha_exp = alpha.unsqueeze(1).expand(batch_size, 117 | *self.transition.size()) 118 | trans_exp = self.transition.unsqueeze(0).expand_as(alpha_exp) 119 | mat = logit_exp + alpha_exp + trans_exp 120 | alpha_nxt = log_sum_exp(mat, 2).squeeze(-1) 121 | 122 | mask = (lens_ > 0).float().unsqueeze(-1).expand_as(alpha) 123 | alpha = mask * alpha_nxt + (1 - mask) * alpha 124 | lens_ = lens_ - 1 125 | 126 | alpha = alpha + self.transition[self.end].unsqueeze(0).expand_as(alpha) 127 | norm = log_sum_exp(alpha, 1).squeeze(-1) 128 | 129 | return norm 130 | 131 | def viterbi_decode(self, logits, lens): 132 | """Borrowed from pytorch tutorial 133 | Arguments: 134 | logits: [batch_size, seq_len, n_labels] FloatTensor 135 | lens: [batch_size] LongTensor 136 | """ 137 | batch_size, seq_len, n_labels = logits.size() 138 | # vit = logits.data.new(batch_size, self.label_size).fill_(-10000) 139 | vit = logits.new_full((batch_size, self.label_size), -100.0) 140 | vit[:, self.start] = 0 141 | # vit = Variable(vit) 142 | c_lens = lens.clone() 143 | 144 | logits_t = logits.transpose(1, 0) 145 | pointers = [] 146 | for logit in logits_t: 147 | vit_exp = vit.unsqueeze(1).expand(batch_size, n_labels, n_labels) 148 | trn_exp = self.transition.unsqueeze(0).expand_as(vit_exp) 149 | vit_trn_sum = vit_exp + trn_exp 150 | vt_max, vt_argmax = vit_trn_sum.max(2) 151 | 152 | vt_max = vt_max.squeeze(-1) 153 | vit_nxt = vt_max + logit 154 | pointers.append(vt_argmax.squeeze(-1).unsqueeze(0)) 155 | 156 | mask = (c_lens > 0).float().unsqueeze(-1).expand_as(vit_nxt) 157 | vit = mask * vit_nxt + (1 - mask) * vit 158 | 159 | mask = (c_lens == 1).float().unsqueeze(-1).expand_as(vit_nxt) 160 | vit += mask * self.transition[self.end].unsqueeze( 161 | 0).expand_as(vit_nxt) 162 | 163 | c_lens = c_lens - 1 164 | 165 | pointers = torch.cat(pointers) 166 | scores, idx = vit.max(1) 167 | # idx = idx.squeeze(-1) 168 | paths = [idx.unsqueeze(1)] 169 | for argmax in reversed(pointers): 170 | idx_exp = idx.unsqueeze(-1) 171 | idx = torch.gather(argmax, 1, idx_exp) 172 | idx = idx.squeeze(-1) 173 | 174 | paths.insert(0, idx.unsqueeze(1)) 175 | 176 | paths = torch.cat(paths[1:], 1) 177 | scores = scores.squeeze(-1) 178 | 179 | return scores, paths 180 | -------------------------------------------------------------------------------- /layers/.ipynb_checkpoints/crf-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | # TODO: move to utils 6 | def log_sum_exp(tensor, dim=0): 7 | """LogSumExp operation.""" 8 | m, _ = torch.max(tensor, dim) 9 | m_exp = m.unsqueeze(-1).expand_as(tensor) 10 | return m + torch.log(torch.sum(torch.exp(tensor - m_exp), dim)) 11 | 12 | 13 | def sequence_mask(lens, max_len=None): 14 | batch_size = lens.size(0) 15 | 16 | if max_len is None: 17 | max_len = lens.max().item() 18 | 19 | ranges = torch.arange(0, max_len).long() 20 | ranges = ranges.unsqueeze(0).expand(batch_size, max_len) 21 | 22 | if lens.data.is_cuda: 23 | ranges = ranges.cuda() 24 | 25 | lens_exp = lens.unsqueeze(1).expand_as(ranges) 26 | mask = ranges < lens_exp 27 | 28 | return mask 29 | 30 | 31 | class CRF(nn.Module): 32 | def __init__(self, label_size): 33 | super(CRF, self).__init__() 34 | 35 | self.label_size = label_size 36 | self.start = self.label_size - 2 37 | self.end = self.label_size - 1 38 | transition = torch.randn(self.label_size, self.label_size) 39 | self.transition = nn.Parameter(transition) 40 | self.initialize() 41 | 42 | def initialize(self): 43 | self.transition.data[:, self.end] = -100.0 44 | self.transition.data[self.start, :] = -100.0 45 | 46 | def pad_logits(self, logits): 47 | # lens = lens.data 48 | batch_size, seq_len, label_num = logits.size() 49 | # pads = Variable(logits.data.new(batch_size, seq_len, 2).fill_(-1000.0), 50 | # requires_grad=False) 51 | pads = logits.new_full((batch_size, seq_len, 2), -1000.0, 52 | requires_grad=False) 53 | logits = torch.cat([logits, pads], dim=2) 54 | return logits 55 | 56 | def calc_binary_score(self, labels, lens): 57 | batch_size, seq_len = labels.size() 58 | 59 | # labels_ext = Variable(labels.data.new(batch_size, seq_len + 2)) 60 | labels_ext = labels.new_empty((batch_size, seq_len + 2)) 61 | labels_ext[:, 0] = self.start 62 | labels_ext[:, 1:-1] = labels 63 | mask = sequence_mask(lens + 1, max_len=(seq_len + 2)).long() 64 | # pad_stop = Variable(labels.data.new(1).fill_(self.end)) 65 | pad_stop = labels.new_full((1,), self.end, requires_grad=False) 66 | pad_stop = pad_stop.unsqueeze(-1).expand(batch_size, seq_len + 2) 67 | labels_ext = (1 - mask) * pad_stop + mask * labels_ext 68 | labels = labels_ext 69 | 70 | trn = self.transition 71 | trn_exp = trn.unsqueeze(0).expand(batch_size, *trn.size()) 72 | lbl_r = labels[:, 1:] 73 | lbl_rexp = lbl_r.unsqueeze(-1).expand(*lbl_r.size(), trn.size(0)) 74 | trn_row = torch.gather(trn_exp, 1, lbl_rexp) 75 | 76 | 77 | 78 | lbl_lexp = labels[:, :-1].unsqueeze(-1) 79 | trn_scr = torch.gather(trn_row, 2, lbl_lexp) 80 | trn_scr = trn_scr.squeeze(-1) 81 | 82 | mask = sequence_mask(lens + 1).float() 83 | trn_scr = trn_scr * mask 84 | score = trn_scr 85 | 86 | return score 87 | 88 | def calc_unary_score(self, logits, labels, lens): 89 | labels = labels[:, :logits.size(1)] 90 | 91 | labels_exp = labels.unsqueeze(-1) 92 | 93 | scores = torch.gather(logits, 2, labels_exp).squeeze(-1) 94 | mask = sequence_mask(lens).float() 95 | scores = scores * mask 96 | return scores 97 | 98 | def calc_gold_score(self, logits, labels, lens): 99 | unary_score = self.calc_unary_score(logits, labels, lens).sum( 100 | 1).squeeze(-1) 101 | binary_score = self.calc_binary_score(labels, lens).sum(1).squeeze(-1) 102 | return unary_score + binary_score 103 | 104 | def calc_norm_score(self, logits, lens): 105 | batch_size, seq_len, feat_dim = logits.size() 106 | # alpha = logits.data.new(batch_size, self.label_size).fill_(-10000.0) 107 | alpha = logits.new_full((batch_size, self.label_size), -100.0) 108 | alpha[:, self.start] = 0 109 | # alpha = Variable(alpha) 110 | lens_ = lens.clone() 111 | 112 | logits_t = logits.transpose(1, 0) 113 | for logit in logits_t: 114 | logit_exp = logit.unsqueeze(-1).expand(batch_size, 115 | *self.transition.size()) 116 | alpha_exp = alpha.unsqueeze(1).expand(batch_size, 117 | *self.transition.size()) 118 | trans_exp = self.transition.unsqueeze(0).expand_as(alpha_exp) 119 | mat = logit_exp + alpha_exp + trans_exp 120 | alpha_nxt = log_sum_exp(mat, 2).squeeze(-1) 121 | 122 | mask = (lens_ > 0).float().unsqueeze(-1).expand_as(alpha) 123 | alpha = mask * alpha_nxt + (1 - mask) * alpha 124 | lens_ = lens_ - 1 125 | 126 | alpha = alpha + self.transition[self.end].unsqueeze(0).expand_as(alpha) 127 | norm = log_sum_exp(alpha, 1).squeeze(-1) 128 | 129 | return norm 130 | 131 | def viterbi_decode(self, logits, lens): 132 | """Borrowed from pytorch tutorial 133 | Arguments: 134 | logits: [batch_size, seq_len, n_labels] FloatTensor 135 | lens: [batch_size] LongTensor 136 | """ 137 | batch_size, seq_len, n_labels = logits.size() 138 | # vit = logits.data.new(batch_size, self.label_size).fill_(-10000) 139 | vit = logits.new_full((batch_size, self.label_size), -100.0) 140 | vit[:, self.start] = 0 141 | # vit = Variable(vit) 142 | c_lens = lens.clone() 143 | 144 | logits_t = logits.transpose(1, 0) 145 | pointers = [] 146 | for logit in logits_t: 147 | vit_exp = vit.unsqueeze(1).expand(batch_size, n_labels, n_labels) 148 | trn_exp = self.transition.unsqueeze(0).expand_as(vit_exp) 149 | vit_trn_sum = vit_exp + trn_exp 150 | vt_max, vt_argmax = vit_trn_sum.max(2) 151 | 152 | vt_max = vt_max.squeeze(-1) 153 | vit_nxt = vt_max + logit 154 | pointers.append(vt_argmax.squeeze(-1).unsqueeze(0)) 155 | 156 | mask = (c_lens > 0).float().unsqueeze(-1).expand_as(vit_nxt) 157 | vit = mask * vit_nxt + (1 - mask) * vit 158 | 159 | mask = (c_lens == 1).float().unsqueeze(-1).expand_as(vit_nxt) 160 | vit += mask * self.transition[self.end].unsqueeze( 161 | 0).expand_as(vit_nxt) 162 | 163 | c_lens = c_lens - 1 164 | 165 | pointers = torch.cat(pointers) 166 | scores, idx = vit.max(1) 167 | # idx = idx.squeeze(-1) 168 | paths = [idx.unsqueeze(1)] 169 | for argmax in reversed(pointers): 170 | idx_exp = idx.unsqueeze(-1) 171 | idx = torch.gather(argmax, 1, idx_exp) 172 | idx = idx.squeeze(-1) 173 | 174 | paths.insert(0, idx.unsqueeze(1)) 175 | 176 | paths = torch.cat(paths[1:], 1) 177 | scores = scores.squeeze(-1) 178 | 179 | return scores, paths 180 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | 23 | def warmup_cosine(x, warmup=0.002): 24 | if x < warmup: 25 | return x/warmup 26 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 27 | 28 | def warmup_constant(x, warmup=0.002): 29 | if x < warmup: 30 | return x/warmup 31 | return 1.0 32 | 33 | def warmup_linear(x, warmup=0.002): 34 | if x < warmup: 35 | return x/warmup 36 | return 1.0 - x 37 | 38 | SCHEDULES = { 39 | 'warmup_cosine':warmup_cosine, 40 | 'warmup_constant':warmup_constant, 41 | 'warmup_linear':warmup_linear, 42 | } 43 | 44 | 45 | class BertAdam(Optimizer): 46 | """Implements BERT version of Adam algorithm with weight decay fix. 47 | Params: 48 | lr: learning rate 49 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 50 | t_total: total number of training steps for the learning 51 | rate schedule, -1 means constant learning rate. Default: -1 52 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 53 | b1: Adams b1. Default: 0.9 54 | b2: Adams b2. Default: 0.999 55 | e: Adams epsilon. Default: 1e-6 56 | weight_decay: Weight decay. Default: 0.01 57 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 58 | """ 59 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 60 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 61 | max_grad_norm=1.0): 62 | if lr is not required and lr < 0.0: 63 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 64 | if schedule not in SCHEDULES: 65 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 66 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 67 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 68 | if not 0.0 <= b1 < 1.0: 69 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 70 | if not 0.0 <= b2 < 1.0: 71 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 72 | if not e >= 0.0: 73 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 74 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 75 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 76 | max_grad_norm=max_grad_norm) 77 | super(BertAdam, self).__init__(params, defaults) 78 | 79 | def get_lr(self): 80 | lr = [] 81 | for group in self.param_groups: 82 | for p in group['params']: 83 | state = self.state[p] 84 | if len(state) == 0: 85 | return [0] 86 | if group['t_total'] != -1: 87 | schedule_fct = SCHEDULES[group['schedule']] 88 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 89 | else: 90 | lr_scheduled = group['lr'] 91 | lr.append(lr_scheduled) 92 | return lr 93 | 94 | def step(self, closure=None): 95 | """Performs a single optimization step. 96 | 97 | Arguments: 98 | closure (callable, optional): A closure that reevaluates the model 99 | and returns the loss. 100 | """ 101 | loss = None 102 | if closure is not None: 103 | loss = closure() 104 | 105 | for group in self.param_groups: 106 | for p in group['params']: 107 | if p.grad is None: 108 | continue 109 | grad = p.grad.data 110 | if grad.is_sparse: 111 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 112 | 113 | state = self.state[p] 114 | 115 | # State initialization 116 | if len(state) == 0: 117 | state['step'] = 0 118 | # Exponential moving average of gradient values 119 | state['next_m'] = torch.zeros_like(p.data) 120 | # Exponential moving average of squared gradient values 121 | state['next_v'] = torch.zeros_like(p.data) 122 | 123 | next_m, next_v = state['next_m'], state['next_v'] 124 | beta1, beta2 = group['b1'], group['b2'] 125 | 126 | # Add grad clipping 127 | if group['max_grad_norm'] > 0: 128 | clip_grad_norm_(p, group['max_grad_norm']) 129 | 130 | # Decay the first and second moment running average coefficient 131 | # In-place operations to update the averages at the same time 132 | next_m.mul_(beta1).add_(1 - beta1, grad) 133 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 134 | update = next_m / (next_v.sqrt() + group['e']) 135 | 136 | # Just adding the square of the weights to the loss function is *not* 137 | # the correct way of using L2 regularization/weight decay with Adam, 138 | # since that will interact with the m and v parameters in strange ways. 139 | # 140 | # Instead we want to decay the weights in a manner that doesn't interact 141 | # with the m/v parameters. This is equivalent to adding the square 142 | # of the weights to the loss with plain (non-momentum) SGD. 143 | if group['weight_decay'] > 0.0: 144 | update += group['weight_decay'] * p.data 145 | 146 | if group['t_total'] != -1: 147 | schedule_fct = SCHEDULES[group['schedule']] 148 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 149 | else: 150 | lr_scheduled = group['lr'] 151 | 152 | update_with_lr = lr_scheduled * update 153 | p.data.add_(-update_with_lr) 154 | 155 | state['step'] += 1 156 | 157 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 158 | # No bias correction 159 | # bias_correction1 = 1 - beta1 ** state['step'] 160 | # bias_correction2 = 1 - beta2 ** state['step'] 161 | 162 | return loss 163 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from pytorch_pretrained_bert.modeling import * 2 | from torch.nn import BCEWithLogitsLoss 3 | from layers.embedding import BertEmbedder 4 | from layers.encoder import BertBiLSTMEncoder 5 | from layers.decoder import CRFDecoder 6 | from torch.autograd import Variable 7 | import torch.nn.functional as F 8 | 9 | 10 | class BertForMultiHeadProblem(BertPreTrainedModel): 11 | def __init__(self, config, num_labels): 12 | super(BertForMultiHeadProblem, self).__init__(config) 13 | self.num_labels = num_labels 14 | self.bert = BertModel(config) 15 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 16 | self.classifier = nn.Linear(config.hidden_size, num_labels) 17 | self.apply(self.init_bert_weights) 18 | 19 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 20 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 21 | sequence_output = self.dropout(sequence_output) 22 | logits = self.classifier(sequence_output) 23 | 24 | if labels is not None: 25 | loss_bce = BCEWithLogitsLoss() 26 | # Only keep active parts of the loss 27 | if attention_mask is not None: 28 | active_loss = attention_mask.view(-1) == 1 29 | active_logits = logits.view(-1, self.num_labels)[active_loss] 30 | active_labels = labels.view(-1, self.num_labels)[active_loss] 31 | loss = loss_bce(active_logits, active_labels) 32 | else: 33 | loss = loss_bce(logits.view(-1, self.num_labels), labels.view(-1)) 34 | return loss 35 | else: 36 | return logits 37 | 38 | class MultiHeadLayers(nn.Module): 39 | def __init__(self, label_embedding_size, hidden_size, lstm_hidden_size, ner_size, rel_size, activation="tanh", dropout=0.1): 40 | super(MultiHeadLayers, self).__init__() 41 | self.dropout_1 = nn.Dropout(p=dropout) 42 | self.dropout_2 = nn.Dropout(p=dropout) 43 | 44 | self.activation = activation 45 | self.rel_size = rel_size 46 | 47 | self.label_embedding_size = label_embedding_size 48 | 49 | self.u_a = nn.Parameter(torch.randn(lstm_hidden_size + label_embedding_size, hidden_size)) 50 | self.w_a = nn.Parameter(torch.randn(lstm_hidden_size + label_embedding_size, hidden_size)) 51 | self.v = nn.Parameter(torch.randn(hidden_size, rel_size)) 52 | self.b_s = nn.Parameter(torch.randn(hidden_size)) 53 | 54 | if self.label_embedding_size != 0: 55 | self.label_embedding = nn.Embedding(ner_size, label_embedding_size) 56 | 57 | 58 | def broadcasting(self, left, right): 59 | left = left.transpose(1, 0) 60 | left = left.unsqueeze(3) 61 | 62 | right = right.transpose(2, 1) 63 | right = right.unsqueeze(0) 64 | 65 | B = left + right 66 | B = B.transpose(1, 0).transpose(3, 2) 67 | 68 | return B 69 | 70 | def forward(self, lstm_output, pred_ner): 71 | lstm_output = self.dropout_1(lstm_output) 72 | if self.label_embedding_size != 0: 73 | embeded_label = self.label_embedding(pred_ner) 74 | z = torch.cat([lstm_output, embeded_label], dim=2) 75 | else: 76 | z = lstm_output 77 | left = torch.einsum('aij,jk->aik', z, self.u_a) 78 | right = torch.einsum('aij,jk->aik', z, self.w_a) 79 | 80 | outer_sum = self.broadcasting(left, right) 81 | 82 | outer_sum_bias = 1 * ( outer_sum + self.b_s ) 83 | 84 | if self.activation=="tanh": 85 | output = torch.tanh(outer_sum_bias) 86 | elif self.activation=="relu": 87 | output = torch.relu(outer_sum_bias) 88 | 89 | output = self.dropout_2(output) 90 | 91 | g = torch.einsum('aijk,kp->aijp', output, self.v) 92 | 93 | g = g.view(g.size(0), g.size(1), g.size(2) * self.rel_size) 94 | 95 | sigmoid = torch.nn.Sigmoid() 96 | probas = sigmoid(g) 97 | predictedRel = torch.round(probas) 98 | 99 | return predictedRel 100 | 101 | def score(self, lstm_output, gold_ner_labels, gold_rel_labels): 102 | lstm_output = self.dropout_1(lstm_output) 103 | 104 | if self.label_embedding_size != 0: 105 | embeded_label = self.label_embedding(gold_ner_labels)[:, :lstm_output.size(1), :] 106 | z = torch.cat([lstm_output, embeded_label], dim=2) 107 | else: 108 | z = lstm_output 109 | 110 | left = torch.einsum('aij,jk->aik', z, self.u_a) 111 | right = torch.einsum('aij,jk->aik', z, self.w_a) 112 | 113 | outer_sum = self.broadcasting(left, right) 114 | 115 | outer_sum_bias = 1 * ( outer_sum + self.b_s) 116 | 117 | if self.activation == "tanh": 118 | output = torch.tanh(outer_sum_bias) 119 | elif self.activation == "relu": 120 | output = torch.relu(outer_sum_bias) 121 | 122 | output = self.dropout_2(output) 123 | 124 | g = torch.einsum('aijk,kp->aijp', output, self.v) 125 | 126 | g = g.view(g.size(0), g.size(1), g.size(2) * self.rel_size) 127 | 128 | loss_bce = BCEWithLogitsLoss(reduction="mean") 129 | 130 | active_rel_labels = gold_rel_labels[:, :g.size(1), :g.size(2)] 131 | 132 | loss = loss_bce(g, active_rel_labels) 133 | 134 | return loss 135 | 136 | 137 | class BertBiLSTMCRF(nn.Module): 138 | def __init__(self, encoder, decoder, extra=None, use_cuda=True, use_extra=True): 139 | super(BertBiLSTMCRF, self).__init__() 140 | self.encoder = encoder 141 | self.extra = extra 142 | self.decoder = decoder 143 | self.use_cuda = use_cuda 144 | self.use_extra = use_extra 145 | if use_cuda: 146 | self.cuda() 147 | #print(list(self.parameters())) 148 | 149 | def forward(self, batch): 150 | output, hidden = self.encoder(batch) 151 | predictedNer = self.decoder(output, batch[-3]) 152 | 153 | if self.use_extra: 154 | predictedRel = self.extra(output, self.decoder(output, batch[-3])) 155 | return predictedNer, predictedRel 156 | else: 157 | return predictedNer, None 158 | 159 | def score(self, batch): 160 | output, _ = self.encoder(batch) 161 | lossNER = self.decoder.score(output, batch[-3], batch[-2].long()) 162 | if self.use_extra: 163 | lossREL = self.extra.score(output, batch[-2].long(), batch[-1]) 164 | return lossNER + lossREL 165 | else: 166 | return lossNER 167 | 168 | @classmethod 169 | def create(cls, 170 | ner_size, 171 | rel_size, 172 | bert_pretrained_path, embedding_dim=768, bert_mode="weighted", 173 | freeze=True, 174 | enc_hidden_dim=128, rnn_layers=1, 175 | input_dropout=0.1, 176 | use_cuda=True, 177 | use_extra=True, 178 | meta_dim=None, 179 | hidden_size=64, 180 | label_embedding_size=64, 181 | activation="tanh"): 182 | 183 | embedder = BertEmbedder.create(bert_pretrained_path, embedding_dim, use_cuda, bert_mode, freeze) 184 | encoder = BertBiLSTMEncoder.create(embedder, enc_hidden_dim, rnn_layers, use_cuda) 185 | 186 | extra = None 187 | if use_extra: 188 | extra = MultiHeadLayers(label_embedding_size, hidden_size, enc_hidden_dim, ner_size, rel_size, activation, input_dropout) 189 | 190 | decoder = CRFDecoder(ner_size, encoder.output_dim, input_dropout, activation=activation) 191 | 192 | return cls(encoder, decoder, extra, use_cuda, use_extra) 193 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import json 9 | import logging 10 | import os 11 | import shutil 12 | import tempfile 13 | from functools import wraps 14 | from hashlib import sha256 15 | import sys 16 | from io import open 17 | 18 | import boto3 19 | import requests 20 | from botocore.exceptions import ClientError 21 | from tqdm import tqdm 22 | 23 | try: 24 | from urllib.parse import urlparse 25 | except ImportError: 26 | from urlparse import urlparse 27 | 28 | try: 29 | from pathlib import Path 30 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 31 | Path.home() / '.pytorch_pretrained_bert')) 32 | except AttributeError: 33 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 34 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) 35 | 36 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 37 | 38 | 39 | def url_to_filename(url, etag=None): 40 | """ 41 | Convert `url` into a hashed filename in a repeatable way. 42 | If `etag` is specified, append its hash to the url's, delimited 43 | by a period. 44 | """ 45 | url_bytes = url.encode('utf-8') 46 | url_hash = sha256(url_bytes) 47 | filename = url_hash.hexdigest() 48 | 49 | if etag: 50 | etag_bytes = etag.encode('utf-8') 51 | etag_hash = sha256(etag_bytes) 52 | filename += '.' + etag_hash.hexdigest() 53 | 54 | return filename 55 | 56 | 57 | def filename_to_url(filename, cache_dir=None): 58 | """ 59 | Return the url and etag (which may be ``None``) stored for `filename`. 60 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 61 | """ 62 | if cache_dir is None: 63 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 64 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 65 | cache_dir = str(cache_dir) 66 | 67 | cache_path = os.path.join(cache_dir, filename) 68 | if not os.path.exists(cache_path): 69 | raise EnvironmentError("file {} not found".format(cache_path)) 70 | 71 | meta_path = cache_path + '.json' 72 | if not os.path.exists(meta_path): 73 | raise EnvironmentError("file {} not found".format(meta_path)) 74 | 75 | with open(meta_path, encoding="utf-8") as meta_file: 76 | metadata = json.load(meta_file) 77 | url = metadata['url'] 78 | etag = metadata['etag'] 79 | 80 | return url, etag 81 | 82 | 83 | def cached_path(url_or_filename, cache_dir=None): 84 | """ 85 | Given something that might be a URL (or might be a local path), 86 | determine which. If it's a URL, download the file and cache it, and 87 | return the path to the cached file. If it's already a local path, 88 | make sure the file exists and then return the path. 89 | """ 90 | if cache_dir is None: 91 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 92 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 93 | url_or_filename = str(url_or_filename) 94 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 95 | cache_dir = str(cache_dir) 96 | 97 | 98 | parsed = urlparse(url_or_filename) 99 | 100 | if parsed.scheme in ('http', 'https', 's3'): 101 | # URL, so get it from the cache (downloading if necessary) 102 | return get_from_cache(url_or_filename, cache_dir) 103 | elif os.path.exists(url_or_filename): 104 | # File, and it exists. 105 | return url_or_filename 106 | elif parsed.scheme == '': 107 | # File, but it doesn't exist. 108 | raise EnvironmentError("file {} not found".format(url_or_filename)) 109 | else: 110 | # Something unknown 111 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 112 | 113 | 114 | def split_s3_path(url): 115 | """Split a full s3 path into the bucket name and path.""" 116 | parsed = urlparse(url) 117 | if not parsed.netloc or not parsed.path: 118 | raise ValueError("bad s3 path {}".format(url)) 119 | bucket_name = parsed.netloc 120 | s3_path = parsed.path 121 | # Remove '/' at beginning of path. 122 | if s3_path.startswith("/"): 123 | s3_path = s3_path[1:] 124 | return bucket_name, s3_path 125 | 126 | 127 | def s3_request(func): 128 | """ 129 | Wrapper function for s3 requests in order to create more helpful error 130 | messages. 131 | """ 132 | 133 | @wraps(func) 134 | def wrapper(url, *args, **kwargs): 135 | try: 136 | return func(url, *args, **kwargs) 137 | except ClientError as exc: 138 | if int(exc.response["Error"]["Code"]) == 404: 139 | raise EnvironmentError("file {} not found".format(url)) 140 | else: 141 | raise 142 | 143 | return wrapper 144 | 145 | 146 | @s3_request 147 | def s3_etag(url): 148 | """Check ETag on S3 object.""" 149 | s3_resource = boto3.resource("s3") 150 | bucket_name, s3_path = split_s3_path(url) 151 | s3_object = s3_resource.Object(bucket_name, s3_path) 152 | return s3_object.e_tag 153 | 154 | 155 | @s3_request 156 | def s3_get(url, temp_file): 157 | """Pull a file directly from S3.""" 158 | s3_resource = boto3.resource("s3") 159 | bucket_name, s3_path = split_s3_path(url) 160 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 161 | 162 | 163 | def http_get(url, temp_file): 164 | req = requests.get(url, stream=True) 165 | content_length = req.headers.get('Content-Length') 166 | total = int(content_length) if content_length is not None else None 167 | progress = tqdm(unit="B", total=total) 168 | for chunk in req.iter_content(chunk_size=1024): 169 | if chunk: # filter out keep-alive new chunks 170 | progress.update(len(chunk)) 171 | temp_file.write(chunk) 172 | progress.close() 173 | 174 | 175 | def get_from_cache(url, cache_dir=None): 176 | """ 177 | Given a URL, look for the corresponding dataset in the local cache. 178 | If it's not there, download it. Then return the path to the cached file. 179 | """ 180 | if cache_dir is None: 181 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 182 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 183 | cache_dir = str(cache_dir) 184 | 185 | if not os.path.exists(cache_dir): 186 | os.makedirs(cache_dir) 187 | 188 | # Get eTag to add to filename, if it exists. 189 | if url.startswith("s3://"): 190 | etag = s3_etag(url) 191 | else: 192 | response = requests.head(url, allow_redirects=True) 193 | if response.status_code != 200: 194 | raise IOError("HEAD request failed for url {} with status code {}" 195 | .format(url, response.status_code)) 196 | etag = response.headers.get("ETag") 197 | 198 | filename = url_to_filename(url, etag) 199 | 200 | # get cache path to put the file 201 | cache_path = os.path.join(cache_dir, filename) 202 | 203 | if not os.path.exists(cache_path): 204 | # Download to temporary file, then copy to cache dir once finished. 205 | # Otherwise you get corrupt cache entries if the download gets interrupted. 206 | with tempfile.NamedTemporaryFile() as temp_file: 207 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 208 | 209 | # GET file object 210 | if url.startswith("s3://"): 211 | s3_get(url, temp_file) 212 | else: 213 | http_get(url, temp_file) 214 | 215 | # we are copying the file before closing it, so flush to avoid truncation 216 | temp_file.flush() 217 | # shutil.copyfileobj() starts at the current position, so go to the start 218 | temp_file.seek(0) 219 | 220 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 221 | with open(cache_path, 'wb') as cache_file: 222 | shutil.copyfileobj(temp_file, cache_file) 223 | 224 | logger.info("creating metadata file for %s", cache_path) 225 | meta = {'url': url, 'etag': etag} 226 | meta_path = cache_path + '.json' 227 | with open(meta_path, 'w', encoding="utf-8") as meta_file: 228 | json.dump(meta, meta_file) 229 | 230 | logger.info("removing temp file %s", temp_file.name) 231 | 232 | return cache_path 233 | 234 | 235 | def read_set_from_file(filename): 236 | ''' 237 | Extract a de-duped collection (set) of text from a file. 238 | Expected file format is one item per line. 239 | ''' 240 | collection = set() 241 | with open(filename, 'r', encoding='utf-8') as file_: 242 | for line in file_: 243 | collection.add(line.rstrip()) 244 | return collection 245 | 246 | 247 | def get_file_extension(path, dot=True, lower=True): 248 | ext = os.path.splitext(path)[1] 249 | ext = ext if dot else ext[1:] 250 | return ext.lower() if lower else ext 251 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import json 20 | import logging 21 | import os 22 | import regex as re 23 | from io import open 24 | 25 | try: 26 | from functools import lru_cache 27 | except ImportError: 28 | # Just a dummy decorator to get the checks to run on python2 29 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. 30 | def lru_cache(): 31 | return lambda func: func 32 | 33 | from .file_utils import cached_path 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 38 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 39 | } 40 | PRETRAINED_MERGES_ARCHIVE_MAP = { 41 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 42 | } 43 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 44 | 'gpt2': 1024, 45 | } 46 | VOCAB_NAME = 'vocab.json' 47 | MERGES_NAME = 'merges.txt' 48 | 49 | @lru_cache() 50 | def bytes_to_unicode(): 51 | """ 52 | Returns list of utf-8 byte and a corresponding list of unicode strings. 53 | The reversible bpe codes work on unicode strings. 54 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 55 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 56 | This is a signficant percentage of your normal, say, 32K bpe vocab. 57 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 58 | And avoids mapping to whitespace/control characters the bpe code barfs on. 59 | """ 60 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 61 | cs = bs[:] 62 | n = 0 63 | for b in range(2**8): 64 | if b not in bs: 65 | bs.append(b) 66 | cs.append(2**8+n) 67 | n += 1 68 | cs = [chr(n) for n in cs] 69 | return dict(zip(bs, cs)) 70 | 71 | def get_pairs(word): 72 | """Return set of symbol pairs in a word. 73 | 74 | Word is represented as tuple of symbols (symbols being variable-length strings). 75 | """ 76 | pairs = set() 77 | prev_char = word[0] 78 | for char in word[1:]: 79 | pairs.add((prev_char, char)) 80 | prev_char = char 81 | return pairs 82 | 83 | class GPT2Tokenizer(object): 84 | """ 85 | GPT-2 BPE tokenizer. Peculiarities: 86 | - Byte-level BPE 87 | """ 88 | @classmethod 89 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 90 | """ 91 | Instantiate a PreTrainedBertModel from a pre-trained model file. 92 | Download and cache the pre-trained model file if needed. 93 | """ 94 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 95 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 96 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] 97 | else: 98 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 99 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) 100 | # redirect to the cache, if necessary 101 | try: 102 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 103 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) 104 | except EnvironmentError: 105 | logger.error( 106 | "Model name '{}' was not found in model name list ({}). " 107 | "We assumed '{}' was a path or url but couldn't find files {} and {} " 108 | "at this path or url.".format( 109 | pretrained_model_name_or_path, 110 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 111 | pretrained_model_name_or_path, 112 | vocab_file, merges_file)) 113 | return None 114 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: 115 | logger.info("loading vocabulary file {}".format(vocab_file)) 116 | logger.info("loading merges file {}".format(merges_file)) 117 | else: 118 | logger.info("loading vocabulary file {} from cache at {}".format( 119 | vocab_file, resolved_vocab_file)) 120 | logger.info("loading merges file {} from cache at {}".format( 121 | merges_file, resolved_merges_file)) 122 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 123 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 124 | # than the number of positional embeddings 125 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 126 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 127 | # Instantiate tokenizer. 128 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs) 129 | return tokenizer 130 | 131 | def __init__(self, vocab_file, merges_file, errors='replace', max_len=None): 132 | self.max_len = max_len if max_len is not None else int(1e12) 133 | self.encoder = json.load(open(vocab_file)) 134 | self.decoder = {v:k for k,v in self.encoder.items()} 135 | self.errors = errors # how to handle errors in decoding 136 | self.byte_encoder = bytes_to_unicode() 137 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 138 | bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 139 | bpe_merges = [tuple(merge.split()) for merge in bpe_data] 140 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 141 | self.cache = {} 142 | 143 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 144 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 145 | 146 | def __len__(self): 147 | return len(self.encoder) 148 | 149 | def bpe(self, token): 150 | if token in self.cache: 151 | return self.cache[token] 152 | word = tuple(token) 153 | pairs = get_pairs(word) 154 | 155 | if not pairs: 156 | return token 157 | 158 | while True: 159 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 160 | if bigram not in self.bpe_ranks: 161 | break 162 | first, second = bigram 163 | new_word = [] 164 | i = 0 165 | while i < len(word): 166 | try: 167 | j = word.index(first, i) 168 | new_word.extend(word[i:j]) 169 | i = j 170 | except: 171 | new_word.extend(word[i:]) 172 | break 173 | 174 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 175 | new_word.append(first+second) 176 | i += 2 177 | else: 178 | new_word.append(word[i]) 179 | i += 1 180 | new_word = tuple(new_word) 181 | word = new_word 182 | if len(word) == 1: 183 | break 184 | else: 185 | pairs = get_pairs(word) 186 | word = ' '.join(word) 187 | self.cache[token] = word 188 | return word 189 | 190 | def encode(self, text): 191 | bpe_tokens = [] 192 | for token in re.findall(self.pat, text): 193 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 194 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 195 | if len(bpe_tokens) > self.max_len: 196 | raise ValueError( 197 | "Token indices sequence length is longer than the specified maximum " 198 | " sequence length for this OpenAI GPT-2 model ({} > {}). Running this" 199 | " sequence through the model will result in indexing errors".format(len(bpe_tokens), self.max_len) 200 | ) 201 | return bpe_tokens 202 | 203 | def decode(self, tokens): 204 | text = ''.join([self.decoder[token] for token in tokens]) 205 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 206 | return text 207 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import json 20 | import logging 21 | import os 22 | import re 23 | import sys 24 | from io import open 25 | 26 | from tqdm import tqdm 27 | 28 | from .file_utils import cached_path 29 | from .tokenization import BasicTokenizer 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 34 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json", 35 | } 36 | PRETRAINED_MERGES_ARCHIVE_MAP = { 37 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt", 38 | } 39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 40 | 'openai-gpt': 512, 41 | } 42 | VOCAB_NAME = 'vocab.json' 43 | MERGES_NAME = 'merges.txt' 44 | 45 | def get_pairs(word): 46 | """ 47 | Return set of symbol pairs in a word. 48 | word is represented as tuple of symbols (symbols being variable-length strings) 49 | """ 50 | pairs = set() 51 | prev_char = word[0] 52 | for char in word[1:]: 53 | pairs.add((prev_char, char)) 54 | prev_char = char 55 | return pairs 56 | 57 | def text_standardize(text): 58 | """ 59 | fixes some issues the spacy tokenizer had on books corpus 60 | also does some whitespace standardization 61 | """ 62 | text = text.replace('—', '-') 63 | text = text.replace('–', '-') 64 | text = text.replace('―', '-') 65 | text = text.replace('…', '...') 66 | text = text.replace('´', "'") 67 | text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text) 68 | text = re.sub(r'\s*\n\s*', ' \n ', text) 69 | text = re.sub(r'[^\S\n]+', ' ', text) 70 | return text.strip() 71 | 72 | class OpenAIGPTTokenizer(object): 73 | """ 74 | BPE tokenizer. Peculiarities: 75 | - lower case all inputs 76 | - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not. 77 | - argument special_tokens and function set_special_tokens: 78 | can be used to add additional symbols (ex: "__classify__") to a vocabulary. 79 | """ 80 | @classmethod 81 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 82 | """ 83 | Instantiate a PreTrainedBertModel from a pre-trained model file. 84 | Download and cache the pre-trained model file if needed. 85 | """ 86 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 87 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 88 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] 89 | else: 90 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 91 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) 92 | # redirect to the cache, if necessary 93 | try: 94 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 95 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) 96 | except EnvironmentError: 97 | logger.error( 98 | "Model name '{}' was not found in model name list ({}). " 99 | "We assumed '{}' was a path or url but couldn't find files {} and {} " 100 | "at this path or url.".format( 101 | pretrained_model_name_or_path, 102 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 103 | pretrained_model_name_or_path, 104 | vocab_file, merges_file)) 105 | return None 106 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: 107 | logger.info("loading vocabulary file {}".format(vocab_file)) 108 | logger.info("loading merges file {}".format(merges_file)) 109 | else: 110 | logger.info("loading vocabulary file {} from cache at {}".format( 111 | vocab_file, resolved_vocab_file)) 112 | logger.info("loading merges file {} from cache at {}".format( 113 | merges_file, resolved_merges_file)) 114 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 115 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 116 | # than the number of positional embeddings 117 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 118 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 119 | # Instantiate tokenizer. 120 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs) 121 | return tokenizer 122 | 123 | def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None): 124 | try: 125 | import ftfy 126 | import spacy 127 | self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat']) 128 | self.fix_text = ftfy.fix_text 129 | except ImportError: 130 | logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") 131 | self.nlp = BasicTokenizer(do_lower_case=True, 132 | never_split=special_tokens if special_tokens is not None else []) 133 | self.fix_text = None 134 | 135 | self.max_len = max_len if max_len is not None else int(1e12) 136 | self.encoder = json.load(open(vocab_file, encoding="utf-8")) 137 | self.decoder = {v:k for k,v in self.encoder.items()} 138 | merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 139 | merges = [tuple(merge.split()) for merge in merges] 140 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 141 | self.cache = {} 142 | self.set_special_tokens(special_tokens) 143 | 144 | def __len__(self): 145 | return len(self.encoder) + len(self.special_tokens) 146 | 147 | def set_special_tokens(self, special_tokens): 148 | """ Add a list of additional tokens to the encoder. 149 | The additional tokens are indexed starting from the last index of the 150 | current vocabulary in the order of the `special_tokens` list. 151 | """ 152 | if not special_tokens: 153 | self.special_tokens = {} 154 | self.special_tokens_decoder = {} 155 | return 156 | self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) 157 | self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} 158 | if self.fix_text is None: 159 | # Using BERT's BasicTokenizer: we can update the tokenizer 160 | self.nlp.never_split = special_tokens 161 | logger.info("Special tokens {}".format(self.special_tokens)) 162 | 163 | def bpe(self, token): 164 | word = tuple(token[:-1]) + (token[-1] + '',) 165 | if token in self.cache: 166 | return self.cache[token] 167 | pairs = get_pairs(word) 168 | 169 | if not pairs: 170 | return token+'' 171 | 172 | while True: 173 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 174 | if bigram not in self.bpe_ranks: 175 | break 176 | first, second = bigram 177 | new_word = [] 178 | i = 0 179 | while i < len(word): 180 | try: 181 | j = word.index(first, i) 182 | new_word.extend(word[i:j]) 183 | i = j 184 | except: 185 | new_word.extend(word[i:]) 186 | break 187 | 188 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 189 | new_word.append(first+second) 190 | i += 2 191 | else: 192 | new_word.append(word[i]) 193 | i += 1 194 | new_word = tuple(new_word) 195 | word = new_word 196 | if len(word) == 1: 197 | break 198 | else: 199 | pairs = get_pairs(word) 200 | word = ' '.join(word) 201 | if word == '\n ': 202 | word = '\n' 203 | self.cache[token] = word 204 | return word 205 | 206 | def tokenize(self, text): 207 | """ Tokenize a string. """ 208 | split_tokens = [] 209 | if self.fix_text is None: 210 | # Using BERT's BasicTokenizer 211 | text = self.nlp.tokenize(text) 212 | for token in text: 213 | split_tokens.extend([t for t in self.bpe(token).split(' ')]) 214 | else: 215 | # Using SpaCy & ftfy (original tokenization process of OpenAI GPT) 216 | text = self.nlp(text_standardize(self.fix_text(text))) 217 | for token in text: 218 | split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) 219 | return split_tokens 220 | 221 | def convert_tokens_to_ids(self, tokens): 222 | """ Converts a sequence of tokens into ids using the vocab. """ 223 | ids = [] 224 | if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): 225 | if tokens in self.special_tokens: 226 | return self.special_tokens[tokens] 227 | else: 228 | return self.encoder.get(tokens, 0) 229 | for token in tokens: 230 | if token in self.special_tokens: 231 | ids.append(self.special_tokens[token]) 232 | else: 233 | ids.append(self.encoder.get(token, 0)) 234 | if len(ids) > self.max_len: 235 | raise ValueError( 236 | "Token indices sequence length is longer than the specified maximum " 237 | " sequence length for this OpenAI GPT model ({} > {}). Running this" 238 | " sequence through the model will result in indexing errors".format(len(ids), self.max_len) 239 | ) 240 | return ids 241 | 242 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False): 243 | """Converts a sequence of ids in BPE tokens using the vocab.""" 244 | tokens = [] 245 | for i in ids: 246 | if i in self.special_tokens_decoder: 247 | if not skip_special_tokens: 248 | tokens.append(self.special_tokens_decoder[i]) 249 | else: 250 | tokens.append(self.decoder[i]) 251 | return tokens 252 | 253 | def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=False): 254 | """Converts a sequence of ids in a string.""" 255 | tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens) 256 | out_string = ''.join(tokens).replace('', ' ').strip() 257 | if clean_up_tokenization_spaces: 258 | out_string = out_string.replace('', '') 259 | out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',').replace(' ,', ',' 260 | ).replace(" n't", "n't").replace(" 'm", "'m").replace(" 're", "'re").replace(" do not", " don't" 261 | ).replace(" 's", "'s").replace(" t ", "'t ").replace(" s ", "'s ").replace(" m ", "'m " 262 | ).replace(" 've", "'ve") 263 | return out_string 264 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | from io import open 24 | 25 | from .file_utils import cached_path 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 37 | } 38 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 39 | 'bert-base-uncased': 512, 40 | 'bert-large-uncased': 512, 41 | 'bert-base-cased': 512, 42 | 'bert-large-cased': 512, 43 | 'bert-base-multilingual-uncased': 512, 44 | 'bert-base-multilingual-cased': 512, 45 | 'bert-base-chinese': 512, 46 | } 47 | VOCAB_NAME = 'vocab.txt' 48 | 49 | 50 | def load_vocab(vocab_file): 51 | """Loads a vocabulary file into a dictionary.""" 52 | vocab = collections.OrderedDict() 53 | index = 0 54 | with open(vocab_file, "r", encoding="utf-8") as reader: 55 | while True: 56 | token = reader.readline() 57 | if not token: 58 | break 59 | token = token.strip() 60 | vocab[token] = index 61 | index += 1 62 | return vocab 63 | 64 | 65 | def whitespace_tokenize(text): 66 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 67 | text = text.strip() 68 | if not text: 69 | return [] 70 | tokens = text.split() 71 | return tokens 72 | 73 | 74 | class BertTokenizer(object): 75 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 76 | 77 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, 78 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 79 | if not os.path.isfile(vocab_file): 80 | raise ValueError( 81 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 82 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 83 | self.vocab = load_vocab(vocab_file) 84 | self.ids_to_tokens = collections.OrderedDict( 85 | [(ids, tok) for tok, ids in self.vocab.items()]) 86 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 87 | never_split=never_split) 88 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 89 | self.max_len = max_len if max_len is not None else int(1e12) 90 | 91 | def tokenize(self, text): 92 | split_tokens = [] 93 | for token in self.basic_tokenizer.tokenize(text): 94 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 95 | split_tokens.append(sub_token) 96 | return split_tokens 97 | 98 | def convert_tokens_to_ids(self, tokens): 99 | """Converts a sequence of tokens into ids using the vocab.""" 100 | ids = [] 101 | for token in tokens: 102 | ids.append(self.vocab[token]) 103 | if len(ids) > self.max_len: 104 | raise ValueError( 105 | "Token indices sequence length is longer than the specified maximum " 106 | " sequence length for this BERT model ({} > {}). Running this" 107 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) 108 | ) 109 | return ids 110 | 111 | def convert_ids_to_tokens(self, ids): 112 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 113 | tokens = [] 114 | for i in ids: 115 | tokens.append(self.ids_to_tokens[i]) 116 | return tokens 117 | 118 | @classmethod 119 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 120 | """ 121 | Instantiate a PreTrainedBertModel from a pre-trained model file. 122 | Download and cache the pre-trained model file if needed. 123 | """ 124 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 125 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 126 | else: 127 | vocab_file = pretrained_model_name_or_path 128 | if os.path.isdir(vocab_file): 129 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 130 | # redirect to the cache, if necessary 131 | try: 132 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 133 | except EnvironmentError: 134 | logger.error( 135 | "Model name '{}' was not found in model name list ({}). " 136 | "We assumed '{}' was a path or url but couldn't find any file " 137 | "associated to this path or url.".format( 138 | pretrained_model_name_or_path, 139 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 140 | vocab_file)) 141 | return None 142 | if resolved_vocab_file == vocab_file: 143 | logger.info("loading vocabulary file {}".format(vocab_file)) 144 | else: 145 | logger.info("loading vocabulary file {} from cache at {}".format( 146 | vocab_file, resolved_vocab_file)) 147 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 148 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 149 | # than the number of positional embeddings 150 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 151 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 152 | # Instantiate tokenizer. 153 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 154 | return tokenizer 155 | 156 | 157 | class BasicTokenizer(object): 158 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 159 | 160 | def __init__(self, 161 | do_lower_case=True, 162 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 163 | """Constructs a BasicTokenizer. 164 | 165 | Args: 166 | do_lower_case: Whether to lower case the input. 167 | """ 168 | self.do_lower_case = do_lower_case 169 | self.never_split = never_split 170 | 171 | def tokenize(self, text): 172 | """Tokenizes a piece of text.""" 173 | text = self._clean_text(text) 174 | # This was added on November 1st, 2018 for the multilingual and Chinese 175 | # models. This is also applied to the English models now, but it doesn't 176 | # matter since the English models were not trained on any Chinese data 177 | # and generally don't have any Chinese data in them (there are Chinese 178 | # characters in the vocabulary because Wikipedia does have some Chinese 179 | # words in the English Wikipedia.). 180 | text = self._tokenize_chinese_chars(text) 181 | orig_tokens = whitespace_tokenize(text) 182 | split_tokens = [] 183 | for token in orig_tokens: 184 | if self.do_lower_case and token not in self.never_split: 185 | token = token.lower() 186 | token = self._run_strip_accents(token) 187 | split_tokens.extend(self._run_split_on_punc(token)) 188 | 189 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 190 | return output_tokens 191 | 192 | def _run_strip_accents(self, text): 193 | """Strips accents from a piece of text.""" 194 | text = unicodedata.normalize("NFD", text) 195 | output = [] 196 | for char in text: 197 | cat = unicodedata.category(char) 198 | if cat == "Mn": 199 | continue 200 | output.append(char) 201 | return "".join(output) 202 | 203 | def _run_split_on_punc(self, text): 204 | """Splits punctuation on a piece of text.""" 205 | if text in self.never_split: 206 | return [text] 207 | chars = list(text) 208 | i = 0 209 | start_new_word = True 210 | output = [] 211 | while i < len(chars): 212 | char = chars[i] 213 | if _is_punctuation(char): 214 | output.append([char]) 215 | start_new_word = True 216 | else: 217 | if start_new_word: 218 | output.append([]) 219 | start_new_word = False 220 | output[-1].append(char) 221 | i += 1 222 | 223 | return ["".join(x) for x in output] 224 | 225 | def _tokenize_chinese_chars(self, text): 226 | """Adds whitespace around any CJK character.""" 227 | output = [] 228 | for char in text: 229 | cp = ord(char) 230 | if self._is_chinese_char(cp): 231 | output.append(" ") 232 | output.append(char) 233 | output.append(" ") 234 | else: 235 | output.append(char) 236 | return "".join(output) 237 | 238 | def _is_chinese_char(self, cp): 239 | """Checks whether CP is the codepoint of a CJK character.""" 240 | # This defines a "chinese character" as anything in the CJK Unicode block: 241 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 242 | # 243 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 244 | # despite its name. The modern Korean Hangul alphabet is a different block, 245 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 246 | # space-separated words, so they are not treated specially and handled 247 | # like the all of the other languages. 248 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 249 | (cp >= 0x3400 and cp <= 0x4DBF) or # 250 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 251 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 252 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 253 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 254 | (cp >= 0xF900 and cp <= 0xFAFF) or # 255 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 256 | return True 257 | 258 | return False 259 | 260 | def _clean_text(self, text): 261 | """Performs invalid character removal and whitespace cleanup on text.""" 262 | output = [] 263 | for char in text: 264 | cp = ord(char) 265 | if cp == 0 or cp == 0xfffd or _is_control(char): 266 | continue 267 | if _is_whitespace(char): 268 | output.append(" ") 269 | else: 270 | output.append(char) 271 | return "".join(output) 272 | 273 | 274 | class WordpieceTokenizer(object): 275 | """Runs WordPiece tokenization.""" 276 | 277 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 278 | self.vocab = vocab 279 | self.unk_token = unk_token 280 | self.max_input_chars_per_word = max_input_chars_per_word 281 | 282 | def tokenize(self, text): 283 | """Tokenizes a piece of text into its word pieces. 284 | 285 | This uses a greedy longest-match-first algorithm to perform tokenization 286 | using the given vocabulary. 287 | 288 | For example: 289 | input = "unaffable" 290 | output = ["un", "##aff", "##able"] 291 | 292 | Args: 293 | text: A single token or whitespace separated tokens. This should have 294 | already been passed through `BasicTokenizer`. 295 | 296 | Returns: 297 | A list of wordpiece tokens. 298 | """ 299 | 300 | output_tokens = [] 301 | for token in whitespace_tokenize(text): 302 | chars = list(token) 303 | if len(chars) > self.max_input_chars_per_word: 304 | output_tokens.append(self.unk_token) 305 | continue 306 | 307 | is_bad = False 308 | start = 0 309 | sub_tokens = [] 310 | while start < len(chars): 311 | end = len(chars) 312 | cur_substr = None 313 | while start < end: 314 | substr = "".join(chars[start:end]) 315 | if start > 0: 316 | substr = "##" + substr 317 | if substr in self.vocab: 318 | cur_substr = substr 319 | break 320 | end -= 1 321 | if cur_substr is None: 322 | is_bad = True 323 | break 324 | sub_tokens.append(cur_substr) 325 | start = end 326 | 327 | if is_bad: 328 | output_tokens.append(self.unk_token) 329 | else: 330 | output_tokens.extend(sub_tokens) 331 | return output_tokens 332 | 333 | 334 | def _is_whitespace(char): 335 | """Checks whether `chars` is a whitespace character.""" 336 | # \t, \n, and \r are technically contorl characters but we treat them 337 | # as whitespace since they are generally considered as such. 338 | if char == " " or char == "\t" or char == "\n" or char == "\r": 339 | return True 340 | cat = unicodedata.category(char) 341 | if cat == "Zs": 342 | return True 343 | return False 344 | 345 | 346 | def _is_control(char): 347 | """Checks whether `chars` is a control character.""" 348 | # These are technically control characters but we count them as whitespace 349 | # characters. 350 | if char == "\t" or char == "\n" or char == "\r": 351 | return False 352 | cat = unicodedata.category(char) 353 | if cat.startswith("C"): 354 | return True 355 | return False 356 | 357 | 358 | def _is_punctuation(char): 359 | """Checks whether `chars` is a punctuation character.""" 360 | cp = ord(char) 361 | # We treat all non-letter/number ASCII as punctuation. 362 | # Characters such as "^", "$", and "`" are not in the Unicode 363 | # Punctuation class but we treat them as punctuation anyways, for 364 | # consistency. 365 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 366 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 367 | return True 368 | cat = unicodedata.category(char) 369 | if cat.startswith("P"): 370 | return True 371 | return False 372 | -------------------------------------------------------------------------------- /DataGen.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch 3 | from torch import nn 4 | 5 | import pytorch_pretrained_bert as _bert 6 | from random import shuffle 7 | 8 | import numpy as np 9 | 10 | #from dougu import flatten, lines 11 | 12 | 13 | _device = torch.device("cpu") 14 | 15 | class DataGenerator: 16 | 17 | MASK = "[MASK]" 18 | CLS = "[CLS]" 19 | SEP = "[SEP]" 20 | ner_list = ['O', 'I-CAT', 'B-TIM', 'I-ETT', 'B-RLT', 'B-ETT', 'B-VAR', 'I-RLT', 'I-TIM', 'B-CAT', 'I-VAR', 'X', '[CLS]', '[SEP]'] 21 | rel_list = ["SP", "N", "ST", "PO", "SC"] 22 | 23 | def __init__(self, model, model_name, device=None, half_precision=False, rel_max_len=32): 24 | self.model_name = model_name 25 | self.device = device or _device 26 | do_lower_case = "uncased" in model_name 27 | self.tokenizer = _bert.BertTokenizer.from_pretrained(self.model_name, do_lower_case=do_lower_case) 28 | 29 | maybe_model_wrapper = model.from_pretrained(model_name).to(device=self.device) 30 | try: 31 | self.model = maybe_model_wrapper.bert 32 | except AttributeError: 33 | self.model = maybe_model_wrapper 34 | if half_precision: 35 | self.model.half() 36 | self.max_len = \ 37 | self.model.embeddings.position_embeddings.weight.size(0) 38 | self.dim = self.model.embeddings.position_embeddings.weight.size(1) 39 | self.rel_max_len = rel_max_len 40 | def tokenize(self, text, masked_idxs=None): 41 | tokenized_text = self.tokenizer.tokenize(text) 42 | if masked_idxs is not None: 43 | for idx in masked_idxs: 44 | tokenized_text[idx] = self.MASK 45 | tokenized = [self.CLS] + tokenized_text + [self.SEP] 46 | return tokenized 47 | 48 | def tokenize_to_ids(self, text, masked_idxs=None, pad=True): 49 | tokens = self.tokenize(text, masked_idxs) 50 | return self.convert_tokens_to_ids(tokens, pad=pad) 51 | 52 | def convert_tokens_to_ids(self, tokens, pad=True): 53 | token_ids = self.tokenizer.convert_tokens_to_ids(tokens) 54 | ids = torch.tensor([token_ids]).to(device=self.device) 55 | assert ids.size(1) < self.max_len 56 | if pad: 57 | padded_ids = torch.zeros(1, self.max_len).to(ids) 58 | padded_ids[0, :ids.size(1)] = ids 59 | mask = torch.zeros(1, self.max_len).to(ids) 60 | mask[0, :ids.size(1)] = 1 61 | 62 | return padded_ids, mask 63 | else: 64 | return ids 65 | 66 | def padding_labels(self, true_labels, pad=True): 67 | true_labels = torch.tensor([true_labels]).to(device=self.device) 68 | assert true_labels.size(1) < self.max_len 69 | if pad: 70 | 71 | padded_labels = torch.zeros(1, self.max_len).to(true_labels) 72 | padded_labels[0, :true_labels.size(1)] = true_labels 73 | return padded_labels 74 | else: 75 | return_true_labels = torch.zeros(1, true_labels.size(1)) 76 | return_true_labels[0,:] = true_labels 77 | return return_true_labels 78 | 79 | def parseData(self, file_name): 80 | data = {} 81 | with open(file_name, "r") as f: 82 | current_id = "" 83 | for line in f: 84 | if len(line) <= 2: 85 | continue 86 | if "#" in line: 87 | current_id = line.strip() 88 | data[current_id] = {} 89 | data[current_id]["tokens"] = [] 90 | data[current_id]["label_types"] = [] 91 | data[current_id]["index_rels"] = [] 92 | data[current_id]["label_rels"] = [] 93 | else: 94 | info = line.split("\t") 95 | 96 | data[current_id]["tokens"].append(info[1]) 97 | data[current_id]["label_types"].append(info[2]) 98 | 99 | index_rels_info = info[4].strip().replace("[","").replace("]","").replace(" ", "").split(",") 100 | data[current_id]["index_rels"].append([int(i) for i in index_rels_info]) 101 | 102 | label_rels_info = info[3].strip().replace("[","").replace("]","").replace(" ", "").split(",") 103 | data[current_id]["label_rels"].append([str(i).replace("'", "") for i in label_rels_info]) 104 | return data 105 | 106 | def convertData(self, original_data, padding=True): 107 | ver_1_data = {} 108 | 109 | for id in original_data: 110 | ver_1_item = {} 111 | ver_1_item["tokens"] = [] 112 | ver_1_item["label_types"] = original_data[id]["label_types"] 113 | ver_1_item["index_rels"] = original_data[id]["index_rels"] 114 | ver_1_item["label_rels"] = original_data[id]["label_rels"] 115 | shift_idx = 0 116 | for idx, orig_token in enumerate(original_data[id]["tokens"]): 117 | sub_orig_tokens = orig_token.split("_") 118 | ## tokens 119 | ver_1_item["tokens"] += sub_orig_tokens 120 | if len(sub_orig_tokens) > 1: 121 | ## label_types 122 | orig_item_label_type = ver_1_item["label_types"][idx + shift_idx].replace("B-","").replace("I-", "") 123 | if ver_1_item["label_types"][idx + shift_idx] == "O": 124 | ver_1_item["label_types"] = ver_1_item["label_types"][:idx + shift_idx + 1] + ["O"] * (len(sub_orig_tokens) - 1) + ver_1_item["label_types"][idx+shift_idx + 1:] 125 | else: 126 | ver_1_item["label_types"] = ver_1_item["label_types"][:idx + shift_idx + 1] + ["I-" + orig_item_label_type] * (len(sub_orig_tokens) - 1) + ver_1_item["label_types"][idx+shift_idx + 1:] 127 | ## label_rels 128 | ver_1_item["label_rels"] = ver_1_item["label_rels"][:idx + shift_idx] + [["N"]] * (len(sub_orig_tokens) - 1) + ver_1_item["label_rels"][idx + shift_idx:] 129 | 130 | ## index_rels 131 | prev_index_rels = [] # ver_1_item["label_rels"][:idx + shift_idx] 132 | next_index_rels = [] # ver_1_item["label_rels"][idx+shift_idx:] 133 | 134 | for index_rels in ver_1_item["index_rels"][:idx + shift_idx]: 135 | prev_index_rel = [] 136 | for index_rel in index_rels: 137 | if index_rel >= idx + shift_idx: 138 | prev_index_rel.append(index_rel + len(sub_orig_tokens) - 1) 139 | else: 140 | prev_index_rel.append(index_rel) 141 | prev_index_rels.append(prev_index_rel) 142 | 143 | for index_rels in ver_1_item["index_rels"][idx+shift_idx:]: 144 | next_index_rel = [] 145 | for index_rel in index_rels: 146 | if index_rel >= idx + shift_idx: 147 | next_index_rel.append(index_rel + len(sub_orig_tokens) - 1) 148 | else: 149 | next_index_rel.append(index_rel) 150 | 151 | next_index_rels.append(next_index_rel) 152 | 153 | ver_1_item["index_rels"] = prev_index_rels + [[(idx + shift_idx + i)] for i in range(len(sub_orig_tokens) - 1)] + next_index_rels 154 | shift_idx += len(sub_orig_tokens) - 1 155 | ver_1_data[id] = ver_1_item 156 | 157 | 158 | bert_data = {} 159 | for id in ver_1_data: 160 | bert_item = {} 161 | bert_item["subword_ids"] = [] 162 | bert_item["mask"] = [] 163 | bert_item["token_starts"] = [] 164 | 165 | bert_item["label_types"] = ver_1_data[id]["label_types"] 166 | bert_item["index_rels"] = ver_1_data[id]["index_rels"] 167 | bert_item["label_rels"] = ver_1_data[id]["label_rels"] 168 | 169 | 170 | shift_idx = 0 171 | 172 | bert_tokens = [] 173 | bert_token_starts = [] 174 | for idx, ver_1_token in enumerate(ver_1_data[id]["tokens"]): 175 | sub_ver_1_tokens = self.tokenizer.tokenize(ver_1_token) 176 | bert_token_starts.append(1 + len(sub_ver_1_tokens)) 177 | ## tokens 178 | bert_tokens += sub_ver_1_tokens 179 | if len(sub_ver_1_tokens) > 1: 180 | ## label_types 181 | bert_item["label_types"] = bert_item["label_types"][:idx + shift_idx + 1] + ["X"] * (len(sub_ver_1_tokens) - 1) + bert_item["label_types"][idx+shift_idx + 1:] 182 | 183 | ## label_rels 184 | bert_item["label_rels"] = bert_item["label_rels"][:idx + shift_idx] + [["N"]] * (len(sub_ver_1_tokens) - 1) + bert_item["label_rels"][idx + shift_idx:] 185 | ## index_rels 186 | prev_index_rels = [] # bert_item["label_rels"][:idx + shift_idx] 187 | next_index_rels = [] # bert_item["label_rels"][idx+shift_idx:] 188 | for index_rels in bert_item["index_rels"][:idx + shift_idx]: 189 | prev_index_rel = [] 190 | for index_rel in index_rels: 191 | if index_rel >= idx + shift_idx: 192 | prev_index_rel.append(index_rel + len(sub_ver_1_tokens) - 1) 193 | else: 194 | prev_index_rel.append(index_rel) 195 | prev_index_rels.append(prev_index_rel) 196 | 197 | for index_rels in bert_item["index_rels"][idx+shift_idx:]: 198 | next_index_rel = [] 199 | for index_rel in index_rels: 200 | if index_rel >= idx + shift_idx: 201 | next_index_rel.append(index_rel + len(sub_ver_1_tokens) - 1) 202 | else: 203 | next_index_rel.append(index_rel) 204 | 205 | next_index_rels.append(next_index_rel) 206 | 207 | bert_item["index_rels"] = prev_index_rels + [[(idx + shift_idx + i)] for i in range(len(sub_ver_1_tokens) - 1)] + next_index_rels 208 | shift_idx += len(sub_ver_1_tokens) - 1 209 | 210 | ###################################################### 211 | bert_tokens = [self.CLS] + bert_tokens + [self.SEP] 212 | bert_item["subword_ids"], bert_item["mask"] = self.convert_tokens_to_ids(bert_tokens) 213 | 214 | bert_item["label_types"] = [self.CLS] + bert_item["label_types"] + [self.SEP] 215 | 216 | bert_item["label_rels"] = [["N"]] + bert_item["label_rels"] + [["N"]] 217 | 218 | bert_item_index_rels_ = [] 219 | for i in bert_item["index_rels"]: 220 | row = [] 221 | for j in i: 222 | row.append(j+1) 223 | bert_item_index_rels_.append(row) 224 | bert_item["index_rels"] = [[0]] + bert_item_index_rels_ + [[len(bert_item_index_rels_) + 1]] 225 | 226 | bert_item["token_starts"] = torch.zeros(1, self.max_len).to(bert_item["subword_ids"]) 227 | bert_item["token_starts"][0, bert_item["token_starts"]] = 1 228 | 229 | bert_item_label_types = [] 230 | for label_type in bert_item["label_types"]: 231 | bert_item_label_types.append(self.ner_list.index(label_type)) 232 | 233 | bert_item["label_types"] = self.padding_labels(bert_item_label_types, padding) 234 | 235 | bert_item["rel_matrix"] = self.convert_rel_ids(bert_item["index_rels"], bert_item["label_rels"]) 236 | 237 | bert_item["token_type_ids"] = torch.zeros(bert_item["mask"].size(0), bert_item["mask"].size(1)) 238 | bert_item["token_type_ids"][:, :len(bert_tokens)] = 1 239 | bert_item["token_type_ids"] = bert_item["token_type_ids"].long() 240 | ################################################################# 241 | bert_data[id] = bert_item 242 | 243 | 244 | return bert_data 245 | 246 | def convert_rel_ids(self, index_rels, label_rels): 247 | num_rels = len(self.rel_list) 248 | rel_matrix = torch.zeros(1, self.max_len, self.rel_max_len * num_rels) 249 | 250 | for i, (idxs, labels) in enumerate(zip(index_rels, label_rels)): 251 | for idx, label in zip(idxs, labels): 252 | rel_matrix[0, i, idx * num_rels + self.rel_list.index(label)] = 1 253 | return rel_matrix 254 | 255 | def subword_tokenize_to_ids(self, tokens): 256 | subwords, token_start_idxs = self.subword_tokenize(tokens) 257 | subword_ids, mask = self.convert_tokens_to_ids(subwords) 258 | token_starts = torch.zeros(1, self.max_len).to(subword_ids) 259 | token_starts[0, token_start_idxs] = 1 260 | return subword_ids, mask, token_starts 261 | 262 | def segment_ids(self, segment1_len, segment2_len): 263 | ids = [0] * segment1_len + [1] * segment2_len 264 | return torch.tensor([ids]).to(device=self.device) 265 | 266 | def get_featurized_sentence(self, sentence, padding=True): 267 | tokens = self.tokenize(sentence) 268 | if padding: 269 | input_ids, mask = self.convert_tokens_to_ids (tokens, padding) 270 | else: 271 | input_ids = self.convert_tokens_to_ids (tokens, padding) 272 | 273 | input_types = torch.zeros(input_ids.size(0), input_ids.size(1)) 274 | input_types[:, :len(tokens)] = 1 275 | 276 | empty_ners = torch.zeros(input_ids.size(0), input_ids.size(1)) 277 | empty_rels = torch.zeros(input_ids.size(0), input_ids.size(1), len(self.rel_list)) 278 | return input_ids, mask, input_types, empty_ners, empty_rels 279 | 280 | def get_featurized_sentences(self, file_name, padding=True): 281 | original_data = self.parseData(file_name) 282 | bertData = self.convertData(original_data, padding) 283 | 284 | featurized_sentences = [] 285 | for id in bertData: 286 | features = {} 287 | features["bert_ids"], features["bert_mask"], features["bert_token_starts"], features["ett_tags"], features["rel_matrix"], features["token_type_ids"] = \ 288 | bertData[id]["subword_ids"], bertData[id]["mask"], bertData[id]["token_starts"], bertData[id]["label_types"], bertData[id]["rel_matrix"], bertData[id]["token_type_ids"] 289 | featurized_sentences.append(features) 290 | 291 | self.num_samples = len(featurized_sentences) 292 | return featurized_sentences 293 | 294 | def get_generator(self, file_name, batch_size=32, padding=True, is_shuffle=True): 295 | batch_input_ids = [] 296 | batch_mask_ids = [] 297 | batch_ner_ids = [] 298 | batch_input_type_ids = [] 299 | batch_rel_matrix = [] 300 | featurized_sentences = self.get_featurized_sentences(file_name, padding) 301 | if is_shuffle: 302 | shuffle(featurized_sentences) 303 | for idx, features in enumerate(featurized_sentences): 304 | batch_input_ids.append(features["bert_ids"]) 305 | batch_input_type_ids.append(features["token_type_ids"]) 306 | batch_mask_ids.append(features["bert_mask"]) 307 | batch_ner_ids.append(features["ett_tags"]) 308 | batch_rel_matrix.append(features["rel_matrix"]) 309 | if idx % batch_size == batch_size-1: 310 | return_batch_input_ids = torch.cat(batch_input_ids, dim=0) 311 | return_batch_mask_ids = torch.cat(batch_mask_ids, dim=0) 312 | return_batch_input_type_ids = torch.cat(batch_input_type_ids, dim=0) 313 | return_batch_label_ids = torch.cat(batch_ner_ids, dim=0) 314 | return_batch_rel_matrix = torch.cat(batch_rel_matrix, dim=0) 315 | 316 | batch_input_ids = [] 317 | batch_mask_ids = [] 318 | batch_ner_ids = [] 319 | batch_input_type_ids = [] 320 | batch_rel_matrix = [] 321 | 322 | yield return_batch_input_ids, return_batch_mask_ids, return_batch_input_type_ids, return_batch_label_ids, return_batch_rel_matrix 323 | 324 | if len(batch_input_ids) != 0: 325 | yield torch.cat(batch_input_ids, dim=0), torch.cat(batch_mask_ids, dim=0), torch.cat(batch_input_type_ids, dim=0), torch.cat(batch_ner_ids, dim=0), torch.cat(batch_rel_matrix, dim=0) 326 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/modeling_transfo_xl_utilities.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Utilities for PyTorch Transformer XL model. 17 | Directly adapted from https://github.com/kimiyoung/transformer-xl. 18 | """ 19 | 20 | from collections import defaultdict 21 | 22 | import numpy as np 23 | 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | 28 | # CUDA_MAJOR = int(torch.version.cuda.split('.')[0]) 29 | # CUDA_MINOR = int(torch.version.cuda.split('.')[1]) 30 | 31 | class ProjectedAdaptiveLogSoftmax(nn.Module): 32 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, 33 | keep_order=False): 34 | super(ProjectedAdaptiveLogSoftmax, self).__init__() 35 | 36 | self.n_token = n_token 37 | self.d_embed = d_embed 38 | self.d_proj = d_proj 39 | 40 | self.cutoffs = cutoffs + [n_token] 41 | self.cutoff_ends = [0] + self.cutoffs 42 | self.div_val = div_val 43 | 44 | self.shortlist_size = self.cutoffs[0] 45 | self.n_clusters = len(self.cutoffs) - 1 46 | self.head_size = self.shortlist_size + self.n_clusters 47 | 48 | if self.n_clusters > 0: 49 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed)) 50 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) 51 | 52 | self.out_layers = nn.ModuleList() 53 | self.out_projs = nn.ParameterList() 54 | 55 | if div_val == 1: 56 | for i in range(len(self.cutoffs)): 57 | if d_proj != d_embed: 58 | self.out_projs.append( 59 | nn.Parameter(torch.Tensor(d_proj, d_embed)) 60 | ) 61 | else: 62 | self.out_projs.append(None) 63 | 64 | self.out_layers.append(nn.Linear(d_embed, n_token)) 65 | else: 66 | for i in range(len(self.cutoffs)): 67 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] 68 | d_emb_i = d_embed // (div_val ** i) 69 | 70 | self.out_projs.append( 71 | nn.Parameter(torch.Tensor(d_proj, d_emb_i)) 72 | ) 73 | 74 | self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx)) 75 | 76 | self.keep_order = keep_order 77 | 78 | def _compute_logit(self, hidden, weight, bias, proj): 79 | if proj is None: 80 | logit = F.linear(hidden, weight, bias=bias) 81 | else: 82 | # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1: 83 | proj_hid = F.linear(hidden, proj.t().contiguous()) 84 | logit = F.linear(proj_hid, weight, bias=bias) 85 | # else: 86 | # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t())) 87 | # if bias is not None: 88 | # logit = logit + bias 89 | 90 | return logit 91 | 92 | def forward(self, hidden, target=None, keep_order=False): 93 | ''' 94 | Params: 95 | hidden :: [len*bsz x d_proj] 96 | target :: [len*bsz] 97 | Return: 98 | if target is None: 99 | out :: [len*bsz] Negative log likelihood 100 | else: 101 | out :: [len*bsz x n_tokens] log probabilities of tokens over the vocabulary 102 | We could replace this implementation by the native PyTorch one 103 | if their's had an option to set bias on all clusters in the native one. 104 | here: https://github.com/pytorch/pytorch/blob/dbe6a7a9ff1a364a8706bf5df58a1ca96d2fd9da/torch/nn/modules/adaptive.py#L138 105 | ''' 106 | 107 | if target is not None: 108 | target = target.view(-1) 109 | if hidden.size(0) != target.size(0): 110 | raise RuntimeError('Input and target should have the same size ' 111 | 'in the batch dimension.') 112 | 113 | if self.n_clusters == 0: 114 | logit = self._compute_logit(hidden, self.out_layers[0].weight, 115 | self.out_layers[0].bias, self.out_projs[0]) 116 | if target is not None: 117 | output = -F.log_softmax(logit, dim=-1) \ 118 | .gather(1, target.unsqueeze(1)).squeeze(1) 119 | else: 120 | output = F.log_softmax(logit, dim=-1) 121 | else: 122 | # construct weights and biases 123 | weights, biases = [], [] 124 | for i in range(len(self.cutoffs)): 125 | if self.div_val == 1: 126 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 127 | weight_i = self.out_layers[0].weight[l_idx:r_idx] 128 | bias_i = self.out_layers[0].bias[l_idx:r_idx] 129 | else: 130 | weight_i = self.out_layers[i].weight 131 | bias_i = self.out_layers[i].bias 132 | 133 | if i == 0: 134 | weight_i = torch.cat( 135 | [weight_i, self.cluster_weight], dim=0) 136 | bias_i = torch.cat( 137 | [bias_i, self.cluster_bias], dim=0) 138 | 139 | weights.append(weight_i) 140 | biases.append(bias_i) 141 | 142 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] 143 | 144 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) 145 | head_logprob = F.log_softmax(head_logit, dim=1) 146 | 147 | if target is None: 148 | out = hidden.new_empty((head_logit.size(0), self.n_token)) 149 | else: 150 | out = torch.zeros_like(target, dtype=hidden.dtype, device=hidden.device) 151 | 152 | offset = 0 153 | cutoff_values = [0] + self.cutoffs 154 | for i in range(len(cutoff_values) - 1): 155 | l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] 156 | 157 | if target is not None: 158 | mask_i = (target >= l_idx) & (target < r_idx) 159 | indices_i = mask_i.nonzero().squeeze() 160 | 161 | if indices_i.numel() == 0: 162 | continue 163 | 164 | target_i = target.index_select(0, indices_i) - l_idx 165 | head_logprob_i = head_logprob.index_select(0, indices_i) 166 | hidden_i = hidden.index_select(0, indices_i) 167 | else: 168 | hidden_i = hidden 169 | 170 | if i == 0: 171 | if target is not None: 172 | logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1) 173 | else: 174 | out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]] 175 | else: 176 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] 177 | 178 | tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) 179 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 180 | cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster 181 | if target is not None: 182 | logprob_i = head_logprob_i[:, cluster_prob_idx] \ 183 | + tail_logprob_i.gather(1, target_i[:, None]).squeeze(1) 184 | else: 185 | logprob_i = head_logprob[:, cluster_prob_idx, None] + tail_logprob_i 186 | out[:, l_idx:r_idx] = logprob_i 187 | 188 | if target is not None: 189 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: 190 | out.index_copy_(0, indices_i, -logprob_i) 191 | else: 192 | out[offset:offset+logprob_i.size(0)].copy_(-logprob_i) 193 | offset += logprob_i.size(0) 194 | 195 | return out 196 | 197 | 198 | def log_prob(self, hidden): 199 | r""" Computes log probabilities for all :math:`n\_classes` 200 | From: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/adaptive.py 201 | Args: 202 | hidden (Tensor): a minibatch of examples 203 | Returns: 204 | log-probabilities of for each class :math:`c` 205 | in range :math:`0 <= c <= n\_classes`, where :math:`n\_classes` is a 206 | parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor. 207 | Shape: 208 | - Input: :math:`(N, in\_features)` 209 | - Output: :math:`(N, n\_classes)` 210 | """ 211 | if self.n_clusters == 0: 212 | logit = self._compute_logit(hidden, self.out_layers[0].weight, 213 | self.out_layers[0].bias, self.out_projs[0]) 214 | return F.log_softmax(logit, dim=-1) 215 | else: 216 | # construct weights and biases 217 | weights, biases = [], [] 218 | for i in range(len(self.cutoffs)): 219 | if self.div_val == 1: 220 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 221 | weight_i = self.out_layers[0].weight[l_idx:r_idx] 222 | bias_i = self.out_layers[0].bias[l_idx:r_idx] 223 | else: 224 | weight_i = self.out_layers[i].weight 225 | bias_i = self.out_layers[i].bias 226 | 227 | if i == 0: 228 | weight_i = torch.cat( 229 | [weight_i, self.cluster_weight], dim=0) 230 | bias_i = torch.cat( 231 | [bias_i, self.cluster_bias], dim=0) 232 | 233 | weights.append(weight_i) 234 | biases.append(bias_i) 235 | 236 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] 237 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) 238 | 239 | out = hidden.new_empty((head_logit.size(0), self.n_token)) 240 | head_logprob = F.log_softmax(head_logit, dim=1) 241 | 242 | cutoff_values = [0] + self.cutoffs 243 | for i in range(len(cutoff_values) - 1): 244 | start_idx, stop_idx = cutoff_values[i], cutoff_values[i + 1] 245 | 246 | if i == 0: 247 | out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]] 248 | else: 249 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] 250 | 251 | tail_logit_i = self._compute_logit(hidden, weight_i, bias_i, proj_i) 252 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 253 | 254 | logprob_i = head_logprob[:, -i] + tail_logprob_i 255 | out[:, start_idx, stop_idx] = logprob_i 256 | 257 | return out 258 | 259 | 260 | class LogUniformSampler(object): 261 | def __init__(self, range_max, n_sample): 262 | """ 263 | Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 264 | `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 265 | 266 | expected count can be approximated by 1 - (1 - p)^n 267 | and we use a numerically stable version -expm1(num_tries * log1p(-p)) 268 | 269 | Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run 270 | """ 271 | with torch.no_grad(): 272 | self.range_max = range_max 273 | log_indices = torch.arange(1., range_max+2., 1.).log_() 274 | self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 275 | # print('P', self.dist.numpy().tolist()[-30:]) 276 | 277 | self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float() 278 | 279 | self.n_sample = n_sample 280 | 281 | def sample(self, labels): 282 | """ 283 | labels: [b1, b2] 284 | Return 285 | true_log_probs: [b1, b2] 286 | samp_log_probs: [n_sample] 287 | neg_samples: [n_sample] 288 | """ 289 | 290 | # neg_samples = torch.empty(0).long() 291 | n_sample = self.n_sample 292 | n_tries = 2 * n_sample 293 | 294 | with torch.no_grad(): 295 | neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique() 296 | device = labels.device 297 | neg_samples = neg_samples.to(device) 298 | true_log_probs = self.log_q[labels].to(device) 299 | samp_log_probs = self.log_q[neg_samples].to(device) 300 | return true_log_probs, samp_log_probs, neg_samples 301 | 302 | def sample_logits(embedding, bias, labels, inputs, sampler): 303 | """ 304 | embedding: an nn.Embedding layer 305 | bias: [n_vocab] 306 | labels: [b1, b2] 307 | inputs: [b1, b2, n_emb] 308 | sampler: you may use a LogUniformSampler 309 | Return 310 | logits: [b1, b2, 1 + n_sample] 311 | """ 312 | true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels) 313 | n_sample = neg_samples.size(0) 314 | b1, b2 = labels.size(0), labels.size(1) 315 | all_ids = torch.cat([labels.view(-1), neg_samples]) 316 | all_w = embedding(all_ids) 317 | true_w = all_w[: -n_sample].view(b1, b2, -1) 318 | sample_w = all_w[- n_sample:].view(n_sample, -1) 319 | 320 | all_b = bias[all_ids] 321 | true_b = all_b[: -n_sample].view(b1, b2) 322 | sample_b = all_b[- n_sample:] 323 | 324 | hit = (labels[:, :, None] == neg_samples).detach() 325 | 326 | true_logits = torch.einsum('ijk,ijk->ij', 327 | [true_w, inputs]) + true_b - true_log_probs 328 | sample_logits = torch.einsum('lk,ijk->ijl', 329 | [sample_w, inputs]) + sample_b - samp_log_probs 330 | sample_logits.masked_fill_(hit, -1e30) 331 | logits = torch.cat([true_logits[:, :, None], sample_logits], -1) 332 | 333 | return logits 334 | 335 | 336 | # class LogUniformSampler(object): 337 | # def __init__(self, range_max, unique=False): 338 | # """ 339 | # Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 340 | # `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 341 | # """ 342 | # self.range_max = range_max 343 | # log_indices = torch.arange(1., range_max+2., 1.).log_() 344 | # self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 345 | 346 | # self.unique = unique 347 | 348 | # if self.unique: 349 | # self.exclude_mask = torch.ByteTensor(range_max).fill_(0) 350 | 351 | # def sample(self, n_sample, labels): 352 | # pos_sample, new_labels = labels.unique(return_inverse=True) 353 | # n_pos_sample = pos_sample.size(0) 354 | # n_neg_sample = n_sample - n_pos_sample 355 | 356 | # if self.unique: 357 | # self.exclude_mask.index_fill_(0, pos_sample, 1) 358 | # sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0) 359 | # self.exclude_mask.index_fill_(0, pos_sample, 0) 360 | # else: 361 | # sample_dist = self.dist 362 | 363 | # neg_sample = torch.multinomial(sample_dist, n_neg_sample) 364 | 365 | # sample = torch.cat([pos_sample, neg_sample]) 366 | # sample_prob = self.dist[sample] 367 | 368 | # return new_labels, sample, sample_prob 369 | 370 | 371 | if __name__ == '__main__': 372 | S, B = 3, 4 373 | n_vocab = 10000 374 | n_sample = 5 375 | H = 32 376 | 377 | labels = torch.LongTensor(S, B).random_(0, n_vocab) 378 | 379 | # sampler = LogUniformSampler(n_vocab, unique=False) 380 | # new_labels, sample, sample_prob = sampler.sample(n_sample, labels) 381 | 382 | sampler = LogUniformSampler(n_vocab, n_sample)#, unique=True) 383 | # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels) 384 | 385 | # print('true_probs', true_probs.numpy().tolist()) 386 | # print('samp_probs', samp_probs.numpy().tolist()) 387 | # print('neg_samples', neg_samples.numpy().tolist()) 388 | 389 | # print('sum', torch.sum(sampler.dist).item()) 390 | 391 | # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item() 392 | 393 | embedding = nn.Embedding(n_vocab, H) 394 | bias = torch.zeros(n_vocab) 395 | inputs = torch.Tensor(S, B, H).normal_() 396 | 397 | logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample) 398 | print('logits', logits.detach().numpy().tolist()) 399 | print('logits shape', logits.size()) 400 | print('out_labels', out_labels.detach().numpy().tolist()) 401 | print('out_labels shape', out_labels.size()) 402 | 403 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization_transfo_xl.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Tokenization classes for Transformer XL model. 17 | Adapted from https://github.com/kimiyoung/transformer-xl. 18 | """ 19 | from __future__ import (absolute_import, division, print_function, 20 | unicode_literals) 21 | 22 | import glob 23 | import logging 24 | import os 25 | import sys 26 | from collections import Counter, OrderedDict 27 | from io import open 28 | import unicodedata 29 | 30 | import torch 31 | import numpy as np 32 | 33 | from .file_utils import cached_path 34 | 35 | if sys.version_info[0] == 2: 36 | import cPickle as pickle 37 | else: 38 | import pickle 39 | 40 | 41 | logger = logging.getLogger(__name__) 42 | 43 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 44 | 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin", 45 | } 46 | VOCAB_NAME = 'vocab.bin' 47 | 48 | PRETRAINED_CORPUS_ARCHIVE_MAP = { 49 | 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin", 50 | } 51 | CORPUS_NAME = 'corpus.bin' 52 | 53 | class TransfoXLTokenizer(object): 54 | """ 55 | Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl 56 | """ 57 | @classmethod 58 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 59 | """ 60 | Instantiate a TransfoXLTokenizer. 61 | The TransfoXLTokenizer. 62 | """ 63 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 64 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 65 | else: 66 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 67 | # redirect to the cache, if necessary 68 | try: 69 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 70 | except EnvironmentError: 71 | logger.error( 72 | "Model name '{}' was not found in model name list ({}). " 73 | "We assumed '{}' was a path or url but couldn't find files {} " 74 | "at this path or url.".format( 75 | pretrained_model_name_or_path, 76 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 77 | pretrained_model_name_or_path, 78 | vocab_file)) 79 | return None 80 | if resolved_vocab_file == vocab_file: 81 | logger.info("loading vocabulary file {}".format(vocab_file)) 82 | else: 83 | logger.info("loading vocabulary file {} from cache at {}".format( 84 | vocab_file, resolved_vocab_file)) 85 | 86 | # Instantiate tokenizer. 87 | tokenizer = cls(*inputs, **kwargs) 88 | vocab_dict = torch.load(resolved_vocab_file) 89 | for key, value in vocab_dict.items(): 90 | tokenizer.__dict__[key] = value 91 | return tokenizer 92 | 93 | def __init__(self, special=[], min_freq=0, max_size=None, lower_case=False, 94 | delimiter=None, vocab_file=None, never_split=("", "", "")): 95 | self.counter = Counter() 96 | self.special = special 97 | self.min_freq = min_freq 98 | self.max_size = max_size 99 | self.lower_case = lower_case 100 | self.delimiter = delimiter 101 | self.vocab_file = vocab_file 102 | self.never_split = never_split 103 | 104 | def count_file(self, path, verbose=False, add_eos=False): 105 | if verbose: print('counting file {} ...'.format(path)) 106 | assert os.path.exists(path) 107 | 108 | sents = [] 109 | with open(path, 'r', encoding='utf-8') as f: 110 | for idx, line in enumerate(f): 111 | if verbose and idx > 0 and idx % 500000 == 0: 112 | print(' line {}'.format(idx)) 113 | symbols = self.tokenize(line, add_eos=add_eos) 114 | self.counter.update(symbols) 115 | sents.append(symbols) 116 | 117 | return sents 118 | 119 | def count_sents(self, sents, verbose=False): 120 | """ 121 | sents : a list of sentences, each a list of tokenized symbols 122 | """ 123 | if verbose: print('counting {} sents ...'.format(len(sents))) 124 | for idx, symbols in enumerate(sents): 125 | if verbose and idx > 0 and idx % 500000 == 0: 126 | print(' line {}'.format(idx)) 127 | self.counter.update(symbols) 128 | 129 | def _build_from_file(self, vocab_file): 130 | self.idx2sym = [] 131 | self.sym2idx = OrderedDict() 132 | 133 | with open(vocab_file, 'r', encoding='utf-8') as f: 134 | for line in f: 135 | symb = line.strip().split()[0] 136 | self.add_symbol(symb) 137 | if '' in self.sym2idx: 138 | self.unk_idx = self.sym2idx[''] 139 | elif '' in self.sym2idx: 140 | self.unk_idx = self.sym2idx[''] 141 | else: 142 | raise ValueError('No token in vocabulary') 143 | 144 | def build_vocab(self): 145 | if self.vocab_file: 146 | print('building vocab from {}'.format(self.vocab_file)) 147 | self._build_from_file(self.vocab_file) 148 | print('final vocab size {}'.format(len(self))) 149 | else: 150 | print('building vocab with min_freq={}, max_size={}'.format( 151 | self.min_freq, self.max_size)) 152 | self.idx2sym = [] 153 | self.sym2idx = OrderedDict() 154 | 155 | for sym in self.special: 156 | self.add_special(sym) 157 | 158 | for sym, cnt in self.counter.most_common(self.max_size): 159 | if cnt < self.min_freq: break 160 | self.add_symbol(sym) 161 | 162 | print('final vocab size {} from {} unique tokens'.format( 163 | len(self), len(self.counter))) 164 | 165 | def encode_file(self, path, ordered=False, verbose=False, add_eos=True, 166 | add_double_eos=False): 167 | if verbose: print('encoding file {} ...'.format(path)) 168 | assert os.path.exists(path) 169 | encoded = [] 170 | with open(path, 'r', encoding='utf-8') as f: 171 | for idx, line in enumerate(f): 172 | if verbose and idx > 0 and idx % 500000 == 0: 173 | print(' line {}'.format(idx)) 174 | symbols = self.tokenize(line, add_eos=add_eos, 175 | add_double_eos=add_double_eos) 176 | encoded.append(self.convert_to_tensor(symbols)) 177 | 178 | if ordered: 179 | encoded = torch.cat(encoded) 180 | 181 | return encoded 182 | 183 | def encode_sents(self, sents, ordered=False, verbose=False): 184 | if verbose: print('encoding {} sents ...'.format(len(sents))) 185 | encoded = [] 186 | for idx, symbols in enumerate(sents): 187 | if verbose and idx > 0 and idx % 500000 == 0: 188 | print(' line {}'.format(idx)) 189 | encoded.append(self.convert_to_tensor(symbols)) 190 | 191 | if ordered: 192 | encoded = torch.cat(encoded) 193 | 194 | return encoded 195 | 196 | def add_special(self, sym): 197 | if sym not in self.sym2idx: 198 | self.idx2sym.append(sym) 199 | self.sym2idx[sym] = len(self.idx2sym) - 1 200 | setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym]) 201 | 202 | def add_symbol(self, sym): 203 | if sym not in self.sym2idx: 204 | self.idx2sym.append(sym) 205 | self.sym2idx[sym] = len(self.idx2sym) - 1 206 | 207 | def get_sym(self, idx): 208 | assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx) 209 | return self.idx2sym[idx] 210 | 211 | def get_idx(self, sym): 212 | if sym in self.sym2idx: 213 | return self.sym2idx[sym] 214 | else: 215 | # print('encounter unk {}'.format(sym)) 216 | # assert '' not in sym 217 | if hasattr(self, 'unk_idx'): 218 | return self.sym2idx.get(sym, self.unk_idx) 219 | # Backward compatibility with pre-trained models 220 | elif '' in self.sym2idx: 221 | return self.sym2idx[''] 222 | elif '' in self.sym2idx: 223 | return self.sym2idx[''] 224 | else: 225 | raise ValueError('Token not in vocabulary and no token in vocabulary for replacement') 226 | 227 | def convert_ids_to_tokens(self, indices): 228 | """Converts a sequence of indices in symbols using the vocab.""" 229 | return [self.get_sym(idx) for idx in indices] 230 | 231 | def convert_tokens_to_ids(self, symbols): 232 | """Converts a sequence of symbols into ids using the vocab.""" 233 | return [self.get_idx(sym) for sym in symbols] 234 | 235 | def convert_to_tensor(self, symbols): 236 | return torch.LongTensor(self.convert_tokens_to_ids(symbols)) 237 | 238 | def decode(self, indices, exclude=None): 239 | """Converts a sequence of indices in a string.""" 240 | if exclude is None: 241 | return ' '.join([self.get_sym(idx) for idx in indices]) 242 | else: 243 | return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude]) 244 | 245 | def __len__(self): 246 | return len(self.idx2sym) 247 | 248 | def _run_split_on_punc(self, text): 249 | """Splits punctuation on a piece of text.""" 250 | if text in self.never_split: 251 | return [text] 252 | chars = list(text) 253 | i = 0 254 | start_new_word = True 255 | output = [] 256 | while i < len(chars): 257 | char = chars[i] 258 | if _is_punctuation(char): 259 | output.append([char]) 260 | start_new_word = True 261 | else: 262 | if start_new_word: 263 | output.append([]) 264 | start_new_word = False 265 | output[-1].append(char) 266 | i += 1 267 | 268 | return ["".join(x) for x in output] 269 | 270 | def _run_strip_accents(self, text): 271 | """Strips accents from a piece of text.""" 272 | text = unicodedata.normalize("NFD", text) 273 | output = [] 274 | for char in text: 275 | cat = unicodedata.category(char) 276 | if cat == "Mn": 277 | continue 278 | output.append(char) 279 | return "".join(output) 280 | 281 | def _clean_text(self, text): 282 | """Performs invalid character removal and whitespace cleanup on text.""" 283 | output = [] 284 | for char in text: 285 | cp = ord(char) 286 | if cp == 0 or cp == 0xfffd or _is_control(char): 287 | continue 288 | if _is_whitespace(char): 289 | output.append(" ") 290 | else: 291 | output.append(char) 292 | return "".join(output) 293 | 294 | def whitespace_tokenize(self, text): 295 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 296 | text = text.strip() 297 | if not text: 298 | return [] 299 | if self.delimiter == '': 300 | tokens = text 301 | else: 302 | tokens = text.split(self.delimiter) 303 | return tokens 304 | 305 | def tokenize(self, line, add_eos=False, add_double_eos=False): 306 | line = self._clean_text(line) 307 | line = line.strip() 308 | 309 | symbols = self.whitespace_tokenize(line) 310 | 311 | split_symbols = [] 312 | for symbol in symbols: 313 | if self.lower_case and symbol not in self.never_split: 314 | symbol = symbol.lower() 315 | symbol = self._run_strip_accents(symbol) 316 | split_symbols.extend(self._run_split_on_punc(symbol)) 317 | 318 | if add_double_eos: # lm1b 319 | return [''] + split_symbols + [''] 320 | elif add_eos: 321 | return split_symbols + [''] 322 | else: 323 | return split_symbols 324 | 325 | 326 | class LMOrderedIterator(object): 327 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None): 328 | """ 329 | data -- LongTensor -- the LongTensor is strictly ordered 330 | """ 331 | self.bsz = bsz 332 | self.bptt = bptt 333 | self.ext_len = ext_len if ext_len is not None else 0 334 | 335 | self.device = device 336 | 337 | # Work out how cleanly we can divide the dataset into bsz parts. 338 | self.n_step = data.size(0) // bsz 339 | 340 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 341 | data = data.narrow(0, 0, self.n_step * bsz) 342 | 343 | # Evenly divide the data across the bsz batches. 344 | self.data = data.view(bsz, -1).t().contiguous().to(device) 345 | 346 | # Number of mini-batches 347 | self.n_batch = (self.n_step + self.bptt - 1) // self.bptt 348 | 349 | def get_batch(self, i, bptt=None): 350 | if bptt is None: bptt = self.bptt 351 | seq_len = min(bptt, self.data.size(0) - 1 - i) 352 | 353 | end_idx = i + seq_len 354 | beg_idx = max(0, i - self.ext_len) 355 | 356 | data = self.data[beg_idx:end_idx] 357 | target = self.data[i+1:i+1+seq_len] 358 | 359 | data_out = data.transpose(0, 1).contiguous().to(self.device) 360 | target_out = target.transpose(0, 1).contiguous().to(self.device) 361 | 362 | return data_out, target_out, seq_len 363 | 364 | def get_fixlen_iter(self, start=0): 365 | for i in range(start, self.data.size(0) - 1, self.bptt): 366 | yield self.get_batch(i) 367 | 368 | def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): 369 | max_len = self.bptt + max_deviation * std 370 | i = start 371 | while True: 372 | bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2. 373 | bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std)))) 374 | data, target, seq_len = self.get_batch(i, bptt) 375 | i += seq_len 376 | yield data, target, seq_len 377 | if i >= self.data.size(0) - 2: 378 | break 379 | 380 | def __iter__(self): 381 | return self.get_fixlen_iter() 382 | 383 | 384 | class LMShuffledIterator(object): 385 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False): 386 | """ 387 | data -- list[LongTensor] -- there is no order among the LongTensors 388 | """ 389 | self.data = data 390 | 391 | self.bsz = bsz 392 | self.bptt = bptt 393 | self.ext_len = ext_len if ext_len is not None else 0 394 | 395 | self.device = device 396 | self.shuffle = shuffle 397 | 398 | def get_sent_stream(self): 399 | # index iterator 400 | epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \ 401 | else np.array(range(len(self.data))) 402 | 403 | # sentence iterator 404 | for idx in epoch_indices: 405 | yield self.data[idx] 406 | 407 | def stream_iterator(self, sent_stream): 408 | # streams for each data in the batch 409 | streams = [None] * self.bsz 410 | 411 | data = torch.LongTensor(self.bptt, self.bsz) 412 | target = torch.LongTensor(self.bptt, self.bsz) 413 | 414 | n_retain = 0 415 | 416 | while True: 417 | # data : [n_retain+bptt x bsz] 418 | # target : [bptt x bsz] 419 | data[n_retain:].fill_(-1) 420 | target.fill_(-1) 421 | 422 | valid_batch = True 423 | 424 | for i in range(self.bsz): 425 | n_filled = 0 426 | try: 427 | while n_filled < self.bptt: 428 | if streams[i] is None or len(streams[i]) <= 1: 429 | streams[i] = next(sent_stream) 430 | # number of new tokens to fill in 431 | n_new = min(len(streams[i]) - 1, self.bptt - n_filled) 432 | # first n_retain tokens are retained from last batch 433 | data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \ 434 | streams[i][:n_new] 435 | target[n_filled:n_filled+n_new, i] = \ 436 | streams[i][1:n_new+1] 437 | streams[i] = streams[i][n_new:] 438 | n_filled += n_new 439 | except StopIteration: 440 | valid_batch = False 441 | break 442 | 443 | if not valid_batch: 444 | return 445 | 446 | data_out = data.transpose(0, 1).contiguous().to(self.device) 447 | target_out = target.transpose(0, 1).contiguous().to(self.device) 448 | 449 | yield data_out, target_out, self.bptt 450 | 451 | n_retain = min(data.size(0), self.ext_len) 452 | if n_retain > 0: 453 | data[:n_retain] = data[-n_retain:] 454 | data.resize_(n_retain + self.bptt, data.size(1)) 455 | 456 | def __iter__(self): 457 | # sent_stream is an iterator 458 | sent_stream = self.get_sent_stream() 459 | 460 | for batch in self.stream_iterator(sent_stream): 461 | yield batch 462 | 463 | 464 | class LMMultiFileIterator(LMShuffledIterator): 465 | def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None, 466 | shuffle=False): 467 | 468 | self.paths = paths 469 | self.vocab = vocab 470 | 471 | self.bsz = bsz 472 | self.bptt = bptt 473 | self.ext_len = ext_len if ext_len is not None else 0 474 | 475 | self.device = device 476 | self.shuffle = shuffle 477 | 478 | def get_sent_stream(self, path): 479 | sents = self.vocab.encode_file(path, add_double_eos=True) 480 | if self.shuffle: 481 | np.random.shuffle(sents) 482 | sent_stream = iter(sents) 483 | 484 | return sent_stream 485 | 486 | def __iter__(self): 487 | if self.shuffle: 488 | np.random.shuffle(self.paths) 489 | 490 | for path in self.paths: 491 | # sent_stream is an iterator 492 | sent_stream = self.get_sent_stream(path) 493 | for batch in self.stream_iterator(sent_stream): 494 | yield batch 495 | 496 | 497 | class TransfoXLCorpus(object): 498 | @classmethod 499 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 500 | """ 501 | Instantiate a pre-processed corpus. 502 | """ 503 | vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 504 | if pretrained_model_name_or_path in PRETRAINED_CORPUS_ARCHIVE_MAP: 505 | corpus_file = PRETRAINED_CORPUS_ARCHIVE_MAP[pretrained_model_name_or_path] 506 | else: 507 | corpus_file = os.path.join(pretrained_model_name_or_path, CORPUS_NAME) 508 | # redirect to the cache, if necessary 509 | try: 510 | resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir) 511 | except EnvironmentError: 512 | logger.error( 513 | "Corpus '{}' was not found in corpus list ({}). " 514 | "We assumed '{}' was a path or url but couldn't find files {} " 515 | "at this path or url.".format( 516 | pretrained_model_name_or_path, 517 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 518 | pretrained_model_name_or_path, 519 | corpus_file)) 520 | return None 521 | if resolved_corpus_file == corpus_file: 522 | logger.info("loading corpus file {}".format(corpus_file)) 523 | else: 524 | logger.info("loading corpus file {} from cache at {}".format( 525 | corpus_file, resolved_corpus_file)) 526 | 527 | # Instantiate tokenizer. 528 | corpus = cls(*inputs, **kwargs) 529 | corpus_dict = torch.load(resolved_corpus_file) 530 | for key, value in corpus_dict.items(): 531 | corpus.__dict__[key] = value 532 | corpus.vocab = vocab 533 | if corpus.train is not None: 534 | corpus.train = torch.tensor(corpus.train, dtype=torch.long) 535 | if corpus.valid is not None: 536 | corpus.valid = torch.tensor(corpus.valid, dtype=torch.long) 537 | if corpus.test is not None: 538 | corpus.test = torch.tensor(corpus.test, dtype=torch.long) 539 | return corpus 540 | 541 | def __init__(self, *args, **kwargs): 542 | self.vocab = TransfoXLTokenizer(*args, **kwargs) 543 | self.dataset = None 544 | self.train = None 545 | self.valid = None 546 | self.test = None 547 | 548 | def build_corpus(self, path, dataset): 549 | self.dataset = dataset 550 | 551 | if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']: 552 | self.vocab.count_file(os.path.join(path, 'train.txt')) 553 | self.vocab.count_file(os.path.join(path, 'valid.txt')) 554 | self.vocab.count_file(os.path.join(path, 'test.txt')) 555 | elif self.dataset == 'wt103': 556 | self.vocab.count_file(os.path.join(path, 'train.txt')) 557 | elif self.dataset == 'lm1b': 558 | train_path_pattern = os.path.join( 559 | path, '1-billion-word-language-modeling-benchmark-r13output', 560 | 'training-monolingual.tokenized.shuffled', 'news.en-*') 561 | train_paths = glob.glob(train_path_pattern) 562 | # the vocab will load from file when build_vocab() is called 563 | 564 | self.vocab.build_vocab() 565 | 566 | if self.dataset in ['ptb', 'wt2', 'wt103']: 567 | self.train = self.vocab.encode_file( 568 | os.path.join(path, 'train.txt'), ordered=True) 569 | self.valid = self.vocab.encode_file( 570 | os.path.join(path, 'valid.txt'), ordered=True) 571 | self.test = self.vocab.encode_file( 572 | os.path.join(path, 'test.txt'), ordered=True) 573 | elif self.dataset in ['enwik8', 'text8']: 574 | self.train = self.vocab.encode_file( 575 | os.path.join(path, 'train.txt'), ordered=True, add_eos=False) 576 | self.valid = self.vocab.encode_file( 577 | os.path.join(path, 'valid.txt'), ordered=True, add_eos=False) 578 | self.test = self.vocab.encode_file( 579 | os.path.join(path, 'test.txt'), ordered=True, add_eos=False) 580 | elif self.dataset == 'lm1b': 581 | self.train = train_paths 582 | self.valid = self.vocab.encode_file( 583 | os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True) 584 | self.test = self.vocab.encode_file( 585 | os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True) 586 | 587 | def get_iterator(self, split, *args, **kwargs): 588 | if split == 'train': 589 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: 590 | data_iter = LMOrderedIterator(self.train, *args, **kwargs) 591 | elif self.dataset == 'lm1b': 592 | kwargs['shuffle'] = True 593 | data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs) 594 | elif split in ['valid', 'test']: 595 | data = self.valid if split == 'valid' else self.test 596 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: 597 | data_iter = LMOrderedIterator(data, *args, **kwargs) 598 | elif self.dataset == 'lm1b': 599 | data_iter = LMShuffledIterator(data, *args, **kwargs) 600 | 601 | return data_iter 602 | 603 | 604 | def get_lm_corpus(datadir, dataset): 605 | fn = os.path.join(datadir, 'cache.pt') 606 | fn_pickle = os.path.join(datadir, 'cache.pkl') 607 | if os.path.exists(fn): 608 | print('Loading cached dataset...') 609 | corpus = torch.load(fn_pickle) 610 | elif os.path.exists(fn): 611 | print('Loading cached dataset from pickle...') 612 | with open(fn, "rb") as fp: 613 | corpus = pickle.load(fp) 614 | else: 615 | print('Producing dataset {}...'.format(dataset)) 616 | kwargs = {} 617 | if dataset in ['wt103', 'wt2']: 618 | kwargs['special'] = [''] 619 | kwargs['lower_case'] = False 620 | elif dataset == 'ptb': 621 | kwargs['special'] = [''] 622 | kwargs['lower_case'] = True 623 | elif dataset == 'lm1b': 624 | kwargs['special'] = [] 625 | kwargs['lower_case'] = False 626 | kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt') 627 | elif dataset in ['enwik8', 'text8']: 628 | pass 629 | 630 | corpus = TransfoXLCorpus(datadir, dataset, **kwargs) 631 | torch.save(corpus, fn) 632 | 633 | return corpus 634 | 635 | def _is_whitespace(char): 636 | """Checks whether `chars` is a whitespace character.""" 637 | # \t, \n, and \r are technically contorl characters but we treat them 638 | # as whitespace since they are generally considered as such. 639 | if char == " " or char == "\t" or char == "\n" or char == "\r": 640 | return True 641 | cat = unicodedata.category(char) 642 | if cat == "Zs": 643 | return True 644 | return False 645 | 646 | 647 | def _is_control(char): 648 | """Checks whether `chars` is a control character.""" 649 | # These are technically control characters but we count them as whitespace 650 | # characters. 651 | if char == "\t" or char == "\n" or char == "\r": 652 | return False 653 | cat = unicodedata.category(char) 654 | if cat.startswith("C"): 655 | return True 656 | return False 657 | 658 | 659 | def _is_punctuation(char): 660 | """Checks whether `chars` is a punctuation character.""" 661 | cp = ord(char) 662 | # We treat all non-letter/number ASCII as punctuation. 663 | # Characters such as "^", "$", and "`" are not in the Unicode 664 | # Punctuation class but we treat them as punctuation anyways, for 665 | # consistency. 666 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 667 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 668 | return True 669 | cat = unicodedata.category(char) 670 | if cat.startswith("P"): 671 | return True 672 | return False 673 | --------------------------------------------------------------------------------