├── .circleci └── config.yml ├── .gitattributes ├── .gitignore ├── README.md ├── examples └── run_squad.py ├── pyproject.toml ├── pytorch_bert ├── __init__.py ├── feature.py ├── modeling.py ├── tokenizer.py └── weight_converter.py ├── requirements-dev.txt ├── requirements.txt ├── setup.cfg ├── setup.py ├── tests ├── __init__.py ├── test_converter.py ├── test_feature.py ├── test_modeling.py └── test_tokenizer.py └── tox.ini /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | executors: 4 | pytorch-bert-executor-36: 5 | working_directory: ~/pytorch-bert 6 | docker: 7 | - image: circleci/python:3.6 8 | 9 | pytorch-bert-executor-37: 10 | working_directory: ~/pytorch-bert 11 | docker: 12 | - image: circleci/python:3.7 13 | 14 | commands: 15 | test-with-tox: 16 | steps: 17 | - checkout 18 | 19 | - run: 20 | name: install tox and codecov 21 | command: pip install --user tox codecov 22 | 23 | - restore_cache: 24 | key: deps-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}-{{ checksum "tox.ini" }} 25 | 26 | - restore_cache: 27 | key: google-pretrained-weight 28 | 29 | - run: 30 | name: test tox 31 | command: ~/.local/bin/tox 32 | 33 | - save_cache: 34 | key: deps-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }} 35 | paths: 36 | - .tox 37 | 38 | - save_cache: 39 | key: google-pretrained-weight 40 | paths: 41 | - /tmp/bert-base 42 | - /tmp/bert-large 43 | 44 | - run: 45 | name: upload report to codecov 46 | command: ~/.local/bin/codecov 47 | 48 | jobs: 49 | run-test-36: 50 | executor: pytorch-bert-executor-36 51 | environment: 52 | TOXENV: py36 53 | steps: 54 | - test-with-tox 55 | 56 | run-test-37: 57 | executor: pytorch-bert-executor-37 58 | environment: 59 | TOXENV: py37 60 | steps: 61 | - test-with-tox 62 | 63 | workflows: 64 | version: 2 65 | pytorch-bert-workflow: 66 | jobs: 67 | - run-test-36 68 | - run-test-37 69 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | 127 | data/ 128 | .DS_Store 129 | .vscode/settings.json 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-bert 2 | 3 | [![codecov](https://codecov.io/gh/jeongukjae/pytorch-bert/branch/master/graph/badge.svg)](https://codecov.io/gh/jeongukjae/pytorch-bert) 4 | [![CircleCI](https://circleci.com/gh/jeongukjae/pytorch-bert.svg?style=shield)](https://circleci.com/gh/jeongukjae/pytorch-bert) 5 | [![PyPI](https://img.shields.io/pypi/v/pytorch-bert)](https://pypi.org/project/pytorch-bert/) 6 | [![PyPI Pyversion](https://img.shields.io/pypi/pyversions/pytorch-bert)](https://pypi.org/project/pytorch-bert/) 7 | 8 | A implementation of BERT using PyTorch `TransformerEncoder` and pre-trained model of [google-research/bert](https://github.com/google-research/bert). 9 | 10 | ## Installation 11 | 12 | ```sh 13 | pip install pytorch-bert 14 | ``` 15 | 16 | ## Usage 17 | 18 | - [**Usage**](https://github.com/jeongukjae/pytorch-bert/blob/8c276c222e721bc725049599f6b46dfedbc63340/tests/test_converter.py#L32) 19 | - [**Input, Ouptut Shape**](https://github.com/jeongukjae/pytorch-bert/blob/master/tests/test_modeling.py) 20 | 21 | ```python 22 | config = BertConfig.from_json("path-to-pretarined-weights/bert_config.json") 23 | model = Bert(config) 24 | load_tf_weight_to_pytorch_bert(model, config, "path-to-pretarined-weights/bert_model.ckpt") 25 | ``` 26 | 27 | Download model files in [google-research/bert](https://github.com/google-research/bert) repository. 28 | -------------------------------------------------------------------------------- /examples/run_squad.py: -------------------------------------------------------------------------------- 1 | # WIP 2 | import json 3 | import argparse 4 | import logging 5 | import sys 6 | from typing import NamedTuple, List 7 | 8 | import torch 9 | from torch import nn, optim 10 | from torch.utils.data import TensorDataset, RandomSampler, DataLoader 11 | 12 | from pytorch_bert import Bert, BertConfig, SubWordTokenizer 13 | from pytorch_bert.tokenizer import clean_text 14 | from pytorch_bert.weight_converter import load_tf_weight_to_pytorch_bert 15 | from pytorch_bert.feature import convert_sequences_to_feature 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--model-path", required=True) 19 | parser.add_argument("--config-path", required=True) 20 | parser.add_argument("--vocab-path", required=True) 21 | parser.add_argument("--squad-train-path", required=True) 22 | parser.add_argument("--epoch", default=3, type=int) 23 | parser.add_argument("--batch-size", default=32, type=int) 24 | parser.add_argument("--learing-rate", default=2e-5, type=float) 25 | parser.add_argument("--logging-step", default=200, type=int) 26 | parser.add_argument("--eval-step", default=500, type=int) 27 | 28 | logger = logging.getLogger() 29 | logger.setLevel(logging.DEBUG) 30 | 31 | handler = logging.StreamHandler(sys.stdout) 32 | handler.setFormatter(logging.Formatter("%(asctime)s - %(message)s")) 33 | logger.addHandler(handler) 34 | 35 | 36 | class BertForSquad(nn.Module): 37 | def __init__(self, config: BertConfig): 38 | super(BertForSquad, self).__init__() 39 | self.bert = Bert(config) 40 | self.squad_layer = nn.Linear(config.hidden_size, 2) 41 | 42 | def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, attention_mask: torch.Tensor): 43 | encoder_outputs, _ = self.bert(input_ids, token_type_ids, attention_mask) 44 | encoder_outputs = encoder_outputs.permute(1, 0, 2) 45 | logits = self.squad_layer(encoder_outputs) 46 | return logits 47 | 48 | 49 | class SquadExample(NamedTuple): 50 | context_text: str 51 | question_text: str 52 | answer_text: str 53 | start_position: int 54 | end_position: int 55 | 56 | 57 | def main(): 58 | args = parser.parse_args() 59 | 60 | logger.info(f"initialize model and converting weight from {args.model_path}") 61 | config = BertConfig.from_json(args.config_path) 62 | model = BertForSquad(config) 63 | load_tf_weight_to_pytorch_bert(model.bert, config, args.model_path) 64 | 65 | logger.info(f"initialize tokenizer using vocab {args.vocab_path}") 66 | tokenizer = SubWordTokenizer(args.vocab_path) 67 | 68 | logger.info(f"read squad dataset from {args.squad_train_path}") 69 | tokenizer = SubWordTokenizer(args.vocab_path) 70 | with open(args.squad_train_path, "r") as f: 71 | paragraphs = [paragraph for data in json.load(f)["data"] for paragraph in data["paragraphs"]] 72 | 73 | logger.info(f"convert squad dataset to features") 74 | examples: List[SquadExample] = [] 75 | for paragraph in paragraphs: 76 | original_context_text = paragraph["context"] 77 | context_text = clean_text(original_context_text) 78 | 79 | for qa in paragraph["qas"]: 80 | question_text = clean_text(qa["question"]) 81 | answer = qa["answers"][0] 82 | answer_text = clean_text(answer["text"]) 83 | 84 | start_position = context_text[: context_text.index(answer_text)].count(" ") 85 | end_position = start_position + answer_text.count(" ") 86 | 87 | examples.append(SquadExample(context_text, question_text, answer_text, start_position, end_position)) 88 | 89 | features = [ 90 | convert_sequences_to_feature( 91 | tokenizer, (example.context_text, example.question_text), config.max_position_embeddings 92 | ) 93 | for example in examples 94 | ] 95 | 96 | logger.info("create dataloader from squad features") 97 | input_ids = torch.tensor([feature.input_ids for feature in features]) 98 | input_type_ids = torch.tensor([feature.input_type_ids for feature in features]) 99 | input_mask = torch.tensor([feature.input_mask for feature in features]) 100 | label = torch.tensor([[example.start_position, example.end_position] for example in examples]) 101 | 102 | dataset = TensorDataset(input_ids, input_type_ids, input_mask, label) 103 | sampler = RandomSampler(dataset) 104 | train_loader = DataLoader(dataset, sampler=sampler, batch_size=args.batch_size) 105 | 106 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) 107 | criterion = nn.CrossEntropyLoss() 108 | 109 | logger.info("start training") 110 | logger.info(f"epoch: {args.epoch}") 111 | logger.info(f"batch size: {args.batch_size}") 112 | logger.info(f"length of dataset: {len(sampler)}") 113 | logger.info(f"length of steps per epoch: {len(train_loader)}") 114 | logger.info(f"learningrate: {args.learning_rate}") 115 | logger.info(f"logging steps: {args.logging_step}") 116 | logger.info(f"eval steps: {args.eval_step}") 117 | 118 | for epoch_index in range(args.epoch): 119 | model.train() 120 | running_loss = 0.0 121 | for batch_index, batch in enumerate(train_loader): 122 | # batch 123 | # 0: input ids 124 | # 1: input type ids 125 | # 2: input mask 126 | # 3: label 127 | optimizer.zero_grad() 128 | 129 | output = model(batch[0], batch[1], batch[2]) 130 | 131 | loss = criterion(output, label) 132 | running_loss += loss.item() 133 | 134 | if batch_index % args.logging_step == args.logging_step - 1: 135 | logger.info(f"Step {batch_index + 1}] loss: {running_loss}") 136 | 137 | if __name__ == "__main__": 138 | main() 139 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | target-version = ['py37'] 4 | include = '\.py$' 5 | -------------------------------------------------------------------------------- /pytorch_bert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0a3" 2 | 3 | __all__ = ["convert_sequences_to_feature", "Bert", "BertConfig", "PretrainingBert", "SubWordTokenizer", "Vocab"] 4 | 5 | from .feature import convert_sequences_to_feature 6 | from .modeling import Bert, BertConfig, PretrainingBert 7 | from .tokenizer import SubWordTokenizer, Vocab 8 | -------------------------------------------------------------------------------- /pytorch_bert/feature.py: -------------------------------------------------------------------------------- 1 | from typing import List, NamedTuple, Tuple, Union, cast 2 | 3 | from .tokenizer import SpecialToken, SubWordTokenizer 4 | 5 | SequencePair = Tuple[str, str] 6 | Sequences = Union[Tuple[str], SequencePair] 7 | 8 | 9 | class Feature(NamedTuple): 10 | tokens: List[str] 11 | input_ids: List[int] 12 | input_type_ids: List[int] 13 | input_mask: List[float] 14 | 15 | 16 | def convert_sequences_to_feature( 17 | tokenizer: SubWordTokenizer, sequences: Sequences, max_sequence_length: int 18 | ) -> Feature: 19 | tokenized_sequences = tuple(tokenizer.tokenize(sequence) for sequence in sequences) 20 | is_sequence_pair = _is_sequence_pair(tokenized_sequences) 21 | 22 | if is_sequence_pair: 23 | # [CLS], sequence1, [SEP], sequence2, [SEP] 24 | tokenized_sequences = cast(Tuple[List[str], List[str]], tokenized_sequences) 25 | tokenized_sequences = _truncate_sequence_pair(tokenized_sequences, max_sequence_length - 3) 26 | else: 27 | # [CLS], sequence1, [SEP] 28 | if len(tokenized_sequences[0]) > max_sequence_length - 2: 29 | tokenized_sequences = tuple(tokenized_sequences[0][0 : max_sequence_length - 2]) 30 | 31 | tokens = [SpecialToken.cls_] + tokenized_sequences[0] + [SpecialToken.sep] 32 | input_type_ids = [0] * (len(tokenized_sequences[0]) + 2) 33 | 34 | if is_sequence_pair: 35 | tokens.extend(tokenized_sequences[1]) 36 | tokens.append(SpecialToken.sep) 37 | 38 | input_type_ids.extend([1] * (len(tokenized_sequences[1]) + 1)) 39 | 40 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 41 | input_mask = [0.0] * len(input_ids) 42 | 43 | if len(input_ids) < max_sequence_length: 44 | padding_list = [0] * (max_sequence_length - len(input_ids)) 45 | padding_list_for_mask = [float("-inf")] * len(padding_list) 46 | 47 | input_type_ids.extend(padding_list) 48 | input_ids.extend(padding_list) 49 | input_mask.extend(padding_list_for_mask) 50 | 51 | return Feature(tokens, input_ids, input_type_ids, input_mask) 52 | 53 | 54 | def _is_sequence_pair(sequences: Tuple) -> bool: 55 | return len(sequences) == 2 56 | 57 | 58 | def _truncate_sequence_pair( 59 | tokenized_sequences: Tuple[List[str], List[str]], max_length: int 60 | ) -> Tuple[List[str], List[str]]: 61 | sequence1, sequence2 = tokenized_sequences 62 | 63 | while True: 64 | total_length = len(sequence1) + len(sequence2) 65 | if total_length <= max_length: 66 | return (sequence1, sequence2) 67 | 68 | if len(sequence1) > len(sequence2): 69 | sequence1.pop() 70 | else: 71 | sequence2.pop() 72 | -------------------------------------------------------------------------------- /pytorch_bert/modeling.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class BertConfig: 9 | def __init__( 10 | self, 11 | vocab_size: int, 12 | hidden_size: int = 768, 13 | num_hidden_layers: int = 12, 14 | num_attention_heads: int = 12, 15 | intermediate_size: int = 3072, 16 | hidden_act: str = "gelu", 17 | hidden_dropout_prob: float = 0.1, 18 | attention_probs_dropout_prob: float = 0.1, 19 | max_position_embeddings: int = 512, 20 | type_vocab_size: int = 16, 21 | initializer_range: float = 0.0, 22 | **kwargs, # unused 23 | ): 24 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 25 | self.hidden_act = hidden_act 26 | self.hidden_dropout_prob = hidden_dropout_prob 27 | self.hidden_size = hidden_size 28 | self.initializer_range = initializer_range 29 | self.intermediate_size = intermediate_size 30 | self.max_position_embeddings = max_position_embeddings 31 | self.num_attention_heads = num_attention_heads 32 | self.num_hidden_layers = num_hidden_layers 33 | self.type_vocab_size = type_vocab_size 34 | self.vocab_size = vocab_size 35 | 36 | @staticmethod 37 | def from_json(path: str) -> "BertConfig": 38 | with open(path, "r") as f: 39 | file_content = json.load(f) 40 | 41 | return BertConfig(**file_content) 42 | 43 | 44 | def init_bert_weight(init_range: float): 45 | def fn_to_apply(module: nn.Module): 46 | if isinstance(module, (nn.Linear, nn.Embedding)): 47 | module.weight.data.normal_(mean=0.0, std=init_range) 48 | 49 | if isinstance(module, nn.Linear) and module.bias is not None: 50 | module.bias.data.zero_() 51 | elif isinstance(module, nn.LayerNorm): 52 | module.bias.data.zero_() 53 | module.weight.data.fill_(1.0) 54 | 55 | return fn_to_apply 56 | 57 | 58 | class Bert(nn.Module): 59 | def __init__(self, config: BertConfig): 60 | super(Bert, self).__init__() 61 | self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 62 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 63 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 64 | self.embedding_layer_norm = nn.LayerNorm(config.hidden_size) 65 | self.embedding_dropout = nn.Dropout(p=config.hidden_dropout_prob) 66 | 67 | self.encoders = nn.TransformerEncoder( 68 | encoder_layer=nn.TransformerEncoderLayer( 69 | d_model=config.hidden_size, 70 | nhead=config.num_attention_heads, 71 | dim_feedforward=config.intermediate_size, 72 | dropout=config.attention_probs_dropout_prob, 73 | activation=config.hidden_act, 74 | ), 75 | num_layers=config.num_hidden_layers, 76 | ) 77 | 78 | self.pooler_layer = nn.Linear(config.hidden_size, config.hidden_size) 79 | self.pooled_output_activate = nn.Tanh() 80 | 81 | def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, attention_mask: torch.Tensor): 82 | seq_length = input_ids.size(1) 83 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 84 | 85 | words_embeddings = self.token_embeddings(input_ids) 86 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 87 | position_embeddings = self.position_embeddings(position_ids) 88 | 89 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 90 | embeddings = self.embedding_layer_norm(embeddings) 91 | embeddings = self.embedding_dropout(embeddings) 92 | 93 | encoder_outputs = self.encoders(embeddings.permute(1, 0, 2), src_key_padding_mask=attention_mask) 94 | pooled_output = self.pooled_output_activate(self.pooler_layer(encoder_outputs[0])) 95 | 96 | return encoder_outputs, pooled_output 97 | 98 | 99 | class PretrainingBert(nn.Module): 100 | def __init__(self, config: BertConfig): 101 | super(PretrainingBert, self).__init__() 102 | 103 | self.bert = Bert(config) 104 | self.mlm = BertMLM(config) 105 | self.nsp = BertNSP(config) 106 | 107 | def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, attention_mask: torch.Tensor): 108 | encoder_outputs, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 109 | mlm_output = self.mlm(encoder_outputs) 110 | nsp_output = self.nsp(pooled_output) 111 | 112 | return encoder_outputs, pooled_output, mlm_output, nsp_output 113 | 114 | 115 | class BertMLM(nn.Module): 116 | def __init__(self, config: BertConfig): 117 | super(BertMLM, self).__init__() 118 | 119 | self.transform = nn.Linear(config.hidden_size, config.hidden_size) 120 | self.transform_layer_norm = nn.LayerNorm(config.hidden_size) 121 | 122 | self.output_layer = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 123 | self.output_bias = nn.Parameter(torch.zeros(config.vocab_size)) 124 | self.output_logit = nn.Softmax(dim=2) 125 | 126 | def forward(self, encoder_outputs: torch.Tensor) -> torch.Tensor: 127 | transformed = F.gelu(self.transform(encoder_outputs)) 128 | transformed = self.transform_layer_norm(transformed) 129 | 130 | logits = self.output_layer(transformed) 131 | return self.output_logit(logits + self.output_bias) 132 | 133 | 134 | class BertNSP(nn.Module): 135 | def __init__(self, config: BertConfig): 136 | super(BertNSP, self).__init__() 137 | 138 | self.output_layer = nn.Linear(config.hidden_size, 2) 139 | self.output_softmax = nn.Softmax(dim=-1) 140 | 141 | def forward(self, pooled_output: torch.Tensor) -> torch.Tensor: 142 | return self.output_softmax(self.output_layer(pooled_output)) 143 | -------------------------------------------------------------------------------- /pytorch_bert/tokenizer.py: -------------------------------------------------------------------------------- 1 | """Tokenizer & Vocab 2 | 3 | Check general category value of unicode in https://www.unicode.org/reports/tr44/#General_Category_Values 4 | """ 5 | import unicodedata 6 | from collections import OrderedDict 7 | from typing import Dict, List, Optional, Tuple, Union, cast 8 | 9 | 10 | class SpecialToken: 11 | unk = "[UNK]" 12 | sep = "[SEP]" 13 | cls_ = "[CLS]" 14 | mask = "[MASK]" 15 | 16 | 17 | class SubWordTokenizer: 18 | def __init__(self, vocab: Union["Vocab", str], do_lower_case: bool = True): 19 | if isinstance(vocab, str): 20 | vocab = Vocab(vocab) 21 | 22 | self.vocab = vocab 23 | 24 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 25 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 26 | 27 | def tokenize(self, text: str) -> List[str]: 28 | return [ 29 | sub_token 30 | for token in self.basic_tokenizer.tokenize(text) 31 | for sub_token in self.wordpiece_tokenizer.tokenize(token) 32 | ] 33 | 34 | def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: 35 | return self.vocab.convert_tokens_to_ids(tokens) 36 | 37 | def convert_ids_to_tokens(self, ids: List[int]) -> List[str]: 38 | return self.vocab.convert_ids_to_tokens(ids) 39 | 40 | 41 | class Vocab: 42 | def __init__(self, vocab_path: str): 43 | self.__vocab = self._load_vocab(vocab_path) 44 | self.__inv_vocab = {v: k for k, v in self.__vocab.items()} 45 | 46 | def __contains__(self, key: str) -> bool: 47 | return key in self.__vocab 48 | 49 | def convert_token_to_id(self, token: str) -> int: 50 | return self.__vocab[token] 51 | 52 | def convert_id_to_token(self, id: int) -> str: 53 | return self.__inv_vocab[id] 54 | 55 | def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: 56 | return cast(List[int], self._convert_by_vocab(self.__vocab, tokens)) 57 | 58 | def convert_ids_to_tokens(self, ids: List[int]) -> List[str]: 59 | return cast(List[str], self._convert_by_vocab(self.__inv_vocab, ids)) 60 | 61 | @staticmethod 62 | def _load_vocab(vocab_path: str) -> OrderedDict: 63 | vocab = OrderedDict() 64 | index = 0 65 | with open(vocab_path, "r") as f: 66 | for line in f: 67 | token = _convert_to_str(line).strip() 68 | vocab[token] = index 69 | index += 1 70 | 71 | return vocab 72 | 73 | @staticmethod 74 | def _convert_by_vocab(vocab: Dict, items: List[Union[int, str]]) -> List[Union[int, str]]: 75 | return [vocab[item] for item in items] 76 | 77 | 78 | class BasicTokenizer: 79 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 80 | 81 | def __init__(self, do_lower_case=True): 82 | """Constructs a BasicTokenizer. 83 | Args: 84 | do_lower_case: Whether to lower case the input. 85 | """ 86 | self.do_lower_case = do_lower_case 87 | 88 | def tokenize(self, text: str) -> List[str]: 89 | """Tokenizes a piece of text.""" 90 | text = clean_text(text) 91 | 92 | # This was added on November 1st, 2018 for the multilingual and Chinese 93 | # models. This is also applied to the English models now, but it doesn't 94 | # matter since the English models were not trained on any Chinese data 95 | # and generally don't have any Chinese data in them (there are Chinese 96 | # characters in the vocabulary because Wikipedia does have some Chinese 97 | # words in the English Wikipedia.). 98 | text = _tokenize_chinese_chars(text) 99 | 100 | original_tokens = tokenize_whitespace(text) 101 | splitted_tokens = [] 102 | for token in original_tokens: 103 | if self.do_lower_case: 104 | token = token.lower() 105 | token = self._strip_accents(token) 106 | splitted_tokens.extend(self._split_on_punc(token)) 107 | 108 | output_tokens = tokenize_whitespace(" ".join(splitted_tokens)) 109 | return output_tokens 110 | 111 | @staticmethod 112 | def _strip_accents(text: str) -> str: 113 | """Strips accents from a piece of text.""" 114 | text = unicodedata.normalize("NFD", text) 115 | output = [char for char in text if unicodedata.category(char) != "Mn"] 116 | 117 | return "".join(output) 118 | 119 | @staticmethod 120 | def _split_on_punc(text: str) -> List[str]: 121 | """Splits punctuation on a piece of text.""" 122 | start_new_word = True 123 | output = [] 124 | 125 | for char in text: 126 | if _is_punctuation(char): 127 | output.append([char]) 128 | start_new_word = True 129 | else: 130 | if start_new_word: 131 | output.append([]) 132 | start_new_word = False 133 | output[-1].append(char) 134 | 135 | return ["".join(x) for x in output] 136 | 137 | 138 | class WordpieceTokenizer: 139 | """Runs WordPiece tokenziation.""" 140 | 141 | __PREFIX_OF_SUBWORD = "##" 142 | 143 | def __init__(self, vocab: Vocab, unknown_token: str = SpecialToken.unk, max_length_of_word: int = 200): 144 | self.vocab = vocab 145 | self.unknown_token = unknown_token 146 | self.max_length_of_word = max_length_of_word 147 | 148 | def tokenize(self, text: str) -> List[str]: 149 | """Tokenizes a piece of text into its word pieces. 150 | This uses a greedy longest-match-first algorithm to perform tokenization 151 | using the given vocabulary. 152 | For example: 153 | input = "unaffable" 154 | output = ["un", "##aff", "##able"] 155 | Args: 156 | text: A single token or whitespace separated tokens. This should have 157 | already been passed through `BasicTokenizer. 158 | Returns: 159 | A list of wordpiece tokens. 160 | """ 161 | return [subword for token in tokenize_whitespace(text) for subword in self._split_to_subwords(token)] 162 | 163 | def _split_to_subwords(self, token: str) -> List[str]: 164 | if len(token) > self.max_length_of_word: 165 | return [self.unknown_token] 166 | 167 | start = 0 168 | subwords = [] 169 | 170 | while start < len(token): 171 | subword, end = self._find_subword_in_token(token, start) 172 | if subword is None: 173 | return [self.unknown_token] 174 | subwords.append(subword) 175 | start = end 176 | 177 | return subwords 178 | 179 | def _find_subword_in_token(self, token: str, start_position: int) -> Tuple[Optional[str], int]: 180 | end_position = len(token) 181 | while end_position > start_position: 182 | subword = token[start_position:end_position] 183 | if start_position > 0: 184 | subword = self.__PREFIX_OF_SUBWORD + subword 185 | 186 | if subword in self.vocab: 187 | return subword, end_position 188 | 189 | end_position -= 1 190 | 191 | return None, end_position 192 | 193 | 194 | def clean_text(text: str) -> str: 195 | """Performs invalid character removal and whitespace cleanup on text.""" 196 | output = [" " if _is_whitespace(char) else char for char in text if not _is_invalid_char(char)] 197 | return "".join(output) 198 | 199 | 200 | def tokenize_whitespace(text: str) -> List[str]: 201 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 202 | return text.strip().split() 203 | 204 | 205 | def _convert_to_str(text: Union[str, bytes]) -> str: 206 | if isinstance(text, str): 207 | return text 208 | elif isinstance(text, bytes): 209 | return text.decode("utf-8", "ignore") 210 | else: 211 | raise ValueError("Unsupported string type: %s" % (type(text))) 212 | 213 | 214 | def _is_punctuation(char: str) -> bool: 215 | """Checks whether `chars` is a punctuation character.""" 216 | cp = ord(char) 217 | # We treat all non-letter/number ASCII as punctuation. 218 | # Characters such as "^", "$", and "`" are not in the Unicode 219 | # Punctuation class but we treat them as punctuation anyways, for 220 | # consistency. 221 | if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): 222 | return True 223 | cat = unicodedata.category(char) 224 | if cat.startswith("P"): 225 | return True 226 | return False 227 | 228 | 229 | def _is_whitespace(char: str) -> bool: 230 | """Returns True if a character is a space character 231 | 232 | .. note:: ``\t``, ``\n``, and ``\r`` are technically contorl characters but we treat them 233 | as whitespace since they are generally considered as such. 234 | """ 235 | if char in (" ", "\t", "\n", "\r"): 236 | return True 237 | if unicodedata.category(char) == "Zs": 238 | return True 239 | return False 240 | 241 | 242 | def _is_invalid_char(char: str) -> bool: 243 | char_code = ord(char) 244 | 245 | return char_code == 0 or char_code == 0xFFFD or _is_control(char) 246 | 247 | 248 | def _is_control(char): 249 | """Checks whether `chars` is a control character.""" 250 | # These are technically control characters but we count them as whitespace 251 | # characters. 252 | if char == "\t" or char == "\n" or char == "\r": 253 | return False 254 | cat = unicodedata.category(char) 255 | if cat in ("Cc", "Cf"): 256 | return True 257 | return False 258 | 259 | 260 | def _tokenize_chinese_chars(text): 261 | """Adds whitespace around any CJK character.""" 262 | output = [] 263 | for char in text: 264 | char_code = ord(char) 265 | if _is_chinese_char(char_code): 266 | output.extend([" ", char, " "]) 267 | else: 268 | output.append(char) 269 | return "".join(output) 270 | 271 | 272 | def _is_chinese_char(char_code: int): 273 | """Checks whether char_code is the code of a Chinese character.""" 274 | # https://en.wikipedia.org/wiki/List_of_CJK_Unified_Ideographs,_part_1_of_4 275 | if ( 276 | (char_code >= 0x4E00 and char_code <= 0x9FFF) # CJK Unified Ideographs 277 | or (char_code >= 0x3400 and char_code <= 0x4DBF) # CJK Unified Ideographs Extension A 278 | or (char_code >= 0x20000 and char_code <= 0x2A6DF) # CJK Unified Ideographs Extension B 279 | or (char_code >= 0x2A700 and char_code <= 0x2B73F) # CJK Unified Ideographs Extension C 280 | or (char_code >= 0x2B740 and char_code <= 0x2B81F) # CJK Unified Ideographs Extension D 281 | or (char_code >= 0x2B820 and char_code <= 0x2CEAF) # CJK Unified Ideographs Extension E 282 | or (char_code >= 0xF900 and char_code <= 0xFAFF) # CJK Compatibility Ideographs 283 | or (char_code >= 0x2F800 and char_code <= 0x2FA1F) # CJK Compatibility Ideographs Supplement 284 | ): 285 | return True 286 | 287 | return False 288 | -------------------------------------------------------------------------------- /pytorch_bert/weight_converter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | from .modeling import Bert, BertConfig, PretrainingBert 6 | 7 | try: 8 | import tensorflow as tf 9 | 10 | _is_tf_imported = True 11 | except ImportError: 12 | _is_tf_imported = False 13 | 14 | 15 | def load_tf_weight_to_pytorch_bert(bert: Bert, config: BertConfig, tf_model_path: str): 16 | if not _is_tf_imported: 17 | raise ImportError("cannot import tensorflow, please install tensorflow first") 18 | 19 | # load embedding layer 20 | _load_embedding(bert.token_embeddings, tf_model_path, "bert/embeddings/word_embeddings") 21 | _load_embedding(bert.token_type_embeddings, tf_model_path, "bert/embeddings/token_type_embeddings") 22 | _load_embedding(bert.position_embeddings, tf_model_path, "bert/embeddings/position_embeddings") 23 | _load_layer_norm(bert.embedding_layer_norm, tf_model_path, "bert/embeddings/LayerNorm") 24 | 25 | # load transformer encoders 26 | for layer_num in range(config.num_hidden_layers): 27 | encoder = bert.encoders.layers[layer_num] 28 | encoder_path = f"bert/encoder/layer_{layer_num}" 29 | 30 | _load_self_attention(encoder.self_attn, tf_model_path, f"{encoder_path}/attention") 31 | _load_layer_norm(encoder.norm1, tf_model_path, f"{encoder_path}/attention/output/LayerNorm") 32 | _load_layer_norm(encoder.norm2, tf_model_path, f"{encoder_path}/output/LayerNorm") 33 | 34 | _load_linear(encoder.self_attn.out_proj, tf_model_path, f"{encoder_path}/attention/output/dense") 35 | _load_linear(encoder.linear1, tf_model_path, f"{encoder_path}/intermediate/dense") 36 | _load_linear(encoder.linear2, tf_model_path, f"{encoder_path}/output/dense") 37 | 38 | # load pooler layer 39 | _load_linear(bert.pooler_layer, tf_model_path, f"bert/pooler/dense") 40 | 41 | 42 | def load_tf_weight_to_pytorch_pretraining_bert( 43 | bert: PretrainingBert, config: BertConfig, tf_model_path: str, share_parameters: bool = False 44 | ): 45 | load_tf_weight_to_pytorch_bert(bert.bert, config, tf_model_path) 46 | 47 | # load mlm 48 | _load_linear(bert.mlm.transform, tf_model_path, "cls/predictions/transform/dense", load_bias=False) 49 | _load_layer_norm(bert.mlm.transform_layer_norm, tf_model_path, "cls/predictions/transform/LayerNorm") 50 | if share_parameters: 51 | bert.mlm.output_layer.weight = bert.bert.token_embeddings.weight 52 | else: 53 | bert.mlm.output_layer.weight = nn.Parameter(bert.bert.token_embeddings.weight.clone()) 54 | _load_raw(bert.mlm.output_bias, tf_model_path, "cls/predictions/output_bias") 55 | 56 | # load nsp 57 | _load_raw(bert.nsp.output_layer.weight, tf_model_path, f"cls/seq_relationship/output_weights") 58 | _load_raw(bert.nsp.output_layer.bias, tf_model_path, f"cls/seq_relationship/output_bias") 59 | 60 | 61 | def _load_embedding(embedding: nn.Embedding, tf_model_path: str, embedding_path: str): 62 | embedding_weight = _load_tf_variable(tf_model_path, embedding_path) 63 | _load_torch_weight(embedding.weight, embedding_weight) 64 | 65 | 66 | def _load_layer_norm(layer_norm: torch.nn.LayerNorm, tf_model_path: str, layer_norm_base: str): 67 | layer_norm_gamma = _load_tf_variable(tf_model_path, f"{layer_norm_base}/gamma") 68 | layer_norm_beta = _load_tf_variable(tf_model_path, f"{layer_norm_base}/beta") 69 | 70 | _load_torch_weight(layer_norm.weight, layer_norm_gamma) 71 | _load_torch_weight(layer_norm.bias, layer_norm_beta) 72 | 73 | 74 | def _load_linear(linear: torch.nn.Linear, tf_model_path: str, linear_path: str, load_bias: bool = True): 75 | linear_weight = _load_tf_variable(tf_model_path, f"{linear_path}/kernel") 76 | linear_weight = np.transpose(linear_weight) 77 | _load_torch_weight(linear.weight, linear_weight) 78 | 79 | if load_bias: 80 | linear_bias = _load_tf_variable(tf_model_path, f"{linear_path}/bias") 81 | _load_torch_weight(linear.bias, linear_bias) 82 | 83 | 84 | def _load_self_attention(param: torch.nn.MultiheadAttention, tf_model_path: str, attention_path: str): 85 | query_weight = _load_tf_variable(tf_model_path, f"{attention_path}/self/query/kernel") 86 | key_weight = _load_tf_variable(tf_model_path, f"{attention_path}/self/key/kernel") 87 | value_weight = _load_tf_variable(tf_model_path, f"{attention_path}/self/value/kernel") 88 | 89 | query_weight = np.transpose(query_weight) 90 | key_weight = np.transpose(key_weight) 91 | value_weight = np.transpose(value_weight) 92 | 93 | query_bias = _load_tf_variable(tf_model_path, f"{attention_path}/self/query/bias") 94 | key_bias = _load_tf_variable(tf_model_path, f"{attention_path}/self/key/bias") 95 | value_bias = _load_tf_variable(tf_model_path, f"{attention_path}/self/value/bias") 96 | 97 | in_proj_weight = np.concatenate((query_weight, key_weight, value_weight)) 98 | in_proj_bias = np.concatenate((query_bias, key_bias, value_bias)) 99 | 100 | _load_torch_weight(param.in_proj_weight, in_proj_weight) 101 | _load_torch_weight(param.in_proj_bias, in_proj_bias) 102 | 103 | 104 | def _load_raw(param: torch.Tensor, tf_model_path: str, path: str): 105 | w = _load_tf_variable(tf_model_path, path) 106 | _load_torch_weight(param, w) 107 | 108 | 109 | def _load_tf_variable(model_path: str, key: str): 110 | return tf.train.load_variable(model_path, key).squeeze() 111 | 112 | 113 | def _load_torch_weight(param: torch.Tensor, data): 114 | assert param.shape == data.shape 115 | param.data = torch.from_numpy(data) 116 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # linters, formatters, analyzers 2 | flake8 3 | isort 4 | black 5 | 6 | # testing 7 | pytest 8 | pytest-cov 9 | 10 | # optional 11 | tensorflow 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = E203, E501, W503, E241 4 | 5 | [tool:isort] 6 | line_length = 120 7 | multi_line_output = 3 8 | include_trailing_comma = True 9 | 10 | [tool:pytest] 11 | addopts = -ra -v -l 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="pytorch-bert", 5 | version="1.0.0a3", 6 | install_requires=["torch>=1.3.0"], 7 | extras_require={"with-tf": ["tensorflow>=2.0.0", "numpy"]}, 8 | packages=find_packages(exclude=["tests"]), 9 | python_requires=">=3.6, <3.8", 10 | # 11 | description="bert implementation", 12 | author="Jeong Ukjae", 13 | author_email="jeongukjae@gmail.com", 14 | url="https://github.com/jeongukjae/pytorch-bert", 15 | classifiers=[ 16 | "Development Status :: 3 - Alpha", 17 | "Topic :: Scientific/Engineering", 18 | "License :: OSI Approved :: MIT License", 19 | "Programming Language :: Python", 20 | "Programming Language :: Python :: 3", 21 | "Programming Language :: Python :: 3.6", 22 | "Programming Language :: Python :: 3.7", 23 | ], 24 | ) 25 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeongukjae/pytorch-bert/71cf8c9a9a4ae1585ae8e733d73d95e1b5c7ffc1/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_converter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib.request 3 | import zipfile 4 | 5 | import pytest 6 | 7 | from pytorch_bert.modeling import Bert, BertConfig, PretrainingBert 8 | from pytorch_bert.weight_converter import load_tf_weight_to_pytorch_bert, load_tf_weight_to_pytorch_pretraining_bert 9 | 10 | google_bert_model_parameters = pytest.mark.parametrize( 11 | "url,directory,unzipped_path", 12 | [ 13 | pytest.param( 14 | "https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip", 15 | "/tmp/bert-base", 16 | "/tmp/bert-base/multi_cased_L-12_H-768_A-12", 17 | ), 18 | pytest.param( 19 | "https://storage.googleapis.com/bert_models/2019_05_30/wwm_uncased_L-24_H-1024_A-16.zip", 20 | "/tmp/bert-large", 21 | "/tmp/bert-large/wwm_uncased_L-24_H-1024_A-16", 22 | ), 23 | ], 24 | ) 25 | 26 | 27 | @google_bert_model_parameters 28 | def test_convert_pretrained_weight_bert(url: str, directory: str, unzipped_path: str): 29 | if not os.path.isdir(directory): 30 | os.mkdir(directory) 31 | 32 | download_model_file(url, directory) 33 | 34 | config = BertConfig.from_json(f"{unzipped_path}/bert_config.json") 35 | model = Bert(config) 36 | load_tf_weight_to_pytorch_bert(model, config, f"{unzipped_path}/bert_model.ckpt") 37 | 38 | 39 | @google_bert_model_parameters 40 | def test_convert_pretrained_weight_of_pretraining_bert(url: str, directory: str, unzipped_path: str): 41 | if not os.path.isdir(directory): 42 | os.mkdir(directory) 43 | 44 | download_model_file(url, directory) 45 | 46 | config = BertConfig.from_json(f"{unzipped_path}/bert_config.json") 47 | model = PretrainingBert(config) 48 | load_tf_weight_to_pytorch_pretraining_bert(model, config, f"{unzipped_path}/bert_model.ckpt") 49 | 50 | assert id(model.bert.token_embeddings.weight) != id(model.mlm.output_layer.weight) 51 | 52 | 53 | @google_bert_model_parameters 54 | def test_convert_pretrained_weight_of_pretraining_bert_sharing_parameters(url: str, directory: str, unzipped_path: str): 55 | if not os.path.isdir(directory): 56 | os.mkdir(directory) 57 | 58 | download_model_file(url, directory) 59 | 60 | config = BertConfig.from_json(f"{unzipped_path}/bert_config.json") 61 | model = PretrainingBert(config) 62 | load_tf_weight_to_pytorch_pretraining_bert(model, config, f"{unzipped_path}/bert_model.ckpt", share_parameters=True) 63 | 64 | assert id(model.bert.token_embeddings.weight) == id(model.mlm.output_layer.weight) 65 | 66 | 67 | def download_model_file(url: str, cache_directory: str = "/tmp", force_download: bool = False): 68 | filename = url.split("/")[-1] 69 | 70 | if not os.path.isdir(cache_directory): 71 | raise ValueError(f"{cache_directory} is not a directory") 72 | 73 | download_path = os.path.join(cache_directory, filename) 74 | if not force_download and os.path.exists(download_path): 75 | return 76 | 77 | urllib.request.urlretrieve(url, download_path) 78 | 79 | model_zip = zipfile.ZipFile(download_path) 80 | model_zip.extractall(cache_directory) 81 | model_zip.close() 82 | -------------------------------------------------------------------------------- /tests/test_feature.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import pytest 4 | 5 | import pytorch_bert.feature as F 6 | import pytorch_bert.tokenizer as T 7 | 8 | VOCAB_WORDS = [ 9 | "[PAD]", 10 | "[UNK]", 11 | "[SEP]", 12 | "[CLS]", 13 | "[MASK]", 14 | "the", 15 | "dog", 16 | "is", 17 | "hairy", 18 | ",", 19 | "this", 20 | "jack", 21 | "##son", 22 | "##ville", 23 | "?", 24 | "no", 25 | "it", 26 | "not", 27 | ] 28 | 29 | 30 | @pytest.mark.parametrize( 31 | "tokenized_sequences,max_length,expected_output", 32 | [ 33 | pytest.param( 34 | ( 35 | ["the", "dog", "is", "hairy", ",", "is", "this", "jack", "##son", "##ville", "?"], 36 | ["no", "it", "is", "not"], 37 | ), 38 | 12, 39 | (["the", "dog", "is", "hairy", ",", "is", "this", "jack"], ["no", "it", "is", "not"]), 40 | ), 41 | pytest.param( 42 | (["the", "dog", "is", "hairy"], ["is", "this", "jack", "##son", "##ville", "?", "no", "it", "is", "not"]), 43 | 12, 44 | (["the", "dog", "is", "hairy"], ["is", "this", "jack", "##son", "##ville", "?", "no", "it"]), 45 | ), 46 | ], 47 | ) 48 | def test_truncate_sequence_pair(tokenized_sequences, max_length, expected_output): 49 | assert F._truncate_sequence_pair(tokenized_sequences, max_length) == expected_output 50 | 51 | 52 | @pytest.mark.parametrize( 53 | "sequence,max_sequence_length,expected_output", 54 | [ 55 | # fmt: off 56 | pytest.param( 57 | ("the dog is hairy, is this jacksonville?", "no it is not"), 58 | 18, 59 | ( 60 | ["[CLS]", "the", "dog", "is", "hairy", ",", "is", "this", "jack", "##son", "##ville", "?", "[SEP]", "no", "it", "is", "not", "[SEP]"], 61 | [3, 5, 6, 7, 8, 9, 7, 10, 11, 12, 13, 14, 2, 15, 16, 7, 17, 2], 62 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1], 63 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 64 | ) 65 | ), 66 | pytest.param( 67 | ("the dog is hairy, is this jacksonville?", "no it is not"), 68 | 20, 69 | ( 70 | ["[CLS]", "the", "dog", "is", "hairy", ",", "is", "this", "jack", "##son", "##ville", "?", "[SEP]", "no", "it", "is", "not", "[SEP]"], 71 | [3, 5, 6, 7, 8, 9, 7, 10, 11, 12, 13, 14, 2, 15, 16, 7, 17, 2, 0, 0], 72 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], 73 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -math.inf, -math.inf] 74 | ), 75 | ), 76 | pytest.param( 77 | ("the dog is hairy, is this jacksonville?",), 78 | 20, 79 | ( 80 | ["[CLS]", "the", "dog", "is", "hairy", ",", "is", "this", "jack", "##son", "##ville", "?", "[SEP]"], 81 | [3, 5, 6, 7, 8, 9, 7, 10, 11, 12, 13, 14, 2, 0, 0, 0, 0, 0, 0, 0], 82 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 83 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -math.inf, -math.inf, -math.inf, -math.inf, -math.inf, -math.inf, -math.inf] 84 | ) 85 | ) 86 | # fmt: on 87 | ], 88 | ) 89 | def test_convert_sequences_to_feature(tmpdir, sequence, max_sequence_length, expected_output): 90 | vocab_path = tmpdir.join("test-vocab-file.txt") 91 | vocab_path.write("\n".join(VOCAB_WORDS)) 92 | 93 | vocab = T.Vocab(vocab_path) 94 | tokenizer = T.SubWordTokenizer(vocab, True) 95 | output = F.convert_sequences_to_feature(tokenizer, sequence, max_sequence_length) 96 | 97 | assert output == expected_output 98 | -------------------------------------------------------------------------------- /tests/test_modeling.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | import pytorch_bert.modeling as M 5 | 6 | 7 | @pytest.fixture 8 | def config(): 9 | return M.BertConfig(300, num_hidden_layers=3) 10 | 11 | 12 | def test_bert_with_random_input(config: M.BertConfig): 13 | model = M.Bert(config) 14 | model.apply(M.init_bert_weight(config.initializer_range)) 15 | batch_size = 1 16 | 17 | encoder_ouputs, pooled_output = model( 18 | torch.randint(300, (batch_size, config.max_position_embeddings)), 19 | torch.randint(2, (batch_size, config.max_position_embeddings)), 20 | torch.randint(2, (batch_size, config.max_position_embeddings), dtype=torch.bool), 21 | ) 22 | 23 | assert encoder_ouputs.size() == (config.max_position_embeddings, batch_size, config.hidden_size) 24 | assert pooled_output.size() == (batch_size, config.hidden_size) 25 | 26 | 27 | def test_pretraining_bert_with_random_input(config: M.BertConfig): 28 | model = M.PretrainingBert(config) 29 | model.apply(M.init_bert_weight(config.initializer_range)) 30 | batch_size = 1 31 | 32 | encoder_ouputs, pooled_output, mlm_output, nsp_output = model( 33 | torch.randint(300, (batch_size, config.max_position_embeddings)), 34 | torch.randint(2, (batch_size, config.max_position_embeddings)), 35 | torch.randint(2, (batch_size, config.max_position_embeddings), dtype=torch.bool), 36 | ) 37 | 38 | assert encoder_ouputs.size() == (config.max_position_embeddings, batch_size, config.hidden_size) 39 | assert pooled_output.size() == (batch_size, config.hidden_size) 40 | assert mlm_output.size() == (config.max_position_embeddings, batch_size, config.vocab_size) 41 | assert nsp_output.size() == (batch_size, 2) 42 | -------------------------------------------------------------------------------- /tests/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | import pytorch_bert.tokenizer as T 2 | 3 | 4 | def test_load_vocab(tmpdir): 5 | path = tmpdir.mkdir("test").join("vocab.txt") 6 | path.write("\n".join(["word1", "word2", "word3"])) 7 | 8 | vocab = T.Vocab(str(path)) 9 | 10 | assert vocab.convert_id_to_token(0) == "word1" 11 | assert vocab.convert_token_to_id("word2") == 1 12 | assert vocab.convert_ids_to_tokens([1, 0]) == ["word2", "word1"] 13 | assert vocab.convert_tokens_to_ids(["word3", "word1"]) == [2, 0] 14 | 15 | 16 | def test_full_tokenizer(tmpdir): 17 | path = tmpdir.mkdir("test").join("full_vocab.txt") 18 | path.write("\n".join(["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing", ","])) 19 | 20 | vocab = T.Vocab(path) 21 | tokenizer = T.SubWordTokenizer(vocab) 22 | 23 | tokens = tokenizer.tokenize("UNwant\u00E9d,running") 24 | 25 | assert tokens == ["un", "##want", "##ed", ",", "runn", "##ing"] 26 | assert tokenizer.convert_tokens_to_ids(tokens) == [7, 4, 5, 10, 8, 9] 27 | assert tokenizer.convert_ids_to_tokens([7, 4, 5, 10, 8, 9]) == tokens 28 | 29 | 30 | def test_basic_tokenizer_no_lower(): 31 | tokenizer = T.BasicTokenizer(do_lower_case=False) 32 | 33 | assert tokenizer.tokenize(" \tHeLLo!how \n Are yoU? ") == ["HeLLo", "!", "how", "Are", "yoU", "?"] 34 | 35 | 36 | def test_basic_tokenizer_do_lower(): 37 | lowered_tokenizer = T.BasicTokenizer(do_lower_case=True) 38 | 39 | assert lowered_tokenizer.tokenize(" \tHeLLo!how \n Are yoU? ") == ["hello", "!", "how", "are", "you", "?"] 40 | assert lowered_tokenizer.tokenize("H\u00E9llo") == ["hello"] 41 | 42 | 43 | def test_basic_tokenizer_with_chinese_character(): 44 | tokenizer = T.BasicTokenizer() 45 | 46 | assert tokenizer.tokenize("ah\u535A\u63A8zz") == ["ah", "\u535A", "\u63A8", "zz"] 47 | 48 | 49 | def test_wordpiece_tokenizer(tmpdir): 50 | path = tmpdir.mkdir("test").join("vocab.txt") 51 | path.write("\n".join(["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"])) 52 | 53 | vocab = T.Vocab(str(path)) 54 | tokenizer = T.WordpieceTokenizer(vocab=vocab) 55 | 56 | assert tokenizer.tokenize("") == [] 57 | assert tokenizer.tokenize("unwanted running") == ["un", "##want", "##ed", "runn", "##ing"] 58 | assert tokenizer.tokenize("unwantedX running") == ["[UNK]", "runn", "##ing"] 59 | 60 | 61 | def test_clean_text(): 62 | assert T.clean_text("\tHello\n안녕안녕 mm") == " Hello 안녕안녕 mm" 63 | 64 | 65 | def test_is_whitespace(): 66 | assert T._is_whitespace(" ") 67 | assert T._is_whitespace("\t") 68 | assert T._is_whitespace("\r") 69 | assert T._is_whitespace("\n") 70 | assert T._is_whitespace("\u00A0") 71 | 72 | assert not T._is_whitespace("A") 73 | assert not T._is_whitespace("-") 74 | 75 | 76 | def test_is_control(): 77 | assert T._is_control("\u0005") 78 | 79 | assert not T._is_control("A") 80 | assert not T._is_control(" ") 81 | assert not T._is_control("\t") 82 | assert not T._is_control("\r") 83 | assert not T._is_control("\U0001F4A9") 84 | 85 | 86 | def test_is_punctuation(): 87 | assert T._is_punctuation("-") 88 | assert T._is_punctuation("$") 89 | assert T._is_punctuation("`") 90 | assert T._is_punctuation(".") 91 | 92 | assert not T._is_punctuation("A") 93 | assert not T._is_punctuation(" ") 94 | 95 | 96 | def test_tokenize_chinese_chars(): 97 | assert T._tokenize_chinese_chars("This is a Chinese character 一") == "This is a Chinese character 一 " 98 | assert T._tokenize_chinese_chars("no Chinese characters.") == "no Chinese characters." 99 | assert T._tokenize_chinese_chars("Some喥Rando喩mChi噟neseCharacter") == "Some 喥 Rando 喩 mChi 噟 neseCharacter" 100 | 101 | 102 | def test_is_chinese_char(): 103 | assert T._is_chinese_char(ord("一")) 104 | assert T._is_chinese_char(ord("壚")) 105 | assert not T._is_chinese_char(ord("a")) 106 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py36,py37 3 | 4 | [testenv] 5 | deps = 6 | -r requirements.txt 7 | -r requirements-dev.txt 8 | commands = 9 | flake8 pytorch_bert tests 10 | isort -rc pytorch_bert tests 11 | black pytorch_bert tests 12 | pytest --cov pytorch_bert 13 | --------------------------------------------------------------------------------