├── .gitignore ├── README.md ├── albert_tiny ├── checkpoint ├── config.json ├── pytorch_model.bin └── vocab.txt ├── albert_zh ├── __init__.py ├── configuration_bert.py ├── configuration_utils.py ├── file_utils.py ├── modeling_albert.py ├── modeling_utils.py ├── optimization.py ├── tokenization_bert.py └── tokenization_utils.py └── usage_example.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 原參考作者現已提供模型下載與測試結果 2 | https://github.com/lonePatient/albert_pytorch/blob/master/README_zh.md 3 | 4 | # Albert-zh for pytorch-transformers 5 | - **停止更新** 6 | - 僅僅是基於**參考**進行轉換,然後踩踩雷 7 | - Albert zh for [pytorch-transformers](https://github.com/huggingface/transformers) 8 | - 測試支援繁體中文 9 | 10 | ## 可用模型 11 | - [albert_tiny_zh](https://github.com/p208p2002/albert-zh-for-pytorch-transformers/releases/download/am_v1.1/albert_tiny.zip) 12 | - [albert_base_zh](https://github.com/p208p2002/albert-zh-for-pytorch-transformers/releases/download/am_v1.1/albert_base.zip) 13 | - [albert_large_zh](https://github.com/p208p2002/albert-zh-for-pytorch-transformers/releases/download/am_v1.1/albert_large.zip) 14 | - [albert_xlarge_zh](https://github.com/p208p2002/albert-zh-for-pytorch-transformers/releases/download/am_v1.1/albert_xlarge.zip) 15 | 16 | ## API 17 | 先將本repo中的`albert_zh`放置在你的專案底下 18 | 19 | `from albert_zh import ...` 20 | ``` 21 | AlbertConfig 22 | AlbertTokenizer 23 | AlbertModel 24 | AlbertForMaskedLM 25 | AlbertForQuestionAnswering 26 | AlbertForSequenceClassification 27 | ``` 28 | > https://huggingface.co/transformers/v2.3.0/model_doc/albert.html 29 | 30 | ## 使用方法 31 | - 請參見`usage_example.py` 32 | > 或是參考[p208p2002/taipei-QA-BERT](https://github.com/p208p2002/taipei-QA-BERT)的實際使用範例 33 | - 測試在 transformers 2.3.0 正常運作 34 | 35 | ## 常見問題 36 | #### 我想在jupyter、colab引入但是遇到問題 37 | 這個repo命名不符合python module命名慣例,並且jupyter本身對自訂的模組沒有很好的支援,請先參考下方的解決範例。後續考慮推上pypi 38 | ```jupyter 39 | # 此段code僅適用於jupyter、colab 40 | !git clone https://github.com/p208p2002/albert-zh-for-pytorch-transformers.git albert 41 | import sys 42 | sys.path.append('.') 43 | from albert.albert_zh import AlbertConfig, AlbertTokenizer, AlbertForSequenceClassification 44 | ``` 45 | #### loss 降不下來,訓練出來變垃圾 46 | 確保 model class 與 model config 由 albert_zh 引入,而非 transformers 47 | > https://github.com/lonePatient/albert_pytorch/issues/35 48 | 49 | #### AttributeError: 'BertConfig' object has no attribute 'share_type' 50 | config.json增加`"share_type":"all"` 51 | 52 | #### 訓練時模型亂印東西 53 | 請用`log()`代替`print()`,並且在程式開始的時候先執行一次`blockPrint()` 54 | ```python 55 | import os,sys 56 | def log(*logs): 57 | enablePrint() 58 | print(*logs) 59 | blockPrint() 60 | 61 | # Disable 62 | def blockPrint(): 63 | sys.stdout = open(os.devnull, 'w') 64 | 65 | # Restore 66 | def enablePrint(): 67 | sys.stdout = sys.__stdout__ 68 | ``` 69 | 70 | ## 測試環境 71 | - python 3.6.4 72 | - pyotrch 1.3 (with cuda 10) 73 | - transformers 2.3.0 74 | 75 | ## 參考 76 | ### albert zh 77 | - https://github.com/brightmart/albert_zh 78 | ### albert tf to pytorch 79 | - https://github.com/lonePatient/albert_pytorch 80 | -------------------------------------------------------------------------------- /albert_tiny/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "albert_model.ckpt" 2 | all_model_checkpoint_paths: "albert_model.ckpt" 3 | -------------------------------------------------------------------------------- /albert_tiny/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.0, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.0, 6 | "hidden_size": 312, 7 | "embedding_size": 128, 8 | "initializer_range": 0.02, 9 | "intermediate_size": 1248 , 10 | "max_position_embeddings": 512, 11 | "num_attention_heads": 12, 12 | "num_hidden_layers": 4, 13 | 14 | "pooler_fc_size": 768, 15 | "pooler_num_attention_heads": 12, 16 | "pooler_num_fc_layers": 3, 17 | "pooler_size_per_head": 128, 18 | "pooler_type": "first_token_transform", 19 | "type_vocab_size": 2, 20 | "vocab_size": 21128, 21 | "ln_type":"postln", 22 | "share_type":"all" 23 | } 24 | -------------------------------------------------------------------------------- /albert_tiny/pytorch_model.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p208p2002/albert-zh-for-pytorch-transformers/9eceaf9796188b0a16971228bf41ad27c4a458f8/albert_tiny/pytorch_model.bin -------------------------------------------------------------------------------- /albert_zh/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_albert import AlbertModel,AlbertForMaskedLM,AlbertForQuestionAnswering,AlbertForSequenceClassification 2 | from .configuration_bert import BertConfig as AlbertConfig 3 | from .tokenization_bert import BertTokenizer as AlbertTokenizer -------------------------------------------------------------------------------- /albert_zh/configuration_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT model configuration """ 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import json 21 | import logging 22 | import sys 23 | from io import open 24 | 25 | from .configuration_utils import PretrainedConfig 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", 31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", 32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", 33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", 34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", 35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", 36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", 37 | 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", 38 | 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", 39 | 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", 40 | 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", 41 | 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", 42 | 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", 43 | } 44 | 45 | 46 | class BertConfig(PretrainedConfig): 47 | r""" 48 | :class:`~pytorch_transformers.BertConfig` is the configuration class to store the configuration of a 49 | `BertModel`. 50 | 51 | 52 | Arguments: 53 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. 54 | hidden_size: Size of the encoder layers and the pooler layer. 55 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 56 | num_attention_heads: Number of attention heads for each attention layer in 57 | the Transformer encoder. 58 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 59 | layer in the Transformer encoder. 60 | hidden_act: The non-linear activation function (function or string) in the 61 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 62 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 63 | layers in the embeddings, encoder, and pooler. 64 | attention_probs_dropout_prob: The dropout ratio for the attention 65 | probabilities. 66 | max_position_embeddings: The maximum sequence length that this model might 67 | ever be used with. Typically set this to something large just in case 68 | (e.g., 512 or 1024 or 2048). 69 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 70 | `BertModel`. 71 | initializer_range: The sttdev of the truncated_normal_initializer for 72 | initializing all weight matrices. 73 | layer_norm_eps: The epsilon used by LayerNorm. 74 | """ 75 | pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP 76 | 77 | def __init__(self, 78 | vocab_size_or_config_json_file=30522, 79 | hidden_size=768, 80 | num_hidden_layers=12, 81 | num_attention_heads=12, 82 | intermediate_size=3072, 83 | hidden_act="gelu", 84 | hidden_dropout_prob=0.1, 85 | attention_probs_dropout_prob=0.1, 86 | max_position_embeddings=512, 87 | type_vocab_size=2, 88 | initializer_range=0.02, 89 | layer_norm_eps=1e-12, 90 | **kwargs): 91 | super(BertConfig, self).__init__(**kwargs) 92 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 93 | and isinstance(vocab_size_or_config_json_file, unicode)): 94 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 95 | json_config = json.loads(reader.read()) 96 | for key, value in json_config.items(): 97 | self.__dict__[key] = value 98 | elif isinstance(vocab_size_or_config_json_file, int): 99 | self.vocab_size = vocab_size_or_config_json_file 100 | self.hidden_size = hidden_size 101 | self.num_hidden_layers = num_hidden_layers 102 | self.num_attention_heads = num_attention_heads 103 | self.hidden_act = hidden_act 104 | self.intermediate_size = intermediate_size 105 | self.hidden_dropout_prob = hidden_dropout_prob 106 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 107 | self.max_position_embeddings = max_position_embeddings 108 | self.type_vocab_size = type_vocab_size 109 | self.initializer_range = initializer_range 110 | self.layer_norm_eps = layer_norm_eps 111 | else: 112 | raise ValueError("First argument must be either a vocabulary size (int)" 113 | " or the path to a pretrained model config file (str)") 114 | -------------------------------------------------------------------------------- /albert_zh/configuration_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Configuration base class and utilities.""" 17 | 18 | from __future__ import (absolute_import, division, print_function, 19 | unicode_literals) 20 | import copy 21 | import json 22 | import logging 23 | import os 24 | from io import open 25 | 26 | from .file_utils import cached_path, CONFIG_NAME 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | class PretrainedConfig(object): 31 | r""" Base class for all configuration classes. 32 | Handles a few parameters tools to all models' configurations as well as methods for loading/downloading/saving configurations. 33 | 34 | Note: 35 | A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights. 36 | It only affects the model's configuration. 37 | 38 | Class attributes (overridden by derived classes): 39 | - ``pretrained_config_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained model configurations as values. 40 | 41 | Parameters: 42 | ``finetuning_task``: string, default `None`. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint. 43 | ``num_labels``: integer, default `2`. Number of classes to use when the model is a classification model (sequences/tokens) 44 | ``output_attentions``: boolean, default `False`. Should the model returns attentions weights. 45 | ``output_hidden_states``: string, default `False`. Should the model returns all hidden-states. 46 | ``torchscript``: string, default `False`. Is the model used with Torchscript. 47 | """ 48 | pretrained_config_archive_map = {} 49 | 50 | def __init__(self, **kwargs): 51 | self.finetuning_task = kwargs.pop('finetuning_task', None) 52 | self.num_labels = kwargs.pop('num_labels', 2) 53 | self.output_attentions = kwargs.pop('output_attentions', False) 54 | self.output_hidden_states = kwargs.pop('output_hidden_states', False) 55 | self.torchscript = kwargs.pop('torchscript', False) 56 | self.pruned_heads = kwargs.pop('pruned_heads', {}) 57 | 58 | def save_pretrained(self, save_directory): 59 | """ Save a configuration object to the directory `save_directory`, so that it 60 | can be re-loaded using the :func:`~pytorch_transformers.PretrainedConfig.from_pretrained` class method. 61 | """ 62 | assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" 63 | 64 | # If we save using the predefined names, we can load using `from_pretrained` 65 | output_config_file = os.path.join(save_directory, CONFIG_NAME) 66 | 67 | self.to_json_file(output_config_file) 68 | 69 | @classmethod 70 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): 71 | r""" Instantiate a :class:`~pytorch_transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration. 72 | 73 | Parameters: 74 | pretrained_model_name_or_path: either: 75 | 76 | - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``. 77 | - a path to a `directory` containing a configuration file saved using the :func:`~pytorch_transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. 78 | - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``. 79 | 80 | cache_dir: (`optional`) string: 81 | Path to a directory in which a downloaded pre-trained model 82 | configuration should be cached if the standard cache should not be used. 83 | 84 | kwargs: (`optional`) dict: key/value pairs with which to update the configuration object after loading. 85 | 86 | - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. 87 | - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter. 88 | 89 | force_download: (`optional`) boolean, default False: 90 | Force to (re-)download the model weights and configuration files and override the cached versions if they exists. 91 | 92 | proxies: (`optional`) dict, default None: 93 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. 94 | The proxies are used on each request. 95 | 96 | return_unused_kwargs: (`optional`) bool: 97 | 98 | - If False, then this function returns just the final configuration object. 99 | - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored. 100 | 101 | Examples:: 102 | 103 | # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a 104 | # derived class: BertConfig 105 | config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. 106 | config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` 107 | config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json') 108 | config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) 109 | assert config.output_attention == True 110 | config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, 111 | foo=False, return_unused_kwargs=True) 112 | assert config.output_attention == True 113 | assert unused_kwargs == {'foo': False} 114 | 115 | """ 116 | cache_dir = kwargs.pop('cache_dir', None) 117 | force_download = kwargs.pop('force_download', False) 118 | proxies = kwargs.pop('proxies', None) 119 | return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) 120 | 121 | if pretrained_model_name_or_path in cls.pretrained_config_archive_map: 122 | config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] 123 | elif os.path.isdir(pretrained_model_name_or_path): 124 | config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) 125 | else: 126 | config_file = pretrained_model_name_or_path 127 | # redirect to the cache, if necessary 128 | try: 129 | resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) 130 | except EnvironmentError as e: 131 | if pretrained_model_name_or_path in cls.pretrained_config_archive_map: 132 | logger.error( 133 | "Couldn't reach server at '{}' to download pretrained model configuration file.".format( 134 | config_file)) 135 | else: 136 | logger.error( 137 | "Model name '{}' was not found in model name list ({}). " 138 | "We assumed '{}' was a path or url but couldn't find any file " 139 | "associated to this path or url.".format( 140 | pretrained_model_name_or_path, 141 | ', '.join(cls.pretrained_config_archive_map.keys()), 142 | config_file)) 143 | raise e 144 | if resolved_config_file == config_file: 145 | logger.info("loading configuration file {}".format(config_file)) 146 | else: 147 | logger.info("loading configuration file {} from cache at {}".format( 148 | config_file, resolved_config_file)) 149 | 150 | # Load config 151 | config = cls.from_json_file(resolved_config_file) 152 | 153 | if hasattr(config, 'pruned_heads'): 154 | config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items()) 155 | 156 | # Update config with kwargs if needed 157 | to_remove = [] 158 | for key, value in kwargs.items(): 159 | if hasattr(config, key): 160 | setattr(config, key, value) 161 | to_remove.append(key) 162 | else: 163 | setattr(config,key,value) 164 | for key in to_remove: 165 | kwargs.pop(key, None) 166 | 167 | logger.info("Model config %s", config) 168 | if return_unused_kwargs: 169 | return config, kwargs 170 | else: 171 | return config 172 | 173 | @classmethod 174 | def from_dict(cls, json_object): 175 | """Constructs a `Config` from a Python dictionary of parameters.""" 176 | config = cls(vocab_size_or_config_json_file=-1) 177 | for key, value in json_object.items(): 178 | config.__dict__[key] = value 179 | return config 180 | 181 | @classmethod 182 | def from_json_file(cls, json_file): 183 | """Constructs a `BertConfig` from a json file of parameters.""" 184 | with open(json_file, "r", encoding='utf-8') as reader: 185 | text = reader.read() 186 | return cls.from_dict(json.loads(text)) 187 | 188 | def __eq__(self, other): 189 | return self.__dict__ == other.__dict__ 190 | 191 | def __repr__(self): 192 | return str(self.to_json_string()) 193 | 194 | def to_dict(self): 195 | """Serializes this instance to a Python dictionary.""" 196 | output = copy.deepcopy(self.__dict__) 197 | return output 198 | 199 | def to_json_string(self): 200 | """Serializes this instance to a JSON string.""" 201 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 202 | 203 | def to_json_file(self, json_file_path): 204 | """ Save this instance to a json file.""" 205 | with open(json_file_path, "w", encoding='utf-8') as writer: 206 | writer.write(self.to_json_string()) 207 | -------------------------------------------------------------------------------- /albert_zh/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import sys 9 | import json 10 | import logging 11 | import os 12 | import six 13 | import shutil 14 | import tempfile 15 | import fnmatch 16 | from functools import wraps 17 | from hashlib import sha256 18 | from io import open 19 | 20 | import boto3 21 | from botocore.config import Config 22 | from botocore.exceptions import ClientError 23 | import requests 24 | from tqdm import tqdm 25 | 26 | try: 27 | from torch.hub import _get_torch_home 28 | torch_cache_home = _get_torch_home() 29 | except ImportError: 30 | torch_cache_home = os.path.expanduser( 31 | os.getenv('TORCH_HOME', os.path.join( 32 | os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) 33 | default_cache_path = os.path.join(torch_cache_home, 'pytorch_transformers') 34 | 35 | try: 36 | from urllib.parse import urlparse 37 | except ImportError: 38 | from urlparse import urlparse 39 | 40 | try: 41 | from pathlib import Path 42 | PYTORCH_PRETRAINED_BERT_CACHE = Path( 43 | os.getenv('PYTORCH_TRANSFORMERS_CACHE', os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path))) 44 | except (AttributeError, ImportError): 45 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_TRANSFORMERS_CACHE', 46 | os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 47 | default_cache_path)) 48 | 49 | PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility 50 | 51 | WEIGHTS_NAME = "pytorch_model.bin" 52 | TF_WEIGHTS_NAME = 'model.ckpt' 53 | CONFIG_NAME = "config.json" 54 | 55 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 56 | 57 | if not six.PY2: 58 | def add_start_docstrings(*docstr): 59 | def docstring_decorator(fn): 60 | fn.__doc__ = ''.join(docstr) + fn.__doc__ 61 | return fn 62 | return docstring_decorator 63 | 64 | def add_end_docstrings(*docstr): 65 | def docstring_decorator(fn): 66 | fn.__doc__ = fn.__doc__ + ''.join(docstr) 67 | return fn 68 | return docstring_decorator 69 | else: 70 | # Not possible to update class docstrings on python2 71 | def add_start_docstrings(*docstr): 72 | def docstring_decorator(fn): 73 | return fn 74 | return docstring_decorator 75 | 76 | def add_end_docstrings(*docstr): 77 | def docstring_decorator(fn): 78 | return fn 79 | return docstring_decorator 80 | 81 | def url_to_filename(url, etag=None): 82 | """ 83 | Convert `url` into a hashed filename in a repeatable way. 84 | If `etag` is specified, append its hash to the url's, delimited 85 | by a period. 86 | """ 87 | url_bytes = url.encode('utf-8') 88 | url_hash = sha256(url_bytes) 89 | filename = url_hash.hexdigest() 90 | 91 | if etag: 92 | etag_bytes = etag.encode('utf-8') 93 | etag_hash = sha256(etag_bytes) 94 | filename += '.' + etag_hash.hexdigest() 95 | 96 | return filename 97 | 98 | 99 | def filename_to_url(filename, cache_dir=None): 100 | """ 101 | Return the url and etag (which may be ``None``) stored for `filename`. 102 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 103 | """ 104 | if cache_dir is None: 105 | cache_dir = PYTORCH_TRANSFORMERS_CACHE 106 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 107 | cache_dir = str(cache_dir) 108 | 109 | cache_path = os.path.join(cache_dir, filename) 110 | if not os.path.exists(cache_path): 111 | raise EnvironmentError("file {} not found".format(cache_path)) 112 | 113 | meta_path = cache_path + '.json' 114 | if not os.path.exists(meta_path): 115 | raise EnvironmentError("file {} not found".format(meta_path)) 116 | 117 | with open(meta_path, encoding="utf-8") as meta_file: 118 | metadata = json.load(meta_file) 119 | url = metadata['url'] 120 | etag = metadata['etag'] 121 | 122 | return url, etag 123 | 124 | 125 | def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None): 126 | """ 127 | Given something that might be a URL (or might be a local path), 128 | determine which. If it's a URL, download the file and cache it, and 129 | return the path to the cached file. If it's already a local path, 130 | make sure the file exists and then return the path. 131 | Args: 132 | cache_dir: specify a cache directory to save the file to (overwrite the default cache dir). 133 | force_download: if True, re-dowload the file even if it's already cached in the cache dir. 134 | """ 135 | if cache_dir is None: 136 | cache_dir = PYTORCH_TRANSFORMERS_CACHE 137 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 138 | url_or_filename = str(url_or_filename) 139 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 140 | cache_dir = str(cache_dir) 141 | 142 | parsed = urlparse(url_or_filename) 143 | 144 | if parsed.scheme in ('http', 'https', 's3'): 145 | # URL, so get it from the cache (downloading if necessary) 146 | return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies) 147 | elif os.path.exists(url_or_filename): 148 | # File, and it exists. 149 | return url_or_filename 150 | elif parsed.scheme == '': 151 | # File, but it doesn't exist. 152 | raise EnvironmentError("file {} not found".format(url_or_filename)) 153 | else: 154 | # Something unknown 155 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 156 | 157 | 158 | def split_s3_path(url): 159 | """Split a full s3 path into the bucket name and path.""" 160 | parsed = urlparse(url) 161 | if not parsed.netloc or not parsed.path: 162 | raise ValueError("bad s3 path {}".format(url)) 163 | bucket_name = parsed.netloc 164 | s3_path = parsed.path 165 | # Remove '/' at beginning of path. 166 | if s3_path.startswith("/"): 167 | s3_path = s3_path[1:] 168 | return bucket_name, s3_path 169 | 170 | 171 | def s3_request(func): 172 | """ 173 | Wrapper function for s3 requests in order to create more helpful error 174 | messages. 175 | """ 176 | 177 | @wraps(func) 178 | def wrapper(url, *args, **kwargs): 179 | try: 180 | return func(url, *args, **kwargs) 181 | except ClientError as exc: 182 | if int(exc.response["Error"]["Code"]) == 404: 183 | raise EnvironmentError("file {} not found".format(url)) 184 | else: 185 | raise 186 | 187 | return wrapper 188 | 189 | 190 | @s3_request 191 | def s3_etag(url, proxies=None): 192 | """Check ETag on S3 object.""" 193 | s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) 194 | bucket_name, s3_path = split_s3_path(url) 195 | s3_object = s3_resource.Object(bucket_name, s3_path) 196 | return s3_object.e_tag 197 | 198 | 199 | @s3_request 200 | def s3_get(url, temp_file, proxies=None): 201 | """Pull a file directly from S3.""" 202 | s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) 203 | bucket_name, s3_path = split_s3_path(url) 204 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 205 | 206 | 207 | def http_get(url, temp_file, proxies=None): 208 | req = requests.get(url, stream=True, proxies=proxies) 209 | content_length = req.headers.get('Content-Length') 210 | total = int(content_length) if content_length is not None else None 211 | progress = tqdm(unit="B", total=total) 212 | for chunk in req.iter_content(chunk_size=1024): 213 | if chunk: # filter out keep-alive new chunks 214 | progress.update(len(chunk)) 215 | temp_file.write(chunk) 216 | progress.close() 217 | 218 | 219 | def get_from_cache(url, cache_dir=None, force_download=False, proxies=None): 220 | """ 221 | Given a URL, look for the corresponding dataset in the local cache. 222 | If it's not there, download it. Then return the path to the cached file. 223 | """ 224 | if cache_dir is None: 225 | cache_dir = PYTORCH_TRANSFORMERS_CACHE 226 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 227 | cache_dir = str(cache_dir) 228 | if sys.version_info[0] == 2 and not isinstance(cache_dir, str): 229 | cache_dir = str(cache_dir) 230 | 231 | if not os.path.exists(cache_dir): 232 | os.makedirs(cache_dir) 233 | 234 | # Get eTag to add to filename, if it exists. 235 | if url.startswith("s3://"): 236 | etag = s3_etag(url, proxies=proxies) 237 | else: 238 | try: 239 | response = requests.head(url, allow_redirects=True, proxies=proxies) 240 | if response.status_code != 200: 241 | etag = None 242 | else: 243 | etag = response.headers.get("ETag") 244 | except EnvironmentError: 245 | etag = None 246 | 247 | if sys.version_info[0] == 2 and etag is not None: 248 | etag = etag.decode('utf-8') 249 | filename = url_to_filename(url, etag) 250 | 251 | # get cache path to put the file 252 | cache_path = os.path.join(cache_dir, filename) 253 | 254 | # If we don't have a connection (etag is None) and can't identify the file 255 | # try to get the last downloaded one 256 | if not os.path.exists(cache_path) and etag is None: 257 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') 258 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) 259 | if matching_files: 260 | cache_path = os.path.join(cache_dir, matching_files[-1]) 261 | 262 | if not os.path.exists(cache_path) or force_download: 263 | # Download to temporary file, then copy to cache dir once finished. 264 | # Otherwise you get corrupt cache entries if the download gets interrupted. 265 | with tempfile.NamedTemporaryFile() as temp_file: 266 | logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) 267 | 268 | # GET file object 269 | if url.startswith("s3://"): 270 | s3_get(url, temp_file, proxies=proxies) 271 | else: 272 | http_get(url, temp_file, proxies=proxies) 273 | 274 | # we are copying the file before closing it, so flush to avoid truncation 275 | temp_file.flush() 276 | # shutil.copyfileobj() starts at the current position, so go to the start 277 | temp_file.seek(0) 278 | 279 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 280 | with open(cache_path, 'wb') as cache_file: 281 | shutil.copyfileobj(temp_file, cache_file) 282 | 283 | logger.info("creating metadata file for %s", cache_path) 284 | meta = {'url': url, 'etag': etag} 285 | meta_path = cache_path + '.json' 286 | with open(meta_path, 'w') as meta_file: 287 | output_string = json.dumps(meta) 288 | if sys.version_info[0] == 2 and isinstance(output_string, str): 289 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2 290 | meta_file.write(output_string) 291 | 292 | logger.info("removing temp file %s", temp_file.name) 293 | 294 | return cache_path 295 | -------------------------------------------------------------------------------- /albert_zh/modeling_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model.""" 17 | 18 | from __future__ import (absolute_import, division, print_function, 19 | unicode_literals) 20 | 21 | import copy 22 | import json 23 | import logging 24 | import os 25 | from io import open 26 | 27 | import six 28 | import torch 29 | from torch import nn 30 | from torch.nn import CrossEntropyLoss 31 | from torch.nn import functional as F 32 | 33 | from .configuration_utils import PretrainedConfig 34 | from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | 39 | try: 40 | from torch.nn import Identity 41 | except ImportError: 42 | # Older PyTorch compatibility 43 | class Identity(nn.Module): 44 | r"""A placeholder identity operator that is argument-insensitive. 45 | """ 46 | def __init__(self, *args, **kwargs): 47 | super(Identity, self).__init__() 48 | 49 | def forward(self, input): 50 | return input 51 | 52 | class PreTrainedModel(nn.Module): 53 | r""" Base class for all models. 54 | 55 | :class:`~pytorch_transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models 56 | as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads. 57 | 58 | Class attributes (overridden by derived classes): 59 | - ``config_class``: a class derived from :class:`~pytorch_transformers.PretrainedConfig` to use as configuration class for this model architecture. 60 | - ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values. 61 | - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments: 62 | 63 | - ``model``: an instance of the relevant subclass of :class:`~pytorch_transformers.PreTrainedModel`, 64 | - ``config``: an instance of the relevant subclass of :class:`~pytorch_transformers.PretrainedConfig`, 65 | - ``path``: a path (string) to the TensorFlow checkpoint. 66 | 67 | - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. 68 | """ 69 | config_class = None 70 | pretrained_model_archive_map = {} 71 | load_tf_weights = lambda model, config, path: None 72 | base_model_prefix = "" 73 | 74 | def __init__(self, config, *inputs, **kwargs): 75 | super(PreTrainedModel, self).__init__() 76 | if not isinstance(config, PretrainedConfig): 77 | raise ValueError( 78 | "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. " 79 | "To create a model from a pretrained model use " 80 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 81 | self.__class__.__name__, self.__class__.__name__ 82 | )) 83 | # Save config in model 84 | self.config = config 85 | 86 | def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): 87 | """ Build a resized Embedding Module from a provided token Embedding Module. 88 | Increasing the size will add newly initialized vectors at the end 89 | Reducing the size will remove vectors from the end 90 | 91 | Args: 92 | new_num_tokens: (`optional`) int 93 | New number of tokens in the embedding matrix. 94 | Increasing the size will add newly initialized vectors at the end 95 | Reducing the size will remove vectors from the end 96 | If not provided or None: return the provided token Embedding Module. 97 | Return: ``torch.nn.Embeddings`` 98 | Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None 99 | """ 100 | if new_num_tokens is None: 101 | return old_embeddings 102 | 103 | old_num_tokens, old_embedding_dim = old_embeddings.weight.size() 104 | if old_num_tokens == new_num_tokens: 105 | return old_embeddings 106 | 107 | # Build new embeddings 108 | new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim) 109 | new_embeddings.to(old_embeddings.weight.device) 110 | 111 | # initialize all new embeddings (in particular added tokens) 112 | self._init_weights(new_embeddings) 113 | 114 | # Copy word embeddings from the previous weights 115 | num_tokens_to_copy = min(old_num_tokens, new_num_tokens) 116 | new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :] 117 | 118 | return new_embeddings 119 | 120 | def _tie_or_clone_weights(self, first_module, second_module): 121 | """ Tie or clone module weights depending of weither we are using TorchScript or not 122 | """ 123 | 124 | if self.config.torchscript: 125 | first_module.weight = nn.Parameter(second_module.weight.clone()) 126 | else: 127 | first_module.weight = second_module.weight 128 | 129 | 130 | if hasattr(first_module, 'bias') and first_module.bias is not None: 131 | first_module.bias.data = torch.nn.functional.pad( 132 | first_module.bias.data, 133 | (0, first_module.weight.shape[0] - first_module.bias.shape[0]), 134 | 'constant', 135 | 0 136 | ) 137 | 138 | def resize_token_embeddings(self, new_num_tokens=None): 139 | """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. 140 | Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. 141 | 142 | Arguments: 143 | 144 | new_num_tokens: (`optional`) int: 145 | New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. 146 | If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model. 147 | 148 | Return: ``torch.nn.Embeddings`` 149 | Pointer to the input tokens Embeddings Module of the model 150 | """ 151 | base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed 152 | model_embeds = base_model._resize_token_embeddings(new_num_tokens) 153 | if new_num_tokens is None: 154 | return model_embeds 155 | 156 | # Update base model and current model config 157 | self.config.vocab_size = new_num_tokens 158 | base_model.vocab_size = new_num_tokens 159 | 160 | # Tie weights again if needed 161 | if hasattr(self, 'tie_weights'): 162 | self.tie_weights() 163 | 164 | return model_embeds 165 | 166 | def init_weights(self): 167 | """ Initialize and prunes weights if needed. """ 168 | # Initialize weights 169 | self.apply(self._init_weights) 170 | 171 | # Prune heads if needed 172 | if self.config.pruned_heads: 173 | self.prune_heads(self.config.pruned_heads) 174 | 175 | def prune_heads(self, heads_to_prune): 176 | """ Prunes heads of the base model. 177 | 178 | Arguments: 179 | 180 | heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`). 181 | E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2. 182 | """ 183 | base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed 184 | 185 | # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads 186 | for layer, heads in heads_to_prune.items(): 187 | union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads) 188 | self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON 189 | 190 | base_model._prune_heads(heads_to_prune) 191 | 192 | def save_pretrained(self, save_directory): 193 | """ Save a model and its configuration file to a directory, so that it 194 | can be re-loaded using the `:func:`~pytorch_transformers.PreTrainedModel.from_pretrained`` class method. 195 | """ 196 | assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" 197 | 198 | # Only save the model it-self if we are using distributed training 199 | model_to_save = self.module if hasattr(self, 'module') else self 200 | 201 | # Save configuration file 202 | model_to_save.config.save_pretrained(save_directory) 203 | 204 | # If we save using the predefined names, we can load using `from_pretrained` 205 | output_model_file = os.path.join(save_directory, WEIGHTS_NAME) 206 | 207 | torch.save(model_to_save.state_dict(), output_model_file) 208 | 209 | @classmethod 210 | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): 211 | r"""Instantiate a pretrained pytorch model from a pre-trained model configuration. 212 | 213 | The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated) 214 | To train the model, you should first set it back in training mode with ``model.train()`` 215 | 216 | The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model. 217 | It is up to you to train those weights with a downstream fine-tuning task. 218 | 219 | The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded. 220 | 221 | Parameters: 222 | pretrained_model_name_or_path: either: 223 | 224 | - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. 225 | - a path to a `directory` containing model weights saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. 226 | - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. 227 | 228 | model_args: (`optional`) Sequence of positional arguments: 229 | All remaning positional arguments will be passed to the underlying model's ``__init__`` method 230 | 231 | config: (`optional`) instance of a class derived from :class:`~pytorch_transformers.PretrainedConfig`: 232 | Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: 233 | 234 | - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or 235 | - the model was saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. 236 | - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. 237 | 238 | state_dict: (`optional`) dict: 239 | an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file. 240 | This option can be used if you want to create a model from a pretrained configuration but load your own weights. 241 | In this case though, you should check if using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and :func:`~pytorch_transformers.PreTrainedModel.from_pretrained` is not a simpler option. 242 | 243 | cache_dir: (`optional`) string: 244 | Path to a directory in which a downloaded pre-trained model 245 | configuration should be cached if the standard cache should not be used. 246 | 247 | force_download: (`optional`) boolean, default False: 248 | Force to (re-)download the model weights and configuration files and override the cached versions if they exists. 249 | 250 | proxies: (`optional`) dict, default None: 251 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. 252 | The proxies are used on each request. 253 | 254 | output_loading_info: (`optional`) boolean: 255 | Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. 256 | 257 | kwargs: (`optional`) Remaining dictionary of keyword arguments: 258 | Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: 259 | 260 | - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) 261 | - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~pytorch_transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. 262 | 263 | Examples:: 264 | 265 | model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. 266 | model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` 267 | model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading 268 | assert model.config.output_attention == True 269 | # Loading from a TF checkpoint file instead of a PyTorch model (slower) 270 | config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') 271 | model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config) 272 | 273 | """ 274 | config = kwargs.pop('config', None) 275 | state_dict = kwargs.pop('state_dict', None) 276 | cache_dir = kwargs.pop('cache_dir', None) 277 | from_tf = kwargs.pop('from_tf', False) 278 | force_download = kwargs.pop('force_download', False) 279 | proxies = kwargs.pop('proxies', None) 280 | output_loading_info = kwargs.pop('output_loading_info', False) 281 | 282 | # Load config 283 | if config is None: 284 | config, model_kwargs = cls.config_class.from_pretrained( 285 | pretrained_model_name_or_path, *model_args, 286 | cache_dir=cache_dir, return_unused_kwargs=True, 287 | force_download=force_download, 288 | **kwargs 289 | ) 290 | else: 291 | model_kwargs = kwargs 292 | 293 | # Load model 294 | if pretrained_model_name_or_path in cls.pretrained_model_archive_map: 295 | archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path] 296 | elif os.path.isdir(pretrained_model_name_or_path): 297 | if from_tf: 298 | # Directly load from a TensorFlow checkpoint 299 | archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") 300 | else: 301 | archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) 302 | else: 303 | if from_tf: 304 | # Directly load from a TensorFlow checkpoint 305 | archive_file = pretrained_model_name_or_path + ".index" 306 | else: 307 | archive_file = pretrained_model_name_or_path 308 | # redirect to the cache, if necessary 309 | try: 310 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) 311 | except EnvironmentError as e: 312 | if pretrained_model_name_or_path in cls.pretrained_model_archive_map: 313 | logger.error( 314 | "Couldn't reach server at '{}' to download pretrained weights.".format( 315 | archive_file)) 316 | else: 317 | logger.error( 318 | "Model name '{}' was not found in model name list ({}). " 319 | "We assumed '{}' was a path or url but couldn't find any file " 320 | "associated to this path or url.".format( 321 | pretrained_model_name_or_path, 322 | ', '.join(cls.pretrained_model_archive_map.keys()), 323 | archive_file)) 324 | raise e 325 | if resolved_archive_file == archive_file: 326 | logger.info("loading weights file {}".format(archive_file)) 327 | else: 328 | logger.info("loading weights file {} from cache at {}".format( 329 | archive_file, resolved_archive_file)) 330 | 331 | # Instantiate model. 332 | model = cls(config, *model_args, **model_kwargs) 333 | 334 | if state_dict is None and not from_tf: 335 | state_dict = torch.load(resolved_archive_file, map_location='cpu') 336 | if from_tf: 337 | # Directly load from a TensorFlow checkpoint 338 | return cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index' 339 | 340 | # Convert old format to new format if needed from a PyTorch state_dict 341 | old_keys = [] 342 | new_keys = [] 343 | for key in state_dict.keys(): 344 | new_key = None 345 | if 'gamma' in key: 346 | new_key = key.replace('gamma', 'weight') 347 | if 'beta' in key: 348 | new_key = key.replace('beta', 'bias') 349 | if new_key: 350 | old_keys.append(key) 351 | new_keys.append(new_key) 352 | for old_key, new_key in zip(old_keys, new_keys): 353 | state_dict[new_key] = state_dict.pop(old_key) 354 | 355 | # Load from a PyTorch state_dict 356 | missing_keys = [] 357 | unexpected_keys = [] 358 | error_msgs = [] 359 | # copy state_dict so _load_from_state_dict can modify it 360 | metadata = getattr(state_dict, '_metadata', None) 361 | state_dict = state_dict.copy() 362 | if metadata is not None: 363 | state_dict._metadata = metadata 364 | 365 | def load(module, prefix=''): 366 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 367 | module._load_from_state_dict( 368 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 369 | for name, child in module._modules.items(): 370 | if child is not None: 371 | load(child, prefix + name + '.') 372 | 373 | # Make sure we are able to load base models as well as derived models (with heads) 374 | start_prefix = '' 375 | model_to_load = model 376 | if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): 377 | start_prefix = cls.base_model_prefix + '.' 378 | if hasattr(model, cls.base_model_prefix) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): 379 | model_to_load = getattr(model, cls.base_model_prefix) 380 | 381 | load(model_to_load, prefix=start_prefix) 382 | if len(missing_keys) > 0: 383 | logger.info("Weights of {} not initialized from pretrained model: {}".format( 384 | model.__class__.__name__, missing_keys)) 385 | if len(unexpected_keys) > 0: 386 | logger.info("Weights from pretrained model not used in {}: {}".format( 387 | model.__class__.__name__, unexpected_keys)) 388 | if len(error_msgs) > 0: 389 | raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 390 | model.__class__.__name__, "\n\t".join(error_msgs))) 391 | 392 | if hasattr(model, 'tie_weights'): 393 | model.tie_weights() # make sure word embedding weights are still tied 394 | 395 | # Set model in evaluation mode to desactivate DropOut modules by default 396 | model.eval() 397 | 398 | if output_loading_info: 399 | loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs} 400 | return model, loading_info 401 | 402 | return model 403 | 404 | 405 | class Conv1D(nn.Module): 406 | def __init__(self, nf, nx): 407 | """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2) 408 | Basically works like a Linear layer but the weights are transposed 409 | """ 410 | super(Conv1D, self).__init__() 411 | self.nf = nf 412 | w = torch.empty(nx, nf) 413 | nn.init.normal_(w, std=0.02) 414 | self.weight = nn.Parameter(w) 415 | self.bias = nn.Parameter(torch.zeros(nf)) 416 | 417 | def forward(self, x): 418 | size_out = x.size()[:-1] + (self.nf,) 419 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) 420 | x = x.view(*size_out) 421 | return x 422 | 423 | 424 | class PoolerStartLogits(nn.Module): 425 | """ Compute SQuAD start_logits from sequence hidden states. """ 426 | def __init__(self, config): 427 | super(PoolerStartLogits, self).__init__() 428 | self.dense = nn.Linear(config.hidden_size, 1) 429 | 430 | def forward(self, hidden_states, p_mask=None): 431 | """ Args: 432 | **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)` 433 | invalid position mask such as query and special symbols (PAD, SEP, CLS) 434 | 1.0 means token should be masked. 435 | """ 436 | x = self.dense(hidden_states).squeeze(-1) 437 | 438 | if p_mask is not None: 439 | x = x * (1 - p_mask) - 1e30 * p_mask 440 | 441 | return x 442 | 443 | 444 | class PoolerEndLogits(nn.Module): 445 | """ Compute SQuAD end_logits from sequence hidden states and start token hidden state. 446 | """ 447 | def __init__(self, config): 448 | super(PoolerEndLogits, self).__init__() 449 | self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) 450 | self.activation = nn.Tanh() 451 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 452 | self.dense_1 = nn.Linear(config.hidden_size, 1) 453 | 454 | def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None): 455 | """ Args: 456 | One of ``start_states``, ``start_positions`` should be not None. 457 | If both are set, ``start_positions`` overrides ``start_states``. 458 | 459 | **start_states**: ``torch.LongTensor`` of shape identical to hidden_states 460 | hidden states of the first tokens for the labeled span. 461 | **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` 462 | position of the first token for the labeled span: 463 | **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)`` 464 | Mask of invalid position such as query and special symbols (PAD, SEP, CLS) 465 | 1.0 means token should be masked. 466 | """ 467 | assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None" 468 | if start_positions is not None: 469 | slen, hsz = hidden_states.shape[-2:] 470 | start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) 471 | start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) 472 | start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) 473 | 474 | x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) 475 | x = self.activation(x) 476 | x = self.LayerNorm(x) 477 | x = self.dense_1(x).squeeze(-1) 478 | 479 | if p_mask is not None: 480 | x = x * (1 - p_mask) - 1e30 * p_mask 481 | 482 | return x 483 | 484 | 485 | class PoolerAnswerClass(nn.Module): 486 | """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """ 487 | def __init__(self, config): 488 | super(PoolerAnswerClass, self).__init__() 489 | self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) 490 | self.activation = nn.Tanh() 491 | self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False) 492 | 493 | def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None): 494 | """ 495 | Args: 496 | One of ``start_states``, ``start_positions`` should be not None. 497 | If both are set, ``start_positions`` overrides ``start_states``. 498 | 499 | **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``. 500 | hidden states of the first tokens for the labeled span. 501 | **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` 502 | position of the first token for the labeled span. 503 | **cls_index**: torch.LongTensor of shape ``(batch_size,)`` 504 | position of the CLS token. If None, take the last token. 505 | 506 | note(Original repo): 507 | no dependency on end_feature so that we can obtain one single `cls_logits` 508 | for each sample 509 | """ 510 | hsz = hidden_states.shape[-1] 511 | assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None" 512 | if start_positions is not None: 513 | start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) 514 | start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz) 515 | 516 | if cls_index is not None: 517 | cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) 518 | cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz) 519 | else: 520 | cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) 521 | 522 | x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1)) 523 | x = self.activation(x) 524 | x = self.dense_1(x).squeeze(-1) 525 | 526 | return x 527 | 528 | 529 | class SQuADHead(nn.Module): 530 | r""" A SQuAD head inspired by XLNet. 531 | 532 | Parameters: 533 | config (:class:`~pytorch_transformers.XLNetConfig`): Model configuration class with all the parameters of the model. 534 | 535 | Inputs: 536 | **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)`` 537 | hidden states of sequence tokens 538 | **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` 539 | position of the first token for the labeled span. 540 | **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` 541 | position of the last token for the labeled span. 542 | **cls_index**: torch.LongTensor of shape ``(batch_size,)`` 543 | position of the CLS token. If None, take the last token. 544 | **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)`` 545 | Whether the question has a possible answer in the paragraph or not. 546 | **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)`` 547 | Mask of invalid position such as query and special symbols (PAD, SEP, CLS) 548 | 1.0 means token should be masked. 549 | 550 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 551 | **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``: 552 | Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses. 553 | **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 554 | ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)`` 555 | Log probabilities for the top config.start_n_top start token possibilities (beam-search). 556 | **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 557 | ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)`` 558 | Indices for the top config.start_n_top start token possibilities (beam-search). 559 | **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 560 | ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)`` 561 | Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). 562 | **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 563 | ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)`` 564 | Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). 565 | **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 566 | ``torch.FloatTensor`` of shape ``(batch_size,)`` 567 | Log probabilities for the ``is_impossible`` label of the answers. 568 | """ 569 | def __init__(self, config): 570 | super(SQuADHead, self).__init__() 571 | self.start_n_top = config.start_n_top 572 | self.end_n_top = config.end_n_top 573 | 574 | self.start_logits = PoolerStartLogits(config) 575 | self.end_logits = PoolerEndLogits(config) 576 | self.answer_class = PoolerAnswerClass(config) 577 | 578 | def forward(self, hidden_states, start_positions=None, end_positions=None, 579 | cls_index=None, is_impossible=None, p_mask=None): 580 | outputs = () 581 | 582 | start_logits = self.start_logits(hidden_states, p_mask=p_mask) 583 | 584 | if start_positions is not None and end_positions is not None: 585 | # If we are on multi-GPU, let's remove the dimension added by batch splitting 586 | for x in (start_positions, end_positions, cls_index, is_impossible): 587 | if x is not None and x.dim() > 1: 588 | x.squeeze_(-1) 589 | 590 | # during training, compute the end logits based on the ground truth of the start position 591 | end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask) 592 | 593 | loss_fct = CrossEntropyLoss() 594 | start_loss = loss_fct(start_logits, start_positions) 595 | end_loss = loss_fct(end_logits, end_positions) 596 | total_loss = (start_loss + end_loss) / 2 597 | 598 | if cls_index is not None and is_impossible is not None: 599 | # Predict answerability from the representation of CLS and START 600 | cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index) 601 | loss_fct_cls = nn.BCEWithLogitsLoss() 602 | cls_loss = loss_fct_cls(cls_logits, is_impossible) 603 | 604 | # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss 605 | total_loss += cls_loss * 0.5 606 | 607 | outputs = (total_loss,) + outputs 608 | 609 | else: 610 | # during inference, compute the end logits based on beam search 611 | bsz, slen, hsz = hidden_states.size() 612 | start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen) 613 | 614 | start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top) 615 | start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) 616 | start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz) 617 | start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) 618 | 619 | hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz) 620 | p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None 621 | end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) 622 | end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) 623 | 624 | end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top) 625 | end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) 626 | end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) 627 | 628 | start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) 629 | cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) 630 | 631 | outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs 632 | 633 | # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits 634 | # or (if labels are provided) (total_loss,) 635 | return outputs 636 | 637 | 638 | class SequenceSummary(nn.Module): 639 | r""" Compute a single vector summary of a sequence hidden states according to various possibilities: 640 | Args of the config class: 641 | summary_type: 642 | - 'last' => [default] take the last token hidden state (like XLNet) 643 | - 'first' => take the first token hidden state (like Bert) 644 | - 'mean' => take the mean of all tokens hidden states 645 | - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2) 646 | - 'attn' => Not implemented now, use multi-head attention 647 | summary_use_proj: Add a projection after the vector extraction 648 | summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False. 649 | summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default 650 | summary_first_dropout: Add a dropout before the projection and activation 651 | summary_last_dropout: Add a dropout after the projection and activation 652 | """ 653 | def __init__(self, config): 654 | super(SequenceSummary, self).__init__() 655 | 656 | self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last' 657 | if self.summary_type == 'attn': 658 | # We should use a standard multi-head attention module with absolute positional embedding for that. 659 | # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 660 | # We can probably just use the multi-head attention module of PyTorch >=1.1.0 661 | raise NotImplementedError 662 | 663 | self.summary = Identity() 664 | if hasattr(config, 'summary_use_proj') and config.summary_use_proj: 665 | if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0: 666 | num_classes = config.num_labels 667 | else: 668 | num_classes = config.hidden_size 669 | self.summary = nn.Linear(config.hidden_size, num_classes) 670 | 671 | self.activation = Identity() 672 | if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh': 673 | self.activation = nn.Tanh() 674 | 675 | self.first_dropout = Identity() 676 | if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0: 677 | self.first_dropout = nn.Dropout(config.summary_first_dropout) 678 | 679 | self.last_dropout = Identity() 680 | if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0: 681 | self.last_dropout = nn.Dropout(config.summary_last_dropout) 682 | 683 | def forward(self, hidden_states, cls_index=None): 684 | """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer. 685 | cls_index: [optional] position of the classification token if summary_type == 'cls_index', 686 | shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states. 687 | if summary_type == 'cls_index' and cls_index is None: 688 | we take the last token of the sequence as classification token 689 | """ 690 | if self.summary_type == 'last': 691 | output = hidden_states[:, -1] 692 | elif self.summary_type == 'first': 693 | output = hidden_states[:, 0] 694 | elif self.summary_type == 'mean': 695 | output = hidden_states.mean(dim=1) 696 | elif self.summary_type == 'cls_index': 697 | if cls_index is None: 698 | cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2]-1, dtype=torch.long) 699 | else: 700 | cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) 701 | cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),)) 702 | # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states 703 | output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) 704 | elif self.summary_type == 'attn': 705 | raise NotImplementedError 706 | 707 | output = self.first_dropout(output) 708 | output = self.summary(output) 709 | output = self.activation(output) 710 | output = self.last_dropout(output) 711 | 712 | return output 713 | 714 | 715 | def prune_linear_layer(layer, index, dim=0): 716 | """ Prune a linear layer (a model parameters) to keep only entries in index. 717 | Return the pruned layer as a new layer with requires_grad=True. 718 | Used to remove heads. 719 | """ 720 | index = index.to(layer.weight.device) 721 | W = layer.weight.index_select(dim, index).clone().detach() 722 | if layer.bias is not None: 723 | if dim == 1: 724 | b = layer.bias.clone().detach() 725 | else: 726 | b = layer.bias[index].clone().detach() 727 | new_size = list(layer.weight.size()) 728 | new_size[dim] = len(index) 729 | new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device) 730 | new_layer.weight.requires_grad = False 731 | new_layer.weight.copy_(W.contiguous()) 732 | new_layer.weight.requires_grad = True 733 | if layer.bias is not None: 734 | new_layer.bias.requires_grad = False 735 | new_layer.bias.copy_(b.contiguous()) 736 | new_layer.bias.requires_grad = True 737 | return new_layer 738 | 739 | 740 | def prune_conv1d_layer(layer, index, dim=1): 741 | """ Prune a Conv1D layer (a model parameters) to keep only entries in index. 742 | A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed. 743 | Return the pruned layer as a new layer with requires_grad=True. 744 | Used to remove heads. 745 | """ 746 | index = index.to(layer.weight.device) 747 | W = layer.weight.index_select(dim, index).clone().detach() 748 | if dim == 0: 749 | b = layer.bias.clone().detach() 750 | else: 751 | b = layer.bias[index].clone().detach() 752 | new_size = list(layer.weight.size()) 753 | new_size[dim] = len(index) 754 | new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device) 755 | new_layer.weight.requires_grad = False 756 | new_layer.weight.copy_(W.contiguous()) 757 | new_layer.weight.requires_grad = True 758 | new_layer.bias.requires_grad = False 759 | new_layer.bias.copy_(b.contiguous()) 760 | new_layer.bias.requires_grad = True 761 | return new_layer 762 | 763 | 764 | def prune_layer(layer, index, dim=None): 765 | """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index. 766 | Return the pruned layer as a new layer with requires_grad=True. 767 | Used to remove heads. 768 | """ 769 | if isinstance(layer, nn.Linear): 770 | return prune_linear_layer(layer, index, dim=0 if dim is None else dim) 771 | elif isinstance(layer, Conv1D): 772 | return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) 773 | else: 774 | raise ValueError("Can't prune layer of class {}".format(layer.__class__)) 775 | -------------------------------------------------------------------------------- /albert_zh/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import logging 18 | import math 19 | 20 | import torch 21 | from torch.optim import Optimizer 22 | from torch.optim.lr_scheduler import LambdaLR 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | class ConstantLRSchedule(LambdaLR): 27 | """ Constant learning rate schedule. 28 | """ 29 | def __init__(self, optimizer, last_epoch=-1): 30 | super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch) 31 | 32 | 33 | class WarmupConstantSchedule(LambdaLR): 34 | """ Linear warmup and then constant. 35 | Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps. 36 | Keeps learning rate schedule equal to 1. after warmup_steps. 37 | """ 38 | def __init__(self, optimizer, warmup_steps, last_epoch=-1): 39 | self.warmup_steps = warmup_steps 40 | super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 41 | 42 | def lr_lambda(self, step): 43 | if step < self.warmup_steps: 44 | return float(step) / float(max(1.0, self.warmup_steps)) 45 | return 1. 46 | 47 | 48 | class WarmupLinearSchedule(LambdaLR): 49 | """ Linear warmup and then linear decay. 50 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 51 | Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps. 52 | """ 53 | def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1): 54 | self.warmup_steps = warmup_steps 55 | self.t_total = t_total 56 | super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 57 | 58 | def lr_lambda(self, step): 59 | if step < self.warmup_steps: 60 | return float(step) / float(max(1, self.warmup_steps)) 61 | return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps))) 62 | 63 | 64 | class WarmupCosineSchedule(LambdaLR): 65 | """ Linear warmup and then cosine decay. 66 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 67 | Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve. 68 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 69 | """ 70 | def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): 71 | self.warmup_steps = warmup_steps 72 | self.t_total = t_total 73 | self.cycles = cycles 74 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 75 | 76 | def lr_lambda(self, step): 77 | if step < self.warmup_steps: 78 | return float(step) / float(max(1.0, self.warmup_steps)) 79 | # progress after warmup 80 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 81 | return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) 82 | 83 | 84 | class WarmupCosineWithHardRestartsSchedule(LambdaLR): 85 | """ Linear warmup and then cosine cycles with hard restarts. 86 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 87 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying 88 | learning rate (with hard restarts). 89 | """ 90 | def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1): 91 | self.warmup_steps = warmup_steps 92 | self.t_total = t_total 93 | self.cycles = cycles 94 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 95 | 96 | def lr_lambda(self, step): 97 | if step < self.warmup_steps: 98 | return float(step) / float(max(1, self.warmup_steps)) 99 | # progress after warmup 100 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 101 | if progress >= 1.0: 102 | return 0.0 103 | return max(0.0, 0.5 * (1. + math.cos(math.pi * ((float(self.cycles) * progress) % 1.0)))) 104 | 105 | 106 | 107 | class AdamW(Optimizer): 108 | """ Implements Adam algorithm with weight decay fix. 109 | 110 | Parameters: 111 | lr (float): learning rate. Default 1e-3. 112 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) 113 | eps (float): Adams epsilon. Default: 1e-6 114 | weight_decay (float): Weight decay. Default: 0.0 115 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True. 116 | """ 117 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True): 118 | if lr < 0.0: 119 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 120 | if not 0.0 <= betas[0] < 1.0: 121 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])) 122 | if not 0.0 <= betas[1] < 1.0: 123 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])) 124 | if not 0.0 <= eps: 125 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) 126 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 127 | correct_bias=correct_bias) 128 | super(AdamW, self).__init__(params, defaults) 129 | 130 | def step(self, closure=None): 131 | """Performs a single optimization step. 132 | 133 | Arguments: 134 | closure (callable, optional): A closure that reevaluates the model 135 | and returns the loss. 136 | """ 137 | loss = None 138 | if closure is not None: 139 | loss = closure() 140 | 141 | for group in self.param_groups: 142 | for p in group['params']: 143 | if p.grad is None: 144 | continue 145 | grad = p.grad.data 146 | if grad.is_sparse: 147 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 148 | 149 | state = self.state[p] 150 | 151 | # State initialization 152 | if len(state) == 0: 153 | state['step'] = 0 154 | # Exponential moving average of gradient values 155 | state['exp_avg'] = torch.zeros_like(p.data) 156 | # Exponential moving average of squared gradient values 157 | state['exp_avg_sq'] = torch.zeros_like(p.data) 158 | 159 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 160 | beta1, beta2 = group['betas'] 161 | 162 | state['step'] += 1 163 | 164 | # Decay the first and second moment running average coefficient 165 | # In-place operations to update the averages at the same time 166 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) 167 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) 168 | denom = exp_avg_sq.sqrt().add_(group['eps']) 169 | 170 | step_size = group['lr'] 171 | if group['correct_bias']: # No bias correction for Bert 172 | bias_correction1 = 1.0 - beta1 ** state['step'] 173 | bias_correction2 = 1.0 - beta2 ** state['step'] 174 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 175 | 176 | p.data.addcdiv_(-step_size, exp_avg, denom) 177 | 178 | # Just adding the square of the weights to the loss function is *not* 179 | # the correct way of using L2 regularization/weight decay with Adam, 180 | # since that will interact with the m and v parameters in strange ways. 181 | # 182 | # Instead we want to decay the weights in a manner that doesn't interact 183 | # with the m/v parameters. This is equivalent to adding the square 184 | # of the weights to the loss with plain (non-momentum) SGD. 185 | # Add weight decay at the end (fixed version) 186 | if group['weight_decay'] > 0.0: 187 | p.data.add_(-group['lr'] * group['weight_decay'], p.data) 188 | 189 | return loss 190 | -------------------------------------------------------------------------------- /albert_zh/tokenization_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | from io import open 24 | 25 | from .tokenization_utils import PreTrainedTokenizer 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} 30 | 31 | PRETRAINED_VOCAB_FILES_MAP = { 32 | 'vocab_file': 33 | { 34 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 35 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 36 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 37 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 38 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 39 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 40 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 41 | 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", 42 | 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", 43 | 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", 44 | 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", 45 | 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", 46 | 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", 47 | 'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt", 48 | 'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt", 49 | } 50 | } 51 | 52 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 53 | 'bert-base-uncased': 512, 54 | 'bert-large-uncased': 512, 55 | 'bert-base-cased': 512, 56 | 'bert-large-cased': 512, 57 | 'bert-base-multilingual-uncased': 512, 58 | 'bert-base-multilingual-cased': 512, 59 | 'bert-base-chinese': 512, 60 | 'bert-base-german-cased': 512, 61 | 'bert-large-uncased-whole-word-masking': 512, 62 | 'bert-large-cased-whole-word-masking': 512, 63 | 'bert-large-uncased-whole-word-masking-finetuned-squad': 512, 64 | 'bert-large-cased-whole-word-masking-finetuned-squad': 512, 65 | 'bert-base-cased-finetuned-mrpc': 512, 66 | 'bert-base-german-dbmdz-cased': 512, 67 | 'bert-base-german-dbmdz-uncased': 512, 68 | } 69 | 70 | PRETRAINED_INIT_CONFIGURATION = { 71 | 'bert-base-uncased': {'do_lower_case': True}, 72 | 'bert-large-uncased': {'do_lower_case': True}, 73 | 'bert-base-cased': {'do_lower_case': False}, 74 | 'bert-large-cased': {'do_lower_case': False}, 75 | 'bert-base-multilingual-uncased': {'do_lower_case': True}, 76 | 'bert-base-multilingual-cased': {'do_lower_case': False}, 77 | 'bert-base-chinese': {'do_lower_case': False}, 78 | 'bert-base-german-cased': {'do_lower_case': False}, 79 | 'bert-large-uncased-whole-word-masking': {'do_lower_case': True}, 80 | 'bert-large-cased-whole-word-masking': {'do_lower_case': False}, 81 | 'bert-large-uncased-whole-word-masking-finetuned-squad': {'do_lower_case': True}, 82 | 'bert-large-cased-whole-word-masking-finetuned-squad': {'do_lower_case': False}, 83 | 'bert-base-cased-finetuned-mrpc': {'do_lower_case': False}, 84 | 'bert-base-german-dbmdz-cased': {'do_lower_case': False}, 85 | 'bert-base-german-dbmdz-uncased': {'do_lower_case': True}, 86 | } 87 | 88 | 89 | def load_vocab(vocab_file): 90 | """Loads a vocabulary file into a dictionary.""" 91 | vocab = collections.OrderedDict() 92 | with open(vocab_file, "r", encoding="utf-8") as reader: 93 | tokens = reader.readlines() 94 | for index, token in enumerate(tokens): 95 | token = token.rstrip('\n') 96 | vocab[token] = index 97 | return vocab 98 | 99 | 100 | def whitespace_tokenize(text): 101 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 102 | text = text.strip() 103 | if not text: 104 | return [] 105 | tokens = text.split() 106 | return tokens 107 | 108 | 109 | class BertTokenizer(PreTrainedTokenizer): 110 | r""" 111 | Constructs a BertTokenizer. 112 | :class:`~transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece 113 | 114 | Args: 115 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 116 | do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False 117 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 118 | max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the 119 | minimum of this value (if specified) and the underlying BERT model's sequence length. 120 | never_split: List of tokens which will never be split during tokenization. Only has an effect when 121 | do_wordpiece_only=False 122 | """ 123 | 124 | vocab_files_names = VOCAB_FILES_NAMES 125 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 126 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 127 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 128 | 129 | def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, 130 | unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", 131 | mask_token="[MASK]", tokenize_chinese_chars=True, **kwargs): 132 | """Constructs a BertTokenizer. 133 | 134 | Args: 135 | **vocab_file**: Path to a one-wordpiece-per-line vocabulary file 136 | **do_lower_case**: (`optional`) boolean (default True) 137 | Whether to lower case the input 138 | Only has an effect when do_basic_tokenize=True 139 | **do_basic_tokenize**: (`optional`) boolean (default True) 140 | Whether to do basic tokenization before wordpiece. 141 | **never_split**: (`optional`) list of string 142 | List of tokens which will never be split during tokenization. 143 | Only has an effect when do_basic_tokenize=True 144 | **tokenize_chinese_chars**: (`optional`) boolean (default True) 145 | Whether to tokenize Chinese characters. 146 | This should likely be deactivated for Japanese: 147 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 148 | """ 149 | super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token, 150 | pad_token=pad_token, cls_token=cls_token, 151 | mask_token=mask_token, **kwargs) 152 | self.max_len_single_sentence = self.max_len - 2 # take into account special tokens 153 | self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens 154 | 155 | if not os.path.isfile(vocab_file): 156 | raise ValueError( 157 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 158 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 159 | self.vocab = load_vocab(vocab_file) 160 | self.ids_to_tokens = collections.OrderedDict( 161 | [(ids, tok) for tok, ids in self.vocab.items()]) 162 | self.do_basic_tokenize = do_basic_tokenize 163 | if do_basic_tokenize: 164 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 165 | never_split=never_split, 166 | tokenize_chinese_chars=tokenize_chinese_chars) 167 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) 168 | 169 | @property 170 | def vocab_size(self): 171 | return len(self.vocab) 172 | 173 | def _tokenize(self, text): 174 | split_tokens = [] 175 | if self.do_basic_tokenize: 176 | for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): 177 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 178 | split_tokens.append(sub_token) 179 | else: 180 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 181 | return split_tokens 182 | 183 | def _convert_token_to_id(self, token): 184 | """ Converts a token (str/unicode) in an id using the vocab. """ 185 | return self.vocab.get(token, self.vocab.get(self.unk_token)) 186 | 187 | def _convert_id_to_token(self, index): 188 | """Converts an index (integer) in a token (string/unicode) using the vocab.""" 189 | return self.ids_to_tokens.get(index, self.unk_token) 190 | 191 | def convert_tokens_to_string(self, tokens): 192 | """ Converts a sequence of tokens (string) in a single string. """ 193 | out_string = ' '.join(tokens).replace(' ##', '').strip() 194 | return out_string 195 | 196 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): 197 | """ 198 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks 199 | by concatenating and adding special tokens. 200 | A BERT sequence has the following format: 201 | single sequence: [CLS] X [SEP] 202 | pair of sequences: [CLS] A [SEP] B [SEP] 203 | """ 204 | if token_ids_1 is None: 205 | return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] 206 | cls = [self.cls_token_id] 207 | sep = [self.sep_token_id] 208 | return cls + token_ids_0 + sep + token_ids_1 + sep 209 | 210 | def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): 211 | """ 212 | Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding 213 | special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. 214 | 215 | Args: 216 | token_ids_0: list of ids (must not contain special tokens) 217 | token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids 218 | for sequence pairs 219 | already_has_special_tokens: (default False) Set to True if the token list is already formated with 220 | special tokens for the model 221 | 222 | Returns: 223 | A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token. 224 | """ 225 | 226 | if already_has_special_tokens: 227 | if token_ids_1 is not None: 228 | raise ValueError("You should not supply a second sequence if the provided sequence of " 229 | "ids is already formated with special tokens for the model.") 230 | return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) 231 | 232 | if token_ids_1 is not None: 233 | return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] 234 | return [1] + ([0] * len(token_ids_0)) + [1] 235 | 236 | def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): 237 | """ 238 | Creates a mask from the two sequences passed to be used in a sequence-pair classification task. 239 | A BERT sequence pair mask has the following format: 240 | 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 241 | | first sequence | second sequence 242 | 243 | if token_ids_1 is None, only returns the first portion of the mask (0's). 244 | """ 245 | sep = [self.sep_token_id] 246 | cls = [self.cls_token_id] 247 | if token_ids_1 is None: 248 | return len(cls + token_ids_0 + sep) * [0] 249 | return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] 250 | 251 | def save_vocabulary(self, vocab_path): 252 | """Save the tokenizer vocabulary to a directory or file.""" 253 | index = 0 254 | if os.path.isdir(vocab_path): 255 | vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file']) 256 | else: 257 | vocab_file = vocab_path 258 | with open(vocab_file, "w", encoding="utf-8") as writer: 259 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 260 | if index != token_index: 261 | logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." 262 | " Please check that the vocabulary is not corrupted!".format(vocab_file)) 263 | index = token_index 264 | writer.write(token + u'\n') 265 | index += 1 266 | return (vocab_file,) 267 | 268 | 269 | class BasicTokenizer(object): 270 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 271 | 272 | def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True): 273 | """ Constructs a BasicTokenizer. 274 | 275 | Args: 276 | **do_lower_case**: Whether to lower case the input. 277 | **never_split**: (`optional`) list of str 278 | Kept for backward compatibility purposes. 279 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) 280 | List of token not to split. 281 | **tokenize_chinese_chars**: (`optional`) boolean (default True) 282 | Whether to tokenize Chinese characters. 283 | This should likely be deactivated for Japanese: 284 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 285 | """ 286 | if never_split is None: 287 | never_split = [] 288 | self.do_lower_case = do_lower_case 289 | self.never_split = never_split 290 | self.tokenize_chinese_chars = tokenize_chinese_chars 291 | 292 | def tokenize(self, text, never_split=None): 293 | """ Basic Tokenization of a piece of text. 294 | Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer. 295 | 296 | Args: 297 | **never_split**: (`optional`) list of str 298 | Kept for backward compatibility purposes. 299 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) 300 | List of token not to split. 301 | """ 302 | never_split = self.never_split + (never_split if never_split is not None else []) 303 | text = self._clean_text(text) 304 | # This was added on November 1st, 2018 for the multilingual and Chinese 305 | # models. This is also applied to the English models now, but it doesn't 306 | # matter since the English models were not trained on any Chinese data 307 | # and generally don't have any Chinese data in them (there are Chinese 308 | # characters in the vocabulary because Wikipedia does have some Chinese 309 | # words in the English Wikipedia.). 310 | if self.tokenize_chinese_chars: 311 | text = self._tokenize_chinese_chars(text) 312 | orig_tokens = whitespace_tokenize(text) 313 | split_tokens = [] 314 | for token in orig_tokens: 315 | if self.do_lower_case and token not in never_split: 316 | token = token.lower() 317 | token = self._run_strip_accents(token) 318 | split_tokens.extend(self._run_split_on_punc(token)) 319 | 320 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 321 | return output_tokens 322 | 323 | def _run_strip_accents(self, text): 324 | """Strips accents from a piece of text.""" 325 | text = unicodedata.normalize("NFD", text) 326 | output = [] 327 | for char in text: 328 | cat = unicodedata.category(char) 329 | if cat == "Mn": 330 | continue 331 | output.append(char) 332 | return "".join(output) 333 | 334 | def _run_split_on_punc(self, text, never_split=None): 335 | """Splits punctuation on a piece of text.""" 336 | if never_split is not None and text in never_split: 337 | return [text] 338 | chars = list(text) 339 | i = 0 340 | start_new_word = True 341 | output = [] 342 | while i < len(chars): 343 | char = chars[i] 344 | if _is_punctuation(char): 345 | output.append([char]) 346 | start_new_word = True 347 | else: 348 | if start_new_word: 349 | output.append([]) 350 | start_new_word = False 351 | output[-1].append(char) 352 | i += 1 353 | 354 | return ["".join(x) for x in output] 355 | 356 | def _tokenize_chinese_chars(self, text): 357 | """Adds whitespace around any CJK character.""" 358 | output = [] 359 | for char in text: 360 | cp = ord(char) 361 | if self._is_chinese_char(cp): 362 | output.append(" ") 363 | output.append(char) 364 | output.append(" ") 365 | else: 366 | output.append(char) 367 | return "".join(output) 368 | 369 | def _is_chinese_char(self, cp): 370 | """Checks whether CP is the codepoint of a CJK character.""" 371 | # This defines a "chinese character" as anything in the CJK Unicode block: 372 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 373 | # 374 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 375 | # despite its name. The modern Korean Hangul alphabet is a different block, 376 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 377 | # space-separated words, so they are not treated specially and handled 378 | # like the all of the other languages. 379 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 380 | (cp >= 0x3400 and cp <= 0x4DBF) or # 381 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 382 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 383 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 384 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 385 | (cp >= 0xF900 and cp <= 0xFAFF) or # 386 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 387 | return True 388 | 389 | return False 390 | 391 | def _clean_text(self, text): 392 | """Performs invalid character removal and whitespace cleanup on text.""" 393 | output = [] 394 | for char in text: 395 | cp = ord(char) 396 | if cp == 0 or cp == 0xfffd or _is_control(char): 397 | continue 398 | if _is_whitespace(char): 399 | output.append(" ") 400 | else: 401 | output.append(char) 402 | return "".join(output) 403 | 404 | 405 | class WordpieceTokenizer(object): 406 | """Runs WordPiece tokenization.""" 407 | 408 | def __init__(self, vocab, unk_token, max_input_chars_per_word=100): 409 | self.vocab = vocab 410 | self.unk_token = unk_token 411 | self.max_input_chars_per_word = max_input_chars_per_word 412 | 413 | def tokenize(self, text): 414 | """Tokenizes a piece of text into its word pieces. 415 | 416 | This uses a greedy longest-match-first algorithm to perform tokenization 417 | using the given vocabulary. 418 | 419 | For example: 420 | input = "unaffable" 421 | output = ["un", "##aff", "##able"] 422 | 423 | Args: 424 | text: A single token or whitespace separated tokens. This should have 425 | already been passed through `BasicTokenizer`. 426 | 427 | Returns: 428 | A list of wordpiece tokens. 429 | """ 430 | 431 | output_tokens = [] 432 | for token in whitespace_tokenize(text): 433 | chars = list(token) 434 | if len(chars) > self.max_input_chars_per_word: 435 | output_tokens.append(self.unk_token) 436 | continue 437 | 438 | is_bad = False 439 | start = 0 440 | sub_tokens = [] 441 | while start < len(chars): 442 | end = len(chars) 443 | cur_substr = None 444 | while start < end: 445 | substr = "".join(chars[start:end]) 446 | if start > 0: 447 | substr = "##" + substr 448 | if substr in self.vocab: 449 | cur_substr = substr 450 | break 451 | end -= 1 452 | if cur_substr is None: 453 | is_bad = True 454 | break 455 | sub_tokens.append(cur_substr) 456 | start = end 457 | 458 | if is_bad: 459 | output_tokens.append(self.unk_token) 460 | else: 461 | output_tokens.extend(sub_tokens) 462 | return output_tokens 463 | 464 | 465 | def _is_whitespace(char): 466 | """Checks whether `chars` is a whitespace character.""" 467 | # \t, \n, and \r are technically contorl characters but we treat them 468 | # as whitespace since they are generally considered as such. 469 | if char == " " or char == "\t" or char == "\n" or char == "\r": 470 | return True 471 | cat = unicodedata.category(char) 472 | if cat == "Zs": 473 | return True 474 | return False 475 | 476 | 477 | def _is_control(char): 478 | """Checks whether `chars` is a control character.""" 479 | # These are technically control characters but we count them as whitespace 480 | # characters. 481 | if char == "\t" or char == "\n" or char == "\r": 482 | return False 483 | cat = unicodedata.category(char) 484 | if cat.startswith("C"): 485 | return True 486 | return False 487 | 488 | 489 | def _is_punctuation(char): 490 | """Checks whether `chars` is a punctuation character.""" 491 | cp = ord(char) 492 | # We treat all non-letter/number ASCII as punctuation. 493 | # Characters such as "^", "$", and "`" are not in the Unicode 494 | # Punctuation class but we treat them as punctuation anyways, for 495 | # consistency. 496 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 497 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 498 | return True 499 | cat = unicodedata.category(char) 500 | if cat.startswith("P"): 501 | return True 502 | return False 503 | -------------------------------------------------------------------------------- /albert_zh/tokenization_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import logging 20 | import os 21 | import json 22 | import six 23 | import copy 24 | from io import open 25 | 26 | from .file_utils import cached_path 27 | 28 | import torch 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json' 33 | ADDED_TOKENS_FILE = 'added_tokens.json' 34 | TOKENIZER_CONFIG_FILE = 'tokenizer_config.json' 35 | 36 | class PreTrainedTokenizer(object): 37 | """ Base class for all tokenizers. 38 | Handle all the shared methods for tokenization and special tokens as well as methods dowloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary. 39 | 40 | This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...). 41 | 42 | Class attributes (overridden by derived classes): 43 | 44 | - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string). 45 | - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file. 46 | - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size. 47 | - ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, a dictionnary of specific arguments to pass to the ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the ``from_pretrained()`` method. 48 | 49 | Parameters: 50 | 51 | - ``bos_token``: (`Optional`) string: a beginning of sentence token. Will be associated to ``self.bos_token`` and ``self.bos_token_id`` 52 | 53 | - ``eos_token``: (`Optional`) string: an end of sentence token. Will be associated to ``self.eos_token`` and ``self.eos_token_id`` 54 | 55 | - ``unk_token``: (`Optional`) string: an unknown token. Will be associated to ``self.unk_token`` and ``self.unk_token_id`` 56 | 57 | - ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence). Will be associated to ``self.sep_token`` and ``self.sep_token_id`` 58 | 59 | - ``pad_token``: (`Optional`) string: a padding token. Will be associated to ``self.pad_token`` and ``self.pad_token_id`` 60 | 61 | - ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model). Will be associated to ``self.cls_token`` and ``self.cls_token_id`` 62 | 63 | - ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id`` 64 | 65 | - ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens. Adding all special tokens here ensure they won't be split by the tokenization process. Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids`` 66 | """ 67 | vocab_files_names = {} 68 | pretrained_vocab_files_map = {} 69 | pretrained_init_configuration = {} 70 | max_model_input_sizes = {} 71 | 72 | SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token", 73 | "pad_token", "cls_token", "mask_token", 74 | "additional_special_tokens"] 75 | 76 | @property 77 | def bos_token(self): 78 | """ Beginning of sentence token (string). Log an error if used while not having been set. """ 79 | if self._bos_token is None: 80 | logger.error("Using bos_token, but it is not set yet.") 81 | return self._bos_token 82 | 83 | @property 84 | def eos_token(self): 85 | """ End of sentence token (string). Log an error if used while not having been set. """ 86 | if self._eos_token is None: 87 | logger.error("Using eos_token, but it is not set yet.") 88 | return self._eos_token 89 | 90 | @property 91 | def unk_token(self): 92 | """ Unknown token (string). Log an error if used while not having been set. """ 93 | if self._unk_token is None: 94 | logger.error("Using unk_token, but it is not set yet.") 95 | return self._unk_token 96 | 97 | @property 98 | def sep_token(self): 99 | """ Separation token (string). E.g. separate context and query in an input sequence. Log an error if used while not having been set. """ 100 | if self._sep_token is None: 101 | logger.error("Using sep_token, but it is not set yet.") 102 | return self._sep_token 103 | 104 | @property 105 | def pad_token(self): 106 | """ Padding token (string). Log an error if used while not having been set. """ 107 | if self._pad_token is None: 108 | logger.error("Using pad_token, but it is not set yet.") 109 | return self._pad_token 110 | 111 | @property 112 | def cls_token(self): 113 | """ Classification token (string). E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """ 114 | if self._cls_token is None: 115 | logger.error("Using cls_token, but it is not set yet.") 116 | return self._cls_token 117 | 118 | @property 119 | def mask_token(self): 120 | """ Mask token (string). E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """ 121 | if self._mask_token is None: 122 | logger.error("Using mask_token, but it is not set yet.") 123 | return self._mask_token 124 | 125 | @property 126 | def additional_special_tokens(self): 127 | """ All the additional special tokens you may want to use (list of strings). Log an error if used while not having been set. """ 128 | if self._additional_special_tokens is None: 129 | logger.error("Using additional_special_tokens, but it is not set yet.") 130 | return self._additional_special_tokens 131 | 132 | @bos_token.setter 133 | def bos_token(self, value): 134 | self._bos_token = value 135 | 136 | @eos_token.setter 137 | def eos_token(self, value): 138 | self._eos_token = value 139 | 140 | @unk_token.setter 141 | def unk_token(self, value): 142 | self._unk_token = value 143 | 144 | @sep_token.setter 145 | def sep_token(self, value): 146 | self._sep_token = value 147 | 148 | @pad_token.setter 149 | def pad_token(self, value): 150 | self._pad_token = value 151 | 152 | @cls_token.setter 153 | def cls_token(self, value): 154 | self._cls_token = value 155 | 156 | @mask_token.setter 157 | def mask_token(self, value): 158 | self._mask_token = value 159 | 160 | @additional_special_tokens.setter 161 | def additional_special_tokens(self, value): 162 | self._additional_special_tokens = value 163 | 164 | @property 165 | def bos_token_id(self): 166 | """ Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """ 167 | return self.convert_tokens_to_ids(self.bos_token) 168 | 169 | @property 170 | def eos_token_id(self): 171 | """ Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """ 172 | return self.convert_tokens_to_ids(self.eos_token) 173 | 174 | @property 175 | def unk_token_id(self): 176 | """ Id of the unknown token in the vocabulary. Log an error if used while not having been set. """ 177 | return self.convert_tokens_to_ids(self.unk_token) 178 | 179 | @property 180 | def sep_token_id(self): 181 | """ Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set. """ 182 | return self.convert_tokens_to_ids(self.sep_token) 183 | 184 | @property 185 | def pad_token_id(self): 186 | """ Id of the padding token in the vocabulary. Log an error if used while not having been set. """ 187 | return self.convert_tokens_to_ids(self.pad_token) 188 | 189 | @property 190 | def cls_token_id(self): 191 | """ Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """ 192 | return self.convert_tokens_to_ids(self.cls_token) 193 | 194 | @property 195 | def mask_token_id(self): 196 | """ Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """ 197 | return self.convert_tokens_to_ids(self.mask_token) 198 | 199 | @property 200 | def additional_special_tokens_ids(self): 201 | """ Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set. """ 202 | return self.convert_tokens_to_ids(self.additional_special_tokens) 203 | 204 | def __init__(self, max_len=None, **kwargs): 205 | self._bos_token = None 206 | self._eos_token = None 207 | self._unk_token = None 208 | self._sep_token = None 209 | self._pad_token = None 210 | self._cls_token = None 211 | self._mask_token = None 212 | self._additional_special_tokens = [] 213 | 214 | self.max_len = max_len if max_len is not None else int(1e12) 215 | 216 | # Added tokens 217 | self.added_tokens_encoder = {} 218 | self.added_tokens_decoder = {} 219 | 220 | # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``) 221 | self.init_inputs = () 222 | self.init_kwargs = {} 223 | 224 | for key, value in kwargs.items(): 225 | if key in self.SPECIAL_TOKENS_ATTRIBUTES: 226 | if key == 'additional_special_tokens': 227 | assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value) 228 | else: 229 | assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode)) 230 | setattr(self, key, value) 231 | 232 | 233 | @classmethod 234 | def from_pretrained(cls, *inputs, **kwargs): 235 | r""" 236 | Instantiate a :class:`~transformers.PreTrainedTokenizer` (or a derived class) from a predefined tokenizer. 237 | 238 | Args: 239 | pretrained_model_name_or_path: either: 240 | 241 | - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``. 242 | - a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``. 243 | - (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``. 244 | 245 | cache_dir: (`optional`) string: 246 | Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used. 247 | 248 | force_download: (`optional`) boolean, default False: 249 | Force to (re-)download the vocabulary files and override the cached versions if they exists. 250 | 251 | proxies: (`optional`) dict, default None: 252 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. 253 | The proxies are used on each request. 254 | 255 | inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method. 256 | 257 | kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~transformers.PreTrainedTokenizer` for details. 258 | 259 | Examples:: 260 | 261 | # We can't instantiate directly the base class `PreTrainedTokenizer` so let's show our examples on a derived class: BertTokenizer 262 | 263 | # Download vocabulary from S3 and cache. 264 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 265 | 266 | # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`) 267 | tokenizer = BertTokenizer.from_pretrained('./test/saved_model/') 268 | 269 | # If the tokenizer uses a single vocabulary file, you can point directly to this file 270 | tokenizer = BertTokenizer.from_pretrained('./test/saved_model/my_vocab.txt') 271 | 272 | # You can link tokens to special vocabulary when instantiating 273 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', unk_token='') 274 | # You should be sure '' is in the vocabulary when doing that. 275 | # Otherwise use tokenizer.add_special_tokens({'unk_token': ''}) instead) 276 | assert tokenizer.unk_token == '' 277 | 278 | """ 279 | return cls._from_pretrained(*inputs, **kwargs) 280 | 281 | 282 | @classmethod 283 | def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs): 284 | cache_dir = kwargs.pop('cache_dir', None) 285 | force_download = kwargs.pop('force_download', False) 286 | proxies = kwargs.pop('proxies', None) 287 | 288 | s3_models = list(cls.max_model_input_sizes.keys()) 289 | vocab_files = {} 290 | init_configuration = {} 291 | if pretrained_model_name_or_path in s3_models: 292 | # Get the vocabulary from AWS S3 bucket 293 | for file_id, map_list in cls.pretrained_vocab_files_map.items(): 294 | vocab_files[file_id] = map_list[pretrained_model_name_or_path] 295 | if cls.pretrained_init_configuration and pretrained_model_name_or_path in cls.pretrained_init_configuration: 296 | init_configuration = cls.pretrained_init_configuration[pretrained_model_name_or_path] 297 | else: 298 | # Get the vocabulary from local files 299 | logger.info( 300 | "Model name '{}' not found in model shortcut name list ({}). " 301 | "Assuming '{}' is a path or url to a directory containing tokenizer files.".format( 302 | pretrained_model_name_or_path, ', '.join(s3_models), 303 | pretrained_model_name_or_path)) 304 | 305 | # Look for the tokenizer main vocabulary files 306 | for file_id, file_name in cls.vocab_files_names.items(): 307 | if os.path.isdir(pretrained_model_name_or_path): 308 | # If a directory is provided we look for the standard filenames 309 | full_file_name = os.path.join(pretrained_model_name_or_path, file_name) 310 | else: 311 | # If a path to a file is provided we use it (will only work for non-BPE tokenizer using a single vocabulary file) 312 | full_file_name = pretrained_model_name_or_path 313 | if not os.path.exists(full_file_name): 314 | logger.info("Didn't find file {}. We won't load it.".format(full_file_name)) 315 | full_file_name = None 316 | vocab_files[file_id] = full_file_name 317 | 318 | # Look for the additional tokens files 319 | additional_files_names = {'added_tokens_file': ADDED_TOKENS_FILE, 320 | 'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE, 321 | 'tokenizer_config_file': TOKENIZER_CONFIG_FILE, 322 | } 323 | 324 | # If a path to a file was provided, get the parent directory 325 | saved_directory = pretrained_model_name_or_path 326 | if os.path.exists(saved_directory) and not os.path.isdir(saved_directory): 327 | saved_directory = os.path.dirname(saved_directory) 328 | 329 | for file_id, file_name in additional_files_names.items(): 330 | full_file_name = os.path.join(saved_directory, file_name) 331 | if not os.path.exists(full_file_name): 332 | logger.info("Didn't find file {}. We won't load it.".format(full_file_name)) 333 | full_file_name = None 334 | vocab_files[file_id] = full_file_name 335 | 336 | if all(full_file_name is None for full_file_name in vocab_files.values()): 337 | raise EnvironmentError( 338 | "Model name '{}' was not found in tokenizers model name list ({}). " 339 | "We assumed '{}' was a path or url to a directory containing vocabulary files " 340 | "named {} but couldn't find such vocabulary files at this path or url.".format( 341 | pretrained_model_name_or_path, ', '.join(s3_models), 342 | pretrained_model_name_or_path, 343 | list(cls.vocab_files_names.values()))) 344 | 345 | # Get files from url, cache, or disk depending on the case 346 | try: 347 | resolved_vocab_files = {} 348 | for file_id, file_path in vocab_files.items(): 349 | if file_path is None: 350 | resolved_vocab_files[file_id] = None 351 | else: 352 | resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies) 353 | except EnvironmentError: 354 | if pretrained_model_name_or_path in s3_models: 355 | msg = "Couldn't reach server at '{}' to download vocabulary files." 356 | else: 357 | msg = "Model name '{}' was not found in tokenizers model name list ({}). " \ 358 | "We assumed '{}' was a path or url to a directory containing vocabulary files " \ 359 | "named {}, but couldn't find such vocabulary files at this path or url.".format( 360 | pretrained_model_name_or_path, ', '.join(s3_models), 361 | pretrained_model_name_or_path, 362 | list(cls.vocab_files_names.values())) 363 | 364 | raise EnvironmentError(msg) 365 | 366 | for file_id, file_path in vocab_files.items(): 367 | if file_path == resolved_vocab_files[file_id]: 368 | logger.info("loading file {}".format(file_path)) 369 | else: 370 | logger.info("loading file {} from cache at {}".format( 371 | file_path, resolved_vocab_files[file_id])) 372 | 373 | # Prepare tokenizer initialization kwargs 374 | # Did we saved some inputs and kwargs to reload ? 375 | tokenizer_config_file = resolved_vocab_files.pop('tokenizer_config_file', None) 376 | if tokenizer_config_file is not None: 377 | init_kwargs = json.load(open(tokenizer_config_file, encoding="utf-8")) 378 | saved_init_inputs = init_kwargs.pop('init_inputs', ()) 379 | if not init_inputs: 380 | init_inputs = saved_init_inputs 381 | else: 382 | init_kwargs = init_configuration 383 | 384 | # Update with newly provided kwargs 385 | init_kwargs.update(kwargs) 386 | 387 | # Set max length if needed 388 | if pretrained_model_name_or_path in cls.max_model_input_sizes: 389 | # if we're using a pretrained model, ensure the tokenizer 390 | # wont index sequences longer than the number of positional embeddings 391 | max_len = cls.max_model_input_sizes[pretrained_model_name_or_path] 392 | if max_len is not None and isinstance(max_len, (int, float)): 393 | init_kwargs['max_len'] = min(init_kwargs.get('max_len', int(1e12)), max_len) 394 | 395 | # Merge resolved_vocab_files arguments in init_kwargs. 396 | added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None) 397 | special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None) 398 | for args_name, file_path in resolved_vocab_files.items(): 399 | if args_name not in init_kwargs: 400 | init_kwargs[args_name] = file_path 401 | if special_tokens_map_file is not None: 402 | special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8")) 403 | for key, value in special_tokens_map.items(): 404 | if key not in init_kwargs: 405 | init_kwargs[key] = value 406 | 407 | # Instantiate tokenizer. 408 | tokenizer = cls(*init_inputs, **init_kwargs) 409 | 410 | # Save inputs and kwargs for saving and re-loading with ``save_pretrained`` 411 | tokenizer.init_inputs = init_inputs 412 | tokenizer.init_kwargs = init_kwargs 413 | 414 | # Add supplementary tokens. 415 | if added_tokens_file is not None: 416 | added_tok_encoder = json.load(open(added_tokens_file, encoding="utf-8")) 417 | added_tok_decoder = {v:k for k, v in added_tok_encoder.items()} 418 | tokenizer.added_tokens_encoder.update(added_tok_encoder) 419 | tokenizer.added_tokens_decoder.update(added_tok_decoder) 420 | 421 | return tokenizer 422 | 423 | 424 | def save_pretrained(self, save_directory): 425 | """ Save the tokenizer vocabulary files together with: 426 | - added tokens, 427 | - special-tokens-to-class-attributes-mapping, 428 | - tokenizer instantiation positional and keywords inputs (e.g. do_lower_case for Bert). 429 | 430 | This won't save modifications other than (added tokens and special token mapping) you may have 431 | applied to the tokenizer after the instantiation (e.g. modifying tokenizer.do_lower_case after creation). 432 | 433 | This method make sure the full tokenizer can then be re-loaded using the :func:`~transformers.PreTrainedTokenizer.from_pretrained` class method. 434 | """ 435 | if not os.path.isdir(save_directory): 436 | logger.error("Saving directory ({}) should be a directory".format(save_directory)) 437 | return 438 | 439 | special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE) 440 | added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE) 441 | tokenizer_config_file = os.path.join(save_directory, TOKENIZER_CONFIG_FILE) 442 | 443 | tokenizer_config = copy.deepcopy(self.init_kwargs) 444 | tokenizer_config['init_inputs'] = copy.deepcopy(self.init_inputs) 445 | for file_id in self.vocab_files_names.keys(): 446 | tokenizer_config.pop(file_id, None) 447 | 448 | with open(tokenizer_config_file, 'w', encoding='utf-8') as f: 449 | f.write(json.dumps(tokenizer_config, ensure_ascii=False)) 450 | 451 | with open(special_tokens_map_file, 'w', encoding='utf-8') as f: 452 | f.write(json.dumps(self.special_tokens_map, ensure_ascii=False)) 453 | 454 | with open(added_tokens_file, 'w', encoding='utf-8') as f: 455 | if self.added_tokens_encoder: 456 | out_str = json.dumps(self.added_tokens_encoder, ensure_ascii=False) 457 | else: 458 | out_str = u"{}" 459 | f.write(out_str) 460 | 461 | vocab_files = self.save_vocabulary(save_directory) 462 | 463 | return vocab_files + (special_tokens_map_file, added_tokens_file) 464 | 465 | 466 | def save_vocabulary(self, save_directory): 467 | """ Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens 468 | and special token mappings. 469 | 470 | Please use :func:`~transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full Tokenizer state if you want to reload it using the :func:`~transformers.PreTrainedTokenizer.from_pretrained` class method. 471 | """ 472 | raise NotImplementedError 473 | 474 | 475 | def vocab_size(self): 476 | """ Size of the base vocabulary (without the added tokens) """ 477 | raise NotImplementedError 478 | 479 | 480 | def __len__(self): 481 | """ Size of the full vocabulary with the added tokens """ 482 | return self.vocab_size + len(self.added_tokens_encoder) 483 | 484 | 485 | def add_tokens(self, new_tokens): 486 | """ 487 | Add a list of new tokens to the tokenizer class. If the new tokens are not in the 488 | vocabulary, they are added to it with indices starting from length of the current vocabulary. 489 | 490 | Args: 491 | new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). 492 | 493 | Returns: 494 | Number of tokens added to the vocabulary. 495 | 496 | Examples:: 497 | 498 | # Let's see how to increase the vocabulary of Bert model and tokenizer 499 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 500 | model = BertModel.from_pretrained('bert-base-uncased') 501 | 502 | num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2']) 503 | print('We have added', num_added_toks, 'tokens') 504 | model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. 505 | """ 506 | if not new_tokens: 507 | return 0 508 | 509 | to_add_tokens = [] 510 | for token in new_tokens: 511 | assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode)) 512 | if token != self.unk_token and \ 513 | self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and \ 514 | token not in to_add_tokens: 515 | to_add_tokens.append(token) 516 | logger.info("Adding %s to the vocabulary", token) 517 | 518 | added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens)) 519 | added_tok_decoder = {v:k for k, v in added_tok_encoder.items()} 520 | self.added_tokens_encoder.update(added_tok_encoder) 521 | self.added_tokens_decoder.update(added_tok_decoder) 522 | 523 | return len(to_add_tokens) 524 | 525 | def num_added_tokens(self, pair=False): 526 | """ 527 | Returns the number of added tokens when encoding a sequence with special tokens. 528 | 529 | Note: 530 | This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this 531 | inside your training loop. 532 | 533 | Args: 534 | pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the 535 | number of added tokens in the case of a single sequence if set to False. 536 | 537 | Returns: 538 | Number of tokens added to sequences 539 | """ 540 | token_ids_0 = [] 541 | token_ids_1 = [] 542 | return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None)) 543 | 544 | def add_special_tokens(self, special_tokens_dict): 545 | """ 546 | Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them 547 | to class attributes. If special tokens are NOT in the vocabulary, they are added 548 | to it (indexed starting from the last index of the current vocabulary). 549 | 550 | Using `add_special_tokens` will ensure your special tokens can be used in several ways: 551 | 552 | - special tokens are carefully handled by the tokenizer (they are never split) 553 | - you can easily refer to special tokens using tokenizer class attributes like `tokenizer.cls_token`. This makes it easy to develop model-agnostic training and fine-tuning scripts. 554 | 555 | When possible, special tokens are already registered for provided pretrained models (ex: BertTokenizer cls_token is already registered to be '[CLS]' and XLM's one is also registered to be '') 556 | 557 | Args: 558 | special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes: 559 | [``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, 560 | ``additional_special_tokens``]. 561 | 562 | Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). 563 | 564 | Returns: 565 | Number of tokens added to the vocabulary. 566 | 567 | Examples:: 568 | 569 | # Let's see how to add a new classification token to GPT-2 570 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2') 571 | model = GPT2Model.from_pretrained('gpt2') 572 | 573 | special_tokens_dict = {'cls_token': ''} 574 | 575 | num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) 576 | print('We have added', num_added_toks, 'tokens') 577 | model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. 578 | 579 | assert tokenizer.cls_token == '' 580 | """ 581 | if not special_tokens_dict: 582 | return 0 583 | 584 | added_tokens = 0 585 | for key, value in special_tokens_dict.items(): 586 | assert key in self.SPECIAL_TOKENS_ATTRIBUTES 587 | if key == 'additional_special_tokens': 588 | assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value) 589 | added_tokens += self.add_tokens(value) 590 | else: 591 | assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode)) 592 | added_tokens += self.add_tokens([value]) 593 | logger.info("Assigning %s to the %s key of the tokenizer", value, key) 594 | setattr(self, key, value) 595 | 596 | return added_tokens 597 | 598 | def tokenize(self, text, **kwargs): 599 | """ Converts a string in a sequence of tokens (string), using the tokenizer. 600 | Split in words for word-based vocabulary or sub-words for sub-word-based 601 | vocabularies (BPE/SentencePieces/WordPieces). 602 | 603 | Take care of added tokens. 604 | """ 605 | def split_on_token(tok, text): 606 | result = [] 607 | split_text = text.split(tok) 608 | for i, sub_text in enumerate(split_text): 609 | sub_text = sub_text.strip() 610 | if i == 0 and not sub_text: 611 | result += [tok] 612 | elif i == len(split_text) - 1: 613 | if sub_text: 614 | result += [sub_text] 615 | else: 616 | pass 617 | else: 618 | if sub_text: 619 | result += [sub_text] 620 | result += [tok] 621 | return result 622 | 623 | def split_on_tokens(tok_list, text): 624 | if not text: 625 | return [] 626 | if not tok_list: 627 | return self._tokenize(text, **kwargs) 628 | 629 | tokenized_text = [] 630 | text_list = [text] 631 | for tok in tok_list: 632 | tokenized_text = [] 633 | for sub_text in text_list: 634 | if sub_text not in self.added_tokens_encoder \ 635 | and sub_text not in self.all_special_tokens: 636 | tokenized_text += split_on_token(tok, sub_text) 637 | else: 638 | tokenized_text += [sub_text] 639 | text_list = tokenized_text 640 | 641 | return sum((self._tokenize(token, **kwargs) if token not \ 642 | in self.added_tokens_encoder and token not in self.all_special_tokens \ 643 | else [token] for token in tokenized_text), []) 644 | 645 | added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens 646 | tokenized_text = split_on_tokens(added_tokens, text) 647 | return tokenized_text 648 | 649 | def _tokenize(self, text, **kwargs): 650 | """ Converts a string in a sequence of tokens (string), using the tokenizer. 651 | Split in words for word-based vocabulary or sub-words for sub-word-based 652 | vocabularies (BPE/SentencePieces/WordPieces). 653 | 654 | Do NOT take care of added tokens. 655 | """ 656 | raise NotImplementedError 657 | 658 | def convert_tokens_to_ids(self, tokens): 659 | """ Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id 660 | (resp. a sequence of ids), using the vocabulary. 661 | """ 662 | if tokens is None: 663 | return None 664 | 665 | if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)): 666 | return self._convert_token_to_id_with_added_voc(tokens) 667 | 668 | ids = [] 669 | for token in tokens: 670 | ids.append(self._convert_token_to_id_with_added_voc(token)) 671 | if len(ids) > self.max_len: 672 | logger.warning("Token indices sequence length is longer than the specified maximum sequence length " 673 | "for this model ({} > {}). Running this sequence through the model will result in " 674 | "indexing errors".format(len(ids), self.max_len)) 675 | return ids 676 | 677 | def _convert_token_to_id_with_added_voc(self, token): 678 | if token is None: 679 | return None 680 | 681 | if token in self.added_tokens_encoder: 682 | return self.added_tokens_encoder[token] 683 | return self._convert_token_to_id(token) 684 | 685 | def _convert_token_to_id(self, token): 686 | raise NotImplementedError 687 | 688 | def encode(self, 689 | text, 690 | text_pair=None, 691 | add_special_tokens=False, 692 | max_length=None, 693 | stride=0, 694 | truncation_strategy='longest_first', 695 | return_tensors=None, 696 | **kwargs): 697 | """ 698 | Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. 699 | 700 | Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``. 701 | 702 | Args: 703 | text: The first sequence to be encoded. This can be a string, a list of strings (tokenized string using 704 | the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` 705 | method) 706 | text_pair: Optional second sequence to be encoded. This can be a string, a list of strings (tokenized 707 | string using the `tokenize` method) or a list of integers (tokenized string ids using the 708 | `convert_tokens_to_ids` method) 709 | add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative 710 | to their model. 711 | max_length: if set to a number, will limit the total sequence returned so that it has a maximum length. 712 | If there are overflowing tokens, those will be added to the returned dictionary 713 | stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens 714 | from the main sequence returned. The value of this argument defines the number of additional tokens. 715 | truncation_strategy: string selected in the following options: 716 | - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length 717 | starting from the longest one at each token (when there is a pair of input sequences) 718 | - 'only_first': Only truncate the first sequence 719 | - 'only_second': Only truncate the second sequence 720 | - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) 721 | return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant 722 | or PyTorch torch.Tensor instead of a list of python integers. 723 | **kwargs: passed to the `self.tokenize()` method 724 | """ 725 | encoded_inputs = self.encode_plus(text, 726 | text_pair=text_pair, 727 | max_length=max_length, 728 | add_special_tokens=add_special_tokens, 729 | stride=stride, 730 | truncation_strategy=truncation_strategy, 731 | return_tensors=return_tensors, 732 | **kwargs) 733 | 734 | return encoded_inputs["input_ids"] 735 | 736 | def encode_plus(self, 737 | text, 738 | text_pair=None, 739 | add_special_tokens=False, 740 | max_length=None, 741 | stride=0, 742 | truncation_strategy='longest_first', 743 | return_tensors=None, 744 | **kwargs): 745 | """ 746 | Returns a dictionary containing the encoded sequence or sequence pair and additional informations: 747 | the mask for sequence classification and the overflowing elements if a ``max_length`` is specified. 748 | 749 | Args: 750 | text: The first sequence to be encoded. This can be a string, a list of strings (tokenized string using 751 | the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` 752 | method) 753 | text_pair: Optional second sequence to be encoded. This can be a string, a list of strings (tokenized 754 | string using the `tokenize` method) or a list of integers (tokenized string ids using the 755 | `convert_tokens_to_ids` method) 756 | add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative 757 | to their model. 758 | max_length: if set to a number, will limit the total sequence returned so that it has a maximum length. 759 | If there are overflowing tokens, those will be added to the returned dictionary 760 | stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens 761 | from the main sequence returned. The value of this argument defines the number of additional tokens. 762 | truncation_strategy: string selected in the following options: 763 | - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length 764 | starting from the longest one at each token (when there is a pair of input sequences) 765 | - 'only_first': Only truncate the first sequence 766 | - 'only_second': Only truncate the second sequence 767 | - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) 768 | return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant 769 | or PyTorch torch.Tensor instead of a list of python integers. 770 | **kwargs: passed to the `self.tokenize()` method 771 | """ 772 | 773 | def get_input_ids(text): 774 | if isinstance(text, six.string_types): 775 | return self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) 776 | elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], six.string_types): 777 | return self.convert_tokens_to_ids(text) 778 | elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int): 779 | return text 780 | else: 781 | raise ValueError("Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers.") 782 | 783 | first_ids = get_input_ids(text) 784 | second_ids = get_input_ids(text_pair) if text_pair is not None else None 785 | 786 | return self.prepare_for_model(first_ids, 787 | pair_ids=second_ids, 788 | max_length=max_length, 789 | add_special_tokens=add_special_tokens, 790 | stride=stride, 791 | truncation_strategy=truncation_strategy, 792 | return_tensors=return_tensors) 793 | 794 | def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0, 795 | truncation_strategy='longest_first', return_tensors=None): 796 | """ 797 | Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. 798 | It adds special tokens, truncates 799 | sequences if overflowing while taking into account the special tokens and manages a window stride for 800 | overflowing tokens 801 | 802 | Args: 803 | ids: list of tokenized input ids. Can be obtained from a string by chaining the 804 | `tokenize` and `convert_tokens_to_ids` methods. 805 | pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the 806 | `tokenize` and `convert_tokens_to_ids` methods. 807 | max_length: maximum length of the returned list. Will truncate by taking into account the special tokens. 808 | add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative 809 | to their model. 810 | stride: window stride for overflowing tokens. Can be useful for edge effect removal when using sequential 811 | list of inputs. 812 | truncation_strategy: string selected in the following options: 813 | - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length 814 | starting from the longest one at each token (when there is a pair of input sequences) 815 | - 'only_first': Only truncate the first sequence 816 | - 'only_second': Only truncate the second sequence 817 | - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) 818 | return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant 819 | or PyTorch torch.Tensor instead of a list of python integers. 820 | 821 | Return: 822 | A Dictionary of shape:: 823 | 824 | { 825 | input_ids: list[int], 826 | overflowing_tokens: list[int] if a ``max_length`` is specified, else None 827 | special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True`` 828 | } 829 | 830 | With the fields: 831 | ``input_ids``: list of tokens to be fed to a model 832 | 833 | ``overflowing_tokens``: list of overflowing tokens if a max length is specified. 834 | 835 | ``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added 836 | tokens and 1 specifying sequence tokens. 837 | """ 838 | pair = bool(pair_ids is not None) 839 | len_ids = len(ids) 840 | len_pair_ids = len(pair_ids) if pair else 0 841 | 842 | encoded_inputs = {} 843 | total_len = len_ids + len_pair_ids + (self.num_added_tokens(pair=pair) if add_special_tokens else 0) 844 | if max_length and total_len > max_length: 845 | ids, pair_ids, overflowing_tokens = self.truncate_sequences(ids, pair_ids=pair_ids, 846 | num_tokens_to_remove=total_len-max_length, 847 | truncation_strategy=truncation_strategy, 848 | stride=stride) 849 | encoded_inputs["overflowing_tokens"] = overflowing_tokens 850 | encoded_inputs["num_truncated_tokens"] = total_len - max_length 851 | 852 | if add_special_tokens: 853 | sequence = self.build_inputs_with_special_tokens(ids, pair_ids) 854 | token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) 855 | encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) 856 | else: 857 | sequence = ids + pair_ids if pair else ids 858 | token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else []) 859 | 860 | if return_tensors == 'tf' and is_tf_available(): 861 | sequence = tf.constant([sequence]) 862 | token_type_ids = tf.constant([token_type_ids]) 863 | elif return_tensors == 'pt' and is_torch_available(): 864 | sequence = torch.tensor([sequence]) 865 | token_type_ids = torch.tensor([token_type_ids]) 866 | elif return_tensors is not None: 867 | logger.warning("Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(return_tensors)) 868 | 869 | encoded_inputs["input_ids"] = sequence 870 | encoded_inputs["token_type_ids"] = token_type_ids 871 | 872 | if max_length and len(encoded_inputs["input_ids"]) > max_length: 873 | encoded_inputs["input_ids"] = encoded_inputs["input_ids"][:max_length] 874 | encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length] 875 | encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"][:max_length] 876 | 877 | return encoded_inputs 878 | 879 | def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first', stride=0): 880 | """Truncates a sequence pair in place to the maximum length. 881 | truncation_strategy: string selected in the following options: 882 | - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length 883 | starting from the longest one at each token (when there is a pair of input sequences). 884 | Overflowing tokens only contains overflow from the first sequence. 885 | - 'only_first': Only truncate the first sequence. raise an error if the first sequence is shorter or equal to than num_tokens_to_remove. 886 | - 'only_second': Only truncate the second sequence 887 | - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) 888 | """ 889 | if num_tokens_to_remove <= 0: 890 | return ids, pair_ids, [] 891 | 892 | if truncation_strategy == 'longest_first': 893 | overflowing_tokens = [] 894 | for _ in range(num_tokens_to_remove): 895 | if pair_ids is None or len(ids) > len(pair_ids): 896 | overflowing_tokens = [ids[-1]] + overflowing_tokens 897 | ids = ids[:-1] 898 | else: 899 | pair_ids = pair_ids[:-1] 900 | window_len = min(len(ids), stride) 901 | if window_len > 0: 902 | overflowing_tokens = ids[-window_len:] + overflowing_tokens 903 | elif truncation_strategy == 'only_first': 904 | assert len(ids) > num_tokens_to_remove 905 | window_len = min(len(ids), stride + num_tokens_to_remove) 906 | overflowing_tokens = ids[-window_len:] 907 | ids = ids[:-num_tokens_to_remove] 908 | elif truncation_strategy == 'only_second': 909 | assert pair_ids is not None and len(pair_ids) > num_tokens_to_remove 910 | window_len = min(len(pair_ids), stride + num_tokens_to_remove) 911 | overflowing_tokens = pair_ids[-window_len:] 912 | pair_ids = pair_ids[:-num_tokens_to_remove] 913 | elif truncation_strategy == 'do_not_truncate': 914 | raise ValueError("Input sequence are too long for max_length. Please select a truncation strategy.") 915 | else: 916 | raise ValueError("Truncation_strategy should be selected in ['longest_first', 'only_first', 'only_second', 'do_not_truncate']") 917 | return (ids, pair_ids, overflowing_tokens) 918 | 919 | def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): 920 | logger.warning("This tokenizer does not make use of special tokens.") 921 | if token_ids_1 is None: 922 | return len(token_ids_0) * [0] 923 | return [0] * len(token_ids_0) + [1] * len(token_ids_1) 924 | 925 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): 926 | """ 927 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks 928 | by concatenating and adding special tokens. 929 | A RoBERTa sequence has the following format: 930 | single sequence: X 931 | pair of sequences: A B 932 | """ 933 | logger.warning("This tokenizer does not make use of special tokens. Input is returned with no modification.") 934 | if token_ids_1 is None: 935 | return token_ids_0 936 | return token_ids_0 + token_ids_1 937 | 938 | def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): 939 | """ 940 | Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding 941 | special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. 942 | 943 | Args: 944 | token_ids_0: list of ids (must not contain special tokens) 945 | token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids 946 | for sequence pairs 947 | already_has_special_tokens: (default False) Set to True if the token list is already formated with 948 | special tokens for the model 949 | 950 | Returns: 951 | A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token. 952 | """ 953 | return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0)) 954 | 955 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False): 956 | """ Converts a single index or a sequence of indices (integers) in a token " 957 | (resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens. 958 | 959 | Args: 960 | skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False 961 | """ 962 | if isinstance(ids, int): 963 | if ids in self.added_tokens_decoder: 964 | return self.added_tokens_decoder[ids] 965 | else: 966 | return self._convert_id_to_token(ids) 967 | tokens = [] 968 | for index in ids: 969 | if skip_special_tokens and index in self.all_special_ids: 970 | continue 971 | if index in self.added_tokens_decoder: 972 | tokens.append(self.added_tokens_decoder[index]) 973 | else: 974 | tokens.append(self._convert_id_to_token(index)) 975 | return tokens 976 | 977 | def _convert_id_to_token(self, index): 978 | raise NotImplementedError 979 | 980 | def convert_tokens_to_string(self, tokens): 981 | """ Converts a sequence of tokens (string) in a single string. 982 | The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids)) 983 | but we often want to remove sub-word tokenization artifacts at the same time. 984 | """ 985 | return ' '.join(self.convert_ids_to_tokens(tokens)) 986 | 987 | def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): 988 | """ 989 | Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary 990 | with options to remove special tokens and clean up tokenization spaces. 991 | Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``. 992 | 993 | Args: 994 | token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods. 995 | skip_special_tokens: if set to True, will replace special tokens. 996 | clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces. 997 | """ 998 | filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) 999 | 1000 | # To avoid mixing byte-level and unicode for byte-level BPT 1001 | # we need to build string separatly for added tokens and byte-level tokens 1002 | # cf. https://github.com/huggingface/transformers/issues/1133 1003 | sub_texts = [] 1004 | current_sub_text = [] 1005 | for token in filtered_tokens: 1006 | if skip_special_tokens and token in self.all_special_ids: 1007 | continue 1008 | if token in self.added_tokens_encoder: 1009 | if current_sub_text: 1010 | sub_texts.append(self.convert_tokens_to_string(current_sub_text)) 1011 | current_sub_text = [] 1012 | sub_texts.append(" " + token) 1013 | else: 1014 | current_sub_text.append(token) 1015 | if current_sub_text: 1016 | sub_texts.append(self.convert_tokens_to_string(current_sub_text)) 1017 | text = ''.join(sub_texts) 1018 | 1019 | if clean_up_tokenization_spaces: 1020 | clean_text = self.clean_up_tokenization(text) 1021 | return clean_text 1022 | else: 1023 | return text 1024 | 1025 | @property 1026 | def special_tokens_map(self): 1027 | """ A dictionary mapping special token class attribute (cls_token, unk_token...) to their 1028 | values ('', ''...) 1029 | """ 1030 | set_attr = {} 1031 | for attr in self.SPECIAL_TOKENS_ATTRIBUTES: 1032 | attr_value = getattr(self, "_" + attr) 1033 | if attr_value: 1034 | set_attr[attr] = attr_value 1035 | return set_attr 1036 | 1037 | @property 1038 | def all_special_tokens(self): 1039 | """ List all the special tokens ('', ''...) mapped to class attributes 1040 | (cls_token, unk_token...). 1041 | """ 1042 | all_toks = [] 1043 | set_attr = self.special_tokens_map 1044 | for attr_value in set_attr.values(): 1045 | all_toks = all_toks + (list(attr_value) if isinstance(attr_value, (list, tuple)) else [attr_value]) 1046 | all_toks = list(set(all_toks)) 1047 | return all_toks 1048 | 1049 | @property 1050 | def all_special_ids(self): 1051 | """ List the vocabulary indices of the special tokens ('', ''...) mapped to 1052 | class attributes (cls_token, unk_token...). 1053 | """ 1054 | all_toks = self.all_special_tokens 1055 | all_ids = list(self._convert_token_to_id(t) for t in all_toks) 1056 | return all_ids 1057 | 1058 | @staticmethod 1059 | def clean_up_tokenization(out_string): 1060 | """ Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms. 1061 | """ 1062 | out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',' 1063 | ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't" 1064 | ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re") 1065 | return out_string 1066 | -------------------------------------------------------------------------------- /usage_example.py: -------------------------------------------------------------------------------- 1 | from albert_zh import AlbertConfig,AlbertForSequenceClassification,AlbertTokenizer 2 | import torch 3 | if __name__ == "__main__": 4 | tokenizer = AlbertTokenizer.from_pretrained('albert_tiny/vocab.txt') 5 | model_config = AlbertConfig.from_json_file('./albert_tiny/config.json') 6 | model = AlbertForSequenceClassification.from_pretrained('./albert_tiny',config = model_config) 7 | 8 | intput_str = '周杰倫,臺灣著名華語流行歌曲男歌手、音樂家、唱片製片人。同時是演員、導演,也是電競團隊隊長兼老闆、服飾品牌老闆。以其個人風格和聲歌手樂創作能力著稱,影響華語樂壇。 在2000年,周杰倫發行了他的首張專輯《Jay》,從屬於唱片公司阿爾發音樂。' 9 | input_ids = torch.tensor(tokenizer.encode(intput_str)).unsqueeze(0) # Batch size 1 10 | labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 11 | outputs = model(input_ids, labels=labels) 12 | loss, logits = outputs[:2] 13 | 14 | print(loss,logits) --------------------------------------------------------------------------------