├── .gitignore ├── README.md ├── bert ├── __init__.py ├── configuration_bert.py ├── configuration_utils.py ├── file_utils.py ├── modeling_bert.py ├── modeling_utils.py ├── optimization.py ├── tokenization_bert.py └── tokenization_utils.py ├── evaluate-v1.0.py ├── modeling.py ├── requirements.txt ├── run.sh ├── run_coqa.py └── run_coqa_dataset_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | coqa_utils_test.py 2 | coqa-dev-v1.0.json 3 | coqa-dev-v1.0.json_* 4 | coqa-train-v1.0.json 5 | coqa-train-v1.0.json_* 6 | *.json 7 | vocab.txt 8 | *.log 9 | .vscode/* 10 | *.pyc 11 | *.amax 12 | *.bin 13 | CoQAPreprocess.py 14 | CoQAUtils.py 15 | GeneralUtils.py 16 | nohup.out 17 | parallel.py 18 | run_squad_dataset_utils.py 19 | SDNetTrainer.py 20 | squad_utils_test.py 21 | to_official.py 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bert4CoQA 2 | ## Introduction 3 | This code is example demonstrating how to apply [Bert](https://arxiv.org/abs/1810.04805) on [CoQA Challenge](https://stanfordnlp.github.io/coqa/). 4 | 5 | Code is basically combined from [Transformer](https://github.com/huggingface/pytorch-pretrained-BERT), [run_squad.py](https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/run_squad.py), [SDNet](https://github.com/microsoft/SDNet) and [SMRCToolkit](https://github.com/sogou/SMRCToolkit). 6 | 7 | I train model with config: 8 | - --type bert 9 | - --bert_model bert-base-uncased 10 | - --num_train_epochs 2.0 11 | - --do_lower_case 12 | - --max_seq_length 512 13 | - --doc_stride 128 14 | - --max_query_length 64 15 | - --train_batch_size 12 16 | - --learning_rate 3e-5 17 | - --warmup_proportion 0.1 18 | - --max_grad_norm -1 19 | - --weight_decay 0.01 20 | - --fp16 21 | 22 | on **1x TITAN Xp** in **3 Hours** and achieve **78.0 F1-score** on dev-set. 23 | 24 | That can definitely be improved, and if you found better hyper-parameters, you are welcome to raise an issue :) 25 | 26 | **Not tested on multi-machine training** 27 | 28 | ## Requirement 29 | check requirement.txt or 30 | > pip install -r requirement.txt 31 | 32 | ## How to run 33 | make sure that: 34 | 1. Put *train-coqa-v1.0.json* and *dev-coqa-v1.0.json* on the same dict with *run_coqa.py* 35 | 2. The binary file, config file, and vocab file of bert_model in your bert_model dict name as *pytorch_model.bin*, *config.json*, *vocab.txt* 36 | 3. Enough memory on GPU [according to this](https://github.com/google-research/bert#out-of-memory-issues), you can tune *--train_batch_size*, *--gradient_accumulation_steps*, *--max_seq_length* and *--fp16* for memeory saving. 37 | 38 | and run 39 | > python run_coqa.py --bert_model your_bertmodel_dir --output_dir your_outputdir \[optional\] 40 | 41 | or edit and run *run.sh* 42 | 43 | for calculating F1-score, use *evaluate-v1.0.py* 44 | > python evaluate-v1.0.py --data-file --pred-file 45 | 46 | ## BERT Models URL 47 | 48 | #### pytorch_model.bin 49 | ```python 50 | BERT_PRETRAINED_MODEL_ARCHIVE_MAP = { 51 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin", 52 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin", 53 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin", 54 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin", 55 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin", 56 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin", 57 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin", 58 | "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin", 59 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin", 60 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin", 61 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin", 62 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin", 63 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin", 64 | "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-pytorch_model.bin", 65 | "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-pytorch_model.bin", 66 | "bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-pytorch_model.bin", 67 | "bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-pytorch_model.bin", 68 | "bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-pytorch_model.bin", 69 | "bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-pytorch_model.bin", 70 | "bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/pytorch_model.bin", 71 | "bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/pytorch_model.bin", 72 | "bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/pytorch_model.bin", 73 | } 74 | ``` 75 | 76 | #### config.json 77 | 78 | ```python 79 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 80 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", 81 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", 82 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", 83 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", 84 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", 85 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", 86 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", 87 | "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", 88 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", 89 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", 90 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", 91 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", 92 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", 93 | "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json", 94 | "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json", 95 | "bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-config.json", 96 | "bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-config.json", 97 | "bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-config.json", 98 | "bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json", 99 | "bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json", 100 | "bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json", 101 | "bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json", 102 | } 103 | ``` 104 | 105 | #### vocab.txt 106 | 107 | ```python 108 | PRETRAINED_VOCAB_FILES_MAP = { 109 | "vocab_file": { 110 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 111 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 112 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 113 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 114 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 115 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 116 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 117 | "bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", 118 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", 119 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", 120 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", 121 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", 122 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", 123 | "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt", 124 | "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt", 125 | "bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt", 126 | "bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt", 127 | "bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/vocab.txt", 128 | } 129 | } 130 | ``` 131 | 132 | ## Contact 133 | If you have any questions, please new an issue or contact me, adamluo1995@gmail.com. 134 | -------------------------------------------------------------------------------- /bert/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_bert import BertModel, BertPreTrainedModel 2 | # from .modeling_roberta import RobertaModel 3 | from .tokenization_bert import BertTokenizer 4 | # from .tokenization_roberta import RobertaTokenizer 5 | from .optimization import AdamW, WarmupLinearSchedule -------------------------------------------------------------------------------- /bert/configuration_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT model configuration """ 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import json 21 | import logging 22 | import sys 23 | from io import open 24 | 25 | from .configuration_utils import PretrainedConfig 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", 31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", 32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", 33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", 34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", 35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", 36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", 37 | 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", 38 | 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", 39 | 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", 40 | 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", 41 | 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", 42 | 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", 43 | } 44 | 45 | 46 | class BertConfig(PretrainedConfig): 47 | r""" 48 | :class:`~transformers.BertConfig` is the configuration class to store the configuration of a 49 | `BertModel`. 50 | 51 | 52 | Arguments: 53 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. 54 | hidden_size: Size of the encoder layers and the pooler layer. 55 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 56 | num_attention_heads: Number of attention heads for each attention layer in 57 | the Transformer encoder. 58 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 59 | layer in the Transformer encoder. 60 | hidden_act: The non-linear activation function (function or string) in the 61 | encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported. 62 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 63 | layers in the embeddings, encoder, and pooler. 64 | attention_probs_dropout_prob: The dropout ratio for the attention 65 | probabilities. 66 | max_position_embeddings: The maximum sequence length that this model might 67 | ever be used with. Typically set this to something large just in case 68 | (e.g., 512 or 1024 or 2048). 69 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 70 | `BertModel`. 71 | initializer_range: The sttdev of the truncated_normal_initializer for 72 | initializing all weight matrices. 73 | layer_norm_eps: The epsilon used by LayerNorm. 74 | """ 75 | pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP 76 | 77 | def __init__(self, 78 | vocab_size_or_config_json_file=30522, 79 | hidden_size=768, 80 | num_hidden_layers=12, 81 | num_attention_heads=12, 82 | intermediate_size=3072, 83 | hidden_act="gelu", 84 | hidden_dropout_prob=0.1, 85 | attention_probs_dropout_prob=0.1, 86 | max_position_embeddings=512, 87 | type_vocab_size=2, 88 | initializer_range=0.02, 89 | layer_norm_eps=1e-12, 90 | **kwargs): 91 | super(BertConfig, self).__init__(**kwargs) 92 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 93 | and isinstance(vocab_size_or_config_json_file, unicode)): 94 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 95 | json_config = json.loads(reader.read()) 96 | for key, value in json_config.items(): 97 | self.__dict__[key] = value 98 | elif isinstance(vocab_size_or_config_json_file, int): 99 | self.vocab_size = vocab_size_or_config_json_file 100 | self.hidden_size = hidden_size 101 | self.num_hidden_layers = num_hidden_layers 102 | self.num_attention_heads = num_attention_heads 103 | self.hidden_act = hidden_act 104 | self.intermediate_size = intermediate_size 105 | self.hidden_dropout_prob = hidden_dropout_prob 106 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 107 | self.max_position_embeddings = max_position_embeddings 108 | self.type_vocab_size = type_vocab_size 109 | self.initializer_range = initializer_range 110 | self.layer_norm_eps = layer_norm_eps 111 | else: 112 | raise ValueError("First argument must be either a vocabulary size (int)" 113 | " or the path to a pretrained model config file (str)") 114 | -------------------------------------------------------------------------------- /bert/configuration_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Configuration base class and utilities.""" 17 | 18 | from __future__ import (absolute_import, division, print_function, 19 | unicode_literals) 20 | 21 | import copy 22 | import json 23 | import logging 24 | import os 25 | from io import open 26 | 27 | from .file_utils import cached_path, CONFIG_NAME 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | class PretrainedConfig(object): 32 | r""" Base class for all configuration classes. 33 | Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations. 34 | 35 | Note: 36 | A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights. 37 | It only affects the model's configuration. 38 | 39 | Class attributes (overridden by derived classes): 40 | - ``pretrained_config_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained model configurations as values. 41 | 42 | Parameters: 43 | ``finetuning_task``: string, default `None`. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint. 44 | ``num_labels``: integer, default `2`. Number of classes to use when the model is a classification model (sequences/tokens) 45 | ``output_attentions``: boolean, default `False`. Should the model returns attentions weights. 46 | ``output_hidden_states``: string, default `False`. Should the model returns all hidden-states. 47 | ``torchscript``: string, default `False`. Is the model used with Torchscript. 48 | """ 49 | pretrained_config_archive_map = {} 50 | 51 | def __init__(self, **kwargs): 52 | self.finetuning_task = kwargs.pop('finetuning_task', None) 53 | self.num_labels = kwargs.pop('num_labels', 2) 54 | self.output_attentions = kwargs.pop('output_attentions', False) 55 | self.output_hidden_states = kwargs.pop('output_hidden_states', False) 56 | self.torchscript = kwargs.pop('torchscript', False) 57 | self.use_bfloat16 = kwargs.pop('use_bfloat16', False) 58 | self.pruned_heads = kwargs.pop('pruned_heads', {}) 59 | 60 | def save_pretrained(self, save_directory): 61 | """ Save a configuration object to the directory `save_directory`, so that it 62 | can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method. 63 | """ 64 | assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" 65 | 66 | # If we save using the predefined names, we can load using `from_pretrained` 67 | output_config_file = os.path.join(save_directory, CONFIG_NAME) 68 | 69 | self.to_json_file(output_config_file) 70 | logger.info("Configuration saved in {}".format(output_config_file)) 71 | 72 | @classmethod 73 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): 74 | r""" Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration. 75 | 76 | Parameters: 77 | pretrained_model_name_or_path: either: 78 | 79 | - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``. 80 | - a path to a `directory` containing a configuration file saved using the :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. 81 | - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``. 82 | 83 | cache_dir: (`optional`) string: 84 | Path to a directory in which a downloaded pre-trained model 85 | configuration should be cached if the standard cache should not be used. 86 | 87 | kwargs: (`optional`) dict: key/value pairs with which to update the configuration object after loading. 88 | 89 | - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. 90 | - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter. 91 | 92 | force_download: (`optional`) boolean, default False: 93 | Force to (re-)download the model weights and configuration files and override the cached versions if they exists. 94 | 95 | proxies: (`optional`) dict, default None: 96 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. 97 | The proxies are used on each request. 98 | 99 | return_unused_kwargs: (`optional`) bool: 100 | 101 | - If False, then this function returns just the final configuration object. 102 | - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored. 103 | 104 | Examples:: 105 | 106 | # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a 107 | # derived class: BertConfig 108 | config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. 109 | config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` 110 | config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json') 111 | config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) 112 | assert config.output_attention == True 113 | config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, 114 | foo=False, return_unused_kwargs=True) 115 | assert config.output_attention == True 116 | assert unused_kwargs == {'foo': False} 117 | 118 | """ 119 | cache_dir = kwargs.pop('cache_dir', None) 120 | force_download = kwargs.pop('force_download', False) 121 | proxies = kwargs.pop('proxies', None) 122 | return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) 123 | 124 | if pretrained_model_name_or_path in cls.pretrained_config_archive_map: 125 | config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] 126 | elif os.path.isdir(pretrained_model_name_or_path): 127 | config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) 128 | else: 129 | config_file = pretrained_model_name_or_path 130 | # redirect to the cache, if necessary 131 | try: 132 | resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) 133 | except EnvironmentError as e: 134 | if pretrained_model_name_or_path in cls.pretrained_config_archive_map: 135 | logger.error( 136 | "Couldn't reach server at '{}' to download pretrained model configuration file.".format( 137 | config_file)) 138 | else: 139 | logger.error( 140 | "Model name '{}' was not found in model name list ({}). " 141 | "We assumed '{}' was a path or url but couldn't find any file " 142 | "associated to this path or url.".format( 143 | pretrained_model_name_or_path, 144 | ', '.join(cls.pretrained_config_archive_map.keys()), 145 | config_file)) 146 | raise e 147 | if resolved_config_file == config_file: 148 | logger.info("loading configuration file {}".format(config_file)) 149 | else: 150 | logger.info("loading configuration file {} from cache at {}".format( 151 | config_file, resolved_config_file)) 152 | 153 | # Load config 154 | config = cls.from_json_file(resolved_config_file) 155 | 156 | if hasattr(config, 'pruned_heads'): 157 | config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items()) 158 | 159 | # Update config with kwargs if needed 160 | to_remove = [] 161 | for key, value in kwargs.items(): 162 | if hasattr(config, key): 163 | setattr(config, key, value) 164 | to_remove.append(key) 165 | for key in to_remove: 166 | kwargs.pop(key, None) 167 | 168 | logger.info("Model config %s", config) 169 | if return_unused_kwargs: 170 | return config, kwargs 171 | else: 172 | return config 173 | 174 | @classmethod 175 | def from_dict(cls, json_object): 176 | """Constructs a `Config` from a Python dictionary of parameters.""" 177 | config = cls(vocab_size_or_config_json_file=-1) 178 | for key, value in json_object.items(): 179 | setattr(config, key, value) 180 | return config 181 | 182 | @classmethod 183 | def from_json_file(cls, json_file): 184 | """Constructs a `BertConfig` from a json file of parameters.""" 185 | with open(json_file, "r", encoding='utf-8') as reader: 186 | text = reader.read() 187 | return cls.from_dict(json.loads(text)) 188 | 189 | def __eq__(self, other): 190 | return self.__dict__ == other.__dict__ 191 | 192 | def __repr__(self): 193 | return str(self.to_json_string()) 194 | 195 | def to_dict(self): 196 | """Serializes this instance to a Python dictionary.""" 197 | output = copy.deepcopy(self.__dict__) 198 | return output 199 | 200 | def to_json_string(self): 201 | """Serializes this instance to a JSON string.""" 202 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 203 | 204 | def to_json_file(self, json_file_path): 205 | """ Save this instance to a json file.""" 206 | with open(json_file_path, "w", encoding='utf-8') as writer: 207 | writer.write(self.to_json_string()) 208 | -------------------------------------------------------------------------------- /bert/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import sys 9 | import json 10 | import logging 11 | import os 12 | import six 13 | import shutil 14 | import tempfile 15 | import fnmatch 16 | from functools import wraps 17 | from hashlib import sha256 18 | from io import open 19 | 20 | import boto3 21 | from botocore.config import Config 22 | from botocore.exceptions import ClientError 23 | import requests 24 | from tqdm import tqdm 25 | 26 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 27 | 28 | try: 29 | import tensorflow as tf 30 | assert int(tf.__version__[0]) >= 2 31 | _tf_available = True # pylint: disable=invalid-name 32 | logger.info("TensorFlow version {} available.".format(tf.__version__)) 33 | except (ImportError, AssertionError): 34 | _tf_available = False # pylint: disable=invalid-name 35 | 36 | try: 37 | import torch 38 | _torch_available = True # pylint: disable=invalid-name 39 | logger.info("PyTorch version {} available.".format(torch.__version__)) 40 | except ImportError: 41 | _torch_available = False # pylint: disable=invalid-name 42 | 43 | 44 | try: 45 | from torch.hub import _get_torch_home 46 | torch_cache_home = _get_torch_home() 47 | except ImportError: 48 | torch_cache_home = os.path.expanduser( 49 | os.getenv('TORCH_HOME', os.path.join( 50 | os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) 51 | default_cache_path = os.path.join(torch_cache_home, 'transformers') 52 | 53 | try: 54 | from urllib.parse import urlparse 55 | except ImportError: 56 | from urlparse import urlparse 57 | 58 | try: 59 | from pathlib import Path 60 | PYTORCH_PRETRAINED_BERT_CACHE = Path( 61 | os.getenv('PYTORCH_TRANSFORMERS_CACHE', os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path))) 62 | except (AttributeError, ImportError): 63 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_TRANSFORMERS_CACHE', 64 | os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 65 | default_cache_path)) 66 | 67 | PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility 68 | TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility 69 | 70 | WEIGHTS_NAME = "pytorch_model.bin" 71 | TF2_WEIGHTS_NAME = 'tf_model.h5' 72 | TF_WEIGHTS_NAME = 'model.ckpt' 73 | CONFIG_NAME = "config.json" 74 | 75 | def is_torch_available(): 76 | return _torch_available 77 | 78 | def is_tf_available(): 79 | return _tf_available 80 | 81 | if not six.PY2: 82 | def add_start_docstrings(*docstr): 83 | def docstring_decorator(fn): 84 | fn.__doc__ = ''.join(docstr) + fn.__doc__ 85 | return fn 86 | return docstring_decorator 87 | 88 | def add_end_docstrings(*docstr): 89 | def docstring_decorator(fn): 90 | fn.__doc__ = fn.__doc__ + ''.join(docstr) 91 | return fn 92 | return docstring_decorator 93 | else: 94 | # Not possible to update class docstrings on python2 95 | def add_start_docstrings(*docstr): 96 | def docstring_decorator(fn): 97 | return fn 98 | return docstring_decorator 99 | 100 | def add_end_docstrings(*docstr): 101 | def docstring_decorator(fn): 102 | return fn 103 | return docstring_decorator 104 | 105 | def url_to_filename(url, etag=None): 106 | """ 107 | Convert `url` into a hashed filename in a repeatable way. 108 | If `etag` is specified, append its hash to the url's, delimited 109 | by a period. 110 | If the url ends with .h5 (Keras HDF5 weights) ands '.h5' to the name 111 | so that TF 2.0 can identify it as a HDF5 file 112 | (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380) 113 | """ 114 | url_bytes = url.encode('utf-8') 115 | url_hash = sha256(url_bytes) 116 | filename = url_hash.hexdigest() 117 | 118 | if etag: 119 | etag_bytes = etag.encode('utf-8') 120 | etag_hash = sha256(etag_bytes) 121 | filename += '.' + etag_hash.hexdigest() 122 | 123 | if url.endswith('.h5'): 124 | filename += '.h5' 125 | 126 | return filename 127 | 128 | 129 | def filename_to_url(filename, cache_dir=None): 130 | """ 131 | Return the url and etag (which may be ``None``) stored for `filename`. 132 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 133 | """ 134 | if cache_dir is None: 135 | cache_dir = TRANSFORMERS_CACHE 136 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 137 | cache_dir = str(cache_dir) 138 | 139 | cache_path = os.path.join(cache_dir, filename) 140 | if not os.path.exists(cache_path): 141 | raise EnvironmentError("file {} not found".format(cache_path)) 142 | 143 | meta_path = cache_path + '.json' 144 | if not os.path.exists(meta_path): 145 | raise EnvironmentError("file {} not found".format(meta_path)) 146 | 147 | with open(meta_path, encoding="utf-8") as meta_file: 148 | metadata = json.load(meta_file) 149 | url = metadata['url'] 150 | etag = metadata['etag'] 151 | 152 | return url, etag 153 | 154 | 155 | def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None): 156 | """ 157 | Given something that might be a URL (or might be a local path), 158 | determine which. If it's a URL, download the file and cache it, and 159 | return the path to the cached file. If it's already a local path, 160 | make sure the file exists and then return the path. 161 | Args: 162 | cache_dir: specify a cache directory to save the file to (overwrite the default cache dir). 163 | force_download: if True, re-dowload the file even if it's already cached in the cache dir. 164 | """ 165 | if cache_dir is None: 166 | cache_dir = TRANSFORMERS_CACHE 167 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 168 | url_or_filename = str(url_or_filename) 169 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 170 | cache_dir = str(cache_dir) 171 | 172 | parsed = urlparse(url_or_filename) 173 | 174 | if parsed.scheme in ('http', 'https', 's3'): 175 | # URL, so get it from the cache (downloading if necessary) 176 | return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies) 177 | elif os.path.exists(url_or_filename): 178 | # File, and it exists. 179 | return url_or_filename 180 | elif parsed.scheme == '': 181 | # File, but it doesn't exist. 182 | raise EnvironmentError("file {} not found".format(url_or_filename)) 183 | else: 184 | # Something unknown 185 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 186 | 187 | 188 | def split_s3_path(url): 189 | """Split a full s3 path into the bucket name and path.""" 190 | parsed = urlparse(url) 191 | if not parsed.netloc or not parsed.path: 192 | raise ValueError("bad s3 path {}".format(url)) 193 | bucket_name = parsed.netloc 194 | s3_path = parsed.path 195 | # Remove '/' at beginning of path. 196 | if s3_path.startswith("/"): 197 | s3_path = s3_path[1:] 198 | return bucket_name, s3_path 199 | 200 | 201 | def s3_request(func): 202 | """ 203 | Wrapper function for s3 requests in order to create more helpful error 204 | messages. 205 | """ 206 | 207 | @wraps(func) 208 | def wrapper(url, *args, **kwargs): 209 | try: 210 | return func(url, *args, **kwargs) 211 | except ClientError as exc: 212 | if int(exc.response["Error"]["Code"]) == 404: 213 | raise EnvironmentError("file {} not found".format(url)) 214 | else: 215 | raise 216 | 217 | return wrapper 218 | 219 | 220 | @s3_request 221 | def s3_etag(url, proxies=None): 222 | """Check ETag on S3 object.""" 223 | s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) 224 | bucket_name, s3_path = split_s3_path(url) 225 | s3_object = s3_resource.Object(bucket_name, s3_path) 226 | return s3_object.e_tag 227 | 228 | 229 | @s3_request 230 | def s3_get(url, temp_file, proxies=None): 231 | """Pull a file directly from S3.""" 232 | s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) 233 | bucket_name, s3_path = split_s3_path(url) 234 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 235 | 236 | 237 | def http_get(url, temp_file, proxies=None): 238 | req = requests.get(url, stream=True, proxies=proxies) 239 | content_length = req.headers.get('Content-Length') 240 | total = int(content_length) if content_length is not None else None 241 | progress = tqdm(unit="B", total=total) 242 | for chunk in req.iter_content(chunk_size=1024): 243 | if chunk: # filter out keep-alive new chunks 244 | progress.update(len(chunk)) 245 | temp_file.write(chunk) 246 | progress.close() 247 | 248 | 249 | def get_from_cache(url, cache_dir=None, force_download=False, proxies=None): 250 | """ 251 | Given a URL, look for the corresponding dataset in the local cache. 252 | If it's not there, download it. Then return the path to the cached file. 253 | """ 254 | if cache_dir is None: 255 | cache_dir = TRANSFORMERS_CACHE 256 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 257 | cache_dir = str(cache_dir) 258 | if sys.version_info[0] == 2 and not isinstance(cache_dir, str): 259 | cache_dir = str(cache_dir) 260 | 261 | if not os.path.exists(cache_dir): 262 | os.makedirs(cache_dir) 263 | 264 | # Get eTag to add to filename, if it exists. 265 | if url.startswith("s3://"): 266 | etag = s3_etag(url, proxies=proxies) 267 | else: 268 | try: 269 | response = requests.head(url, allow_redirects=True, proxies=proxies) 270 | if response.status_code != 200: 271 | etag = None 272 | else: 273 | etag = response.headers.get("ETag") 274 | except EnvironmentError: 275 | etag = None 276 | 277 | if sys.version_info[0] == 2 and etag is not None: 278 | etag = etag.decode('utf-8') 279 | filename = url_to_filename(url, etag) 280 | 281 | # get cache path to put the file 282 | cache_path = os.path.join(cache_dir, filename) 283 | 284 | # If we don't have a connection (etag is None) and can't identify the file 285 | # try to get the last downloaded one 286 | if not os.path.exists(cache_path) and etag is None: 287 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') 288 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) 289 | if matching_files: 290 | cache_path = os.path.join(cache_dir, matching_files[-1]) 291 | 292 | if not os.path.exists(cache_path) or force_download: 293 | # Download to temporary file, then copy to cache dir once finished. 294 | # Otherwise you get corrupt cache entries if the download gets interrupted. 295 | with tempfile.NamedTemporaryFile() as temp_file: 296 | logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) 297 | 298 | # GET file object 299 | if url.startswith("s3://"): 300 | s3_get(url, temp_file, proxies=proxies) 301 | else: 302 | http_get(url, temp_file, proxies=proxies) 303 | 304 | # we are copying the file before closing it, so flush to avoid truncation 305 | temp_file.flush() 306 | # shutil.copyfileobj() starts at the current position, so go to the start 307 | temp_file.seek(0) 308 | 309 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 310 | with open(cache_path, 'wb') as cache_file: 311 | shutil.copyfileobj(temp_file, cache_file) 312 | 313 | logger.info("creating metadata file for %s", cache_path) 314 | meta = {'url': url, 'etag': etag} 315 | meta_path = cache_path + '.json' 316 | with open(meta_path, 'w') as meta_file: 317 | output_string = json.dumps(meta) 318 | if sys.version_info[0] == 2 and isinstance(output_string, str): 319 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2 320 | meta_file.write(output_string) 321 | 322 | logger.info("removing temp file %s", temp_file.name) 323 | 324 | return cache_path 325 | -------------------------------------------------------------------------------- /bert/modeling_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model.""" 17 | 18 | from __future__ import (absolute_import, division, print_function, 19 | unicode_literals) 20 | 21 | import copy 22 | import json 23 | import logging 24 | import os 25 | from io import open 26 | 27 | import six 28 | import torch 29 | from torch import nn 30 | from torch.nn import CrossEntropyLoss 31 | from torch.nn import functional as F 32 | 33 | from .configuration_utils import PretrainedConfig 34 | from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | 39 | try: 40 | from torch.nn import Identity 41 | except ImportError: 42 | # Older PyTorch compatibility 43 | class Identity(nn.Module): 44 | r"""A placeholder identity operator that is argument-insensitive. 45 | """ 46 | def __init__(self, *args, **kwargs): 47 | super(Identity, self).__init__() 48 | 49 | def forward(self, input): 50 | return input 51 | 52 | class PreTrainedModel(nn.Module): 53 | r""" Base class for all models. 54 | 55 | :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models 56 | as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads. 57 | 58 | Class attributes (overridden by derived classes): 59 | - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. 60 | - ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values. 61 | - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments: 62 | 63 | - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`, 64 | - ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`, 65 | - ``path``: a path (string) to the TensorFlow checkpoint. 66 | 67 | - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. 68 | """ 69 | config_class = None 70 | pretrained_model_archive_map = {} 71 | load_tf_weights = lambda model, config, path: None 72 | base_model_prefix = "" 73 | 74 | def __init__(self, config, *inputs, **kwargs): 75 | super(PreTrainedModel, self).__init__() 76 | if not isinstance(config, PretrainedConfig): 77 | raise ValueError( 78 | "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. " 79 | "To create a model from a pretrained model use " 80 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 81 | self.__class__.__name__, self.__class__.__name__ 82 | )) 83 | # Save config in model 84 | self.config = config 85 | 86 | def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): 87 | """ Build a resized Embedding Module from a provided token Embedding Module. 88 | Increasing the size will add newly initialized vectors at the end 89 | Reducing the size will remove vectors from the end 90 | 91 | Args: 92 | new_num_tokens: (`optional`) int 93 | New number of tokens in the embedding matrix. 94 | Increasing the size will add newly initialized vectors at the end 95 | Reducing the size will remove vectors from the end 96 | If not provided or None: return the provided token Embedding Module. 97 | Return: ``torch.nn.Embeddings`` 98 | Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None 99 | """ 100 | if new_num_tokens is None: 101 | return old_embeddings 102 | 103 | old_num_tokens, old_embedding_dim = old_embeddings.weight.size() 104 | if old_num_tokens == new_num_tokens: 105 | return old_embeddings 106 | 107 | # Build new embeddings 108 | new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim) 109 | new_embeddings.to(old_embeddings.weight.device) 110 | 111 | # initialize all new embeddings (in particular added tokens) 112 | self._init_weights(new_embeddings) 113 | 114 | # Copy word embeddings from the previous weights 115 | num_tokens_to_copy = min(old_num_tokens, new_num_tokens) 116 | new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :] 117 | 118 | return new_embeddings 119 | 120 | def _tie_or_clone_weights(self, first_module, second_module): 121 | """ Tie or clone module weights depending of weither we are using TorchScript or not 122 | """ 123 | if self.config.torchscript: 124 | first_module.weight = nn.Parameter(second_module.weight.clone()) 125 | else: 126 | first_module.weight = second_module.weight 127 | 128 | if hasattr(first_module, 'bias') and first_module.bias is not None: 129 | first_module.bias.data = torch.nn.functional.pad( 130 | first_module.bias.data, 131 | (0, first_module.weight.shape[0] - first_module.bias.shape[0]), 132 | 'constant', 133 | 0 134 | ) 135 | 136 | def resize_token_embeddings(self, new_num_tokens=None): 137 | """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. 138 | Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. 139 | 140 | Arguments: 141 | 142 | new_num_tokens: (`optional`) int: 143 | New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. 144 | If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model. 145 | 146 | Return: ``torch.nn.Embeddings`` 147 | Pointer to the input tokens Embeddings Module of the model 148 | """ 149 | base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed 150 | model_embeds = base_model._resize_token_embeddings(new_num_tokens) 151 | if new_num_tokens is None: 152 | return model_embeds 153 | 154 | # Update base model and current model config 155 | self.config.vocab_size = new_num_tokens 156 | base_model.vocab_size = new_num_tokens 157 | 158 | # Tie weights again if needed 159 | if hasattr(self, 'tie_weights'): 160 | self.tie_weights() 161 | 162 | return model_embeds 163 | 164 | def init_weights(self): 165 | """ Initialize and prunes weights if needed. """ 166 | # Initialize weights 167 | self.apply(self._init_weights) 168 | 169 | # Prune heads if needed 170 | if self.config.pruned_heads: 171 | self.prune_heads(self.config.pruned_heads) 172 | 173 | def prune_heads(self, heads_to_prune): 174 | """ Prunes heads of the base model. 175 | 176 | Arguments: 177 | 178 | heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`). 179 | E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2. 180 | """ 181 | base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed 182 | 183 | # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads 184 | for layer, heads in heads_to_prune.items(): 185 | union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads) 186 | self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON 187 | 188 | base_model._prune_heads(heads_to_prune) 189 | 190 | def save_pretrained(self, save_directory): 191 | """ Save a model and its configuration file to a directory, so that it 192 | can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method. 193 | """ 194 | assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" 195 | 196 | # Only save the model it-self if we are using distributed training 197 | model_to_save = self.module if hasattr(self, 'module') else self 198 | 199 | # Save configuration file 200 | model_to_save.config.save_pretrained(save_directory) 201 | 202 | # If we save using the predefined names, we can load using `from_pretrained` 203 | output_model_file = os.path.join(save_directory, WEIGHTS_NAME) 204 | torch.save(model_to_save.state_dict(), output_model_file) 205 | logger.info("Model weights saved in {}".format(output_model_file)) 206 | 207 | @classmethod 208 | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): 209 | r"""Instantiate a pretrained pytorch model from a pre-trained model configuration. 210 | 211 | The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated) 212 | To train the model, you should first set it back in training mode with ``model.train()`` 213 | 214 | The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model. 215 | It is up to you to train those weights with a downstream fine-tuning task. 216 | 217 | The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded. 218 | 219 | Parameters: 220 | pretrained_model_name_or_path: either: 221 | 222 | - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. 223 | - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. 224 | - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. 225 | - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``) 226 | 227 | model_args: (`optional`) Sequence of positional arguments: 228 | All remaning positional arguments will be passed to the underlying model's ``__init__`` method 229 | 230 | config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`: 231 | Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: 232 | 233 | - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or 234 | - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. 235 | - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. 236 | 237 | state_dict: (`optional`) dict: 238 | an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file. 239 | This option can be used if you want to create a model from a pretrained configuration but load your own weights. 240 | In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. 241 | 242 | cache_dir: (`optional`) string: 243 | Path to a directory in which a downloaded pre-trained model 244 | configuration should be cached if the standard cache should not be used. 245 | 246 | force_download: (`optional`) boolean, default False: 247 | Force to (re-)download the model weights and configuration files and override the cached versions if they exists. 248 | 249 | proxies: (`optional`) dict, default None: 250 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. 251 | The proxies are used on each request. 252 | 253 | output_loading_info: (`optional`) boolean: 254 | Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. 255 | 256 | kwargs: (`optional`) Remaining dictionary of keyword arguments: 257 | Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: 258 | 259 | - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) 260 | - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. 261 | 262 | Examples:: 263 | 264 | model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. 265 | model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` 266 | model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading 267 | assert model.config.output_attention == True 268 | # Loading from a TF checkpoint file instead of a PyTorch model (slower) 269 | config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') 270 | model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config) 271 | 272 | """ 273 | config = kwargs.pop('config', None) 274 | state_dict = kwargs.pop('state_dict', None) 275 | cache_dir = kwargs.pop('cache_dir', None) 276 | from_tf = kwargs.pop('from_tf', False) 277 | force_download = kwargs.pop('force_download', False) 278 | proxies = kwargs.pop('proxies', None) 279 | output_loading_info = kwargs.pop('output_loading_info', False) 280 | 281 | # Load config 282 | if config is None: 283 | config, model_kwargs = cls.config_class.from_pretrained( 284 | pretrained_model_name_or_path, *model_args, 285 | cache_dir=cache_dir, return_unused_kwargs=True, 286 | force_download=force_download, 287 | **kwargs 288 | ) 289 | else: 290 | model_kwargs = kwargs 291 | 292 | # Load model 293 | if pretrained_model_name_or_path is not None: 294 | if pretrained_model_name_or_path in cls.pretrained_model_archive_map: 295 | archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path] 296 | elif os.path.isdir(pretrained_model_name_or_path): 297 | if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")): 298 | # Load from a TF 1.0 checkpoint 299 | archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") 300 | elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): 301 | # Load from a TF 2.0 checkpoint 302 | archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) 303 | elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): 304 | # Load from a PyTorch checkpoint 305 | archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) 306 | else: 307 | raise EnvironmentError("Error no file named {} found in directory {} or `from_tf` set to False".format( 308 | [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"], 309 | pretrained_model_name_or_path)) 310 | elif os.path.isfile(pretrained_model_name_or_path): 311 | archive_file = pretrained_model_name_or_path 312 | else: 313 | assert from_tf, "Error finding file {}, no file or TF 1.X checkpoint found".format(pretrained_model_name_or_path) 314 | archive_file = pretrained_model_name_or_path + ".index" 315 | 316 | # redirect to the cache, if necessary 317 | try: 318 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) 319 | except EnvironmentError as e: 320 | if pretrained_model_name_or_path in cls.pretrained_model_archive_map: 321 | logger.error( 322 | "Couldn't reach server at '{}' to download pretrained weights.".format( 323 | archive_file)) 324 | else: 325 | logger.error( 326 | "Model name '{}' was not found in model name list ({}). " 327 | "We assumed '{}' was a path or url but couldn't find any file " 328 | "associated to this path or url.".format( 329 | pretrained_model_name_or_path, 330 | ', '.join(cls.pretrained_model_archive_map.keys()), 331 | archive_file)) 332 | raise e 333 | if resolved_archive_file == archive_file: 334 | logger.info("loading weights file {}".format(archive_file)) 335 | else: 336 | logger.info("loading weights file {} from cache at {}".format( 337 | archive_file, resolved_archive_file)) 338 | else: 339 | resolved_archive_file = None 340 | 341 | # Instantiate model. 342 | model = cls(config, *model_args, **model_kwargs) 343 | 344 | if state_dict is None and not from_tf: 345 | state_dict = torch.load(resolved_archive_file, map_location='cpu') 346 | 347 | missing_keys = [] 348 | unexpected_keys = [] 349 | error_msgs = [] 350 | 351 | if from_tf: 352 | if resolved_archive_file.endswith('.index'): 353 | # Load from a TensorFlow 1.X checkpoint - provided by original authors 354 | model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index' 355 | else: 356 | # Load from our TensorFlow 2.0 checkpoints 357 | try: 358 | from transformers import load_tf2_checkpoint_in_pytorch_model 359 | model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True) 360 | except ImportError as e: 361 | logger.error("Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " 362 | "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.") 363 | raise e 364 | else: 365 | # Convert old format to new format if needed from a PyTorch state_dict 366 | old_keys = [] 367 | new_keys = [] 368 | for key in state_dict.keys(): 369 | new_key = None 370 | if 'gamma' in key: 371 | new_key = key.replace('gamma', 'weight') 372 | if 'beta' in key: 373 | new_key = key.replace('beta', 'bias') 374 | if new_key: 375 | old_keys.append(key) 376 | new_keys.append(new_key) 377 | for old_key, new_key in zip(old_keys, new_keys): 378 | state_dict[new_key] = state_dict.pop(old_key) 379 | 380 | # copy state_dict so _load_from_state_dict can modify it 381 | metadata = getattr(state_dict, '_metadata', None) 382 | state_dict = state_dict.copy() 383 | if metadata is not None: 384 | state_dict._metadata = metadata 385 | 386 | def load(module, prefix=''): 387 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 388 | module._load_from_state_dict( 389 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 390 | for name, child in module._modules.items(): 391 | if child is not None: 392 | load(child, prefix + name + '.') 393 | 394 | # Make sure we are able to load base models as well as derived models (with heads) 395 | start_prefix = '' 396 | model_to_load = model 397 | if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): 398 | start_prefix = cls.base_model_prefix + '.' 399 | if hasattr(model, cls.base_model_prefix) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): 400 | model_to_load = getattr(model, cls.base_model_prefix) 401 | 402 | load(model_to_load, prefix=start_prefix) 403 | if len(missing_keys) > 0: 404 | logger.info("Weights of {} not initialized from pretrained model: {}".format( 405 | model.__class__.__name__, missing_keys)) 406 | if len(unexpected_keys) > 0: 407 | logger.info("Weights from pretrained model not used in {}: {}".format( 408 | model.__class__.__name__, unexpected_keys)) 409 | if len(error_msgs) > 0: 410 | raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 411 | model.__class__.__name__, "\n\t".join(error_msgs))) 412 | 413 | if hasattr(model, 'tie_weights'): 414 | model.tie_weights() # make sure word embedding weights are still tied 415 | 416 | # Set model in evaluation mode to desactivate DropOut modules by default 417 | model.eval() 418 | 419 | if output_loading_info: 420 | loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs} 421 | return model, loading_info 422 | 423 | return model 424 | 425 | 426 | class Conv1D(nn.Module): 427 | def __init__(self, nf, nx): 428 | """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2) 429 | Basically works like a Linear layer but the weights are transposed 430 | """ 431 | super(Conv1D, self).__init__() 432 | self.nf = nf 433 | w = torch.empty(nx, nf) 434 | nn.init.normal_(w, std=0.02) 435 | self.weight = nn.Parameter(w) 436 | self.bias = nn.Parameter(torch.zeros(nf)) 437 | 438 | def forward(self, x): 439 | size_out = x.size()[:-1] + (self.nf,) 440 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) 441 | x = x.view(*size_out) 442 | return x 443 | 444 | 445 | class PoolerStartLogits(nn.Module): 446 | """ Compute SQuAD start_logits from sequence hidden states. """ 447 | def __init__(self, config): 448 | super(PoolerStartLogits, self).__init__() 449 | self.dense = nn.Linear(config.hidden_size, 1) 450 | 451 | def forward(self, hidden_states, p_mask=None): 452 | """ Args: 453 | **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)` 454 | invalid position mask such as query and special symbols (PAD, SEP, CLS) 455 | 1.0 means token should be masked. 456 | """ 457 | x = self.dense(hidden_states).squeeze(-1) 458 | 459 | if p_mask is not None: 460 | if next(self.parameters()).dtype == torch.float16: 461 | x = x * (1 - p_mask) - 65500 * p_mask 462 | else: 463 | x = x * (1 - p_mask) - 1e30 * p_mask 464 | 465 | return x 466 | 467 | 468 | class PoolerEndLogits(nn.Module): 469 | """ Compute SQuAD end_logits from sequence hidden states and start token hidden state. 470 | """ 471 | def __init__(self, config): 472 | super(PoolerEndLogits, self).__init__() 473 | self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) 474 | self.activation = nn.Tanh() 475 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 476 | self.dense_1 = nn.Linear(config.hidden_size, 1) 477 | 478 | def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None): 479 | """ Args: 480 | One of ``start_states``, ``start_positions`` should be not None. 481 | If both are set, ``start_positions`` overrides ``start_states``. 482 | 483 | **start_states**: ``torch.LongTensor`` of shape identical to hidden_states 484 | hidden states of the first tokens for the labeled span. 485 | **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` 486 | position of the first token for the labeled span: 487 | **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)`` 488 | Mask of invalid position such as query and special symbols (PAD, SEP, CLS) 489 | 1.0 means token should be masked. 490 | """ 491 | assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None" 492 | if start_positions is not None: 493 | slen, hsz = hidden_states.shape[-2:] 494 | start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) 495 | start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) 496 | start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) 497 | 498 | x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) 499 | x = self.activation(x) 500 | x = self.LayerNorm(x) 501 | x = self.dense_1(x).squeeze(-1) 502 | 503 | if p_mask is not None: 504 | x = x * (1 - p_mask) - 1e30 * p_mask 505 | 506 | return x 507 | 508 | 509 | class PoolerAnswerClass(nn.Module): 510 | """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """ 511 | def __init__(self, config): 512 | super(PoolerAnswerClass, self).__init__() 513 | self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) 514 | self.activation = nn.Tanh() 515 | self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False) 516 | 517 | def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None): 518 | """ 519 | Args: 520 | One of ``start_states``, ``start_positions`` should be not None. 521 | If both are set, ``start_positions`` overrides ``start_states``. 522 | 523 | **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``. 524 | hidden states of the first tokens for the labeled span. 525 | **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` 526 | position of the first token for the labeled span. 527 | **cls_index**: torch.LongTensor of shape ``(batch_size,)`` 528 | position of the CLS token. If None, take the last token. 529 | 530 | note(Original repo): 531 | no dependency on end_feature so that we can obtain one single `cls_logits` 532 | for each sample 533 | """ 534 | hsz = hidden_states.shape[-1] 535 | assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None" 536 | if start_positions is not None: 537 | start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) 538 | start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz) 539 | 540 | if cls_index is not None: 541 | cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) 542 | cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz) 543 | else: 544 | cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) 545 | 546 | x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1)) 547 | x = self.activation(x) 548 | x = self.dense_1(x).squeeze(-1) 549 | 550 | return x 551 | 552 | 553 | class SQuADHead(nn.Module): 554 | r""" A SQuAD head inspired by XLNet. 555 | 556 | Parameters: 557 | config (:class:`~transformers.XLNetConfig`): Model configuration class with all the parameters of the model. 558 | 559 | Inputs: 560 | **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)`` 561 | hidden states of sequence tokens 562 | **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` 563 | position of the first token for the labeled span. 564 | **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` 565 | position of the last token for the labeled span. 566 | **cls_index**: torch.LongTensor of shape ``(batch_size,)`` 567 | position of the CLS token. If None, take the last token. 568 | **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)`` 569 | Whether the question has a possible answer in the paragraph or not. 570 | **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)`` 571 | Mask of invalid position such as query and special symbols (PAD, SEP, CLS) 572 | 1.0 means token should be masked. 573 | 574 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 575 | **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``: 576 | Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses. 577 | **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 578 | ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)`` 579 | Log probabilities for the top config.start_n_top start token possibilities (beam-search). 580 | **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 581 | ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)`` 582 | Indices for the top config.start_n_top start token possibilities (beam-search). 583 | **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 584 | ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)`` 585 | Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). 586 | **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 587 | ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)`` 588 | Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). 589 | **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 590 | ``torch.FloatTensor`` of shape ``(batch_size,)`` 591 | Log probabilities for the ``is_impossible`` label of the answers. 592 | """ 593 | def __init__(self, config): 594 | super(SQuADHead, self).__init__() 595 | self.start_n_top = config.start_n_top 596 | self.end_n_top = config.end_n_top 597 | 598 | self.start_logits = PoolerStartLogits(config) 599 | self.end_logits = PoolerEndLogits(config) 600 | self.answer_class = PoolerAnswerClass(config) 601 | 602 | def forward(self, hidden_states, start_positions=None, end_positions=None, 603 | cls_index=None, is_impossible=None, p_mask=None): 604 | outputs = () 605 | 606 | start_logits = self.start_logits(hidden_states, p_mask=p_mask) 607 | 608 | if start_positions is not None and end_positions is not None: 609 | # If we are on multi-GPU, let's remove the dimension added by batch splitting 610 | for x in (start_positions, end_positions, cls_index, is_impossible): 611 | if x is not None and x.dim() > 1: 612 | x.squeeze_(-1) 613 | 614 | # during training, compute the end logits based on the ground truth of the start position 615 | end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask) 616 | 617 | loss_fct = CrossEntropyLoss() 618 | start_loss = loss_fct(start_logits, start_positions) 619 | end_loss = loss_fct(end_logits, end_positions) 620 | total_loss = (start_loss + end_loss) / 2 621 | 622 | if cls_index is not None and is_impossible is not None: 623 | # Predict answerability from the representation of CLS and START 624 | cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index) 625 | loss_fct_cls = nn.BCEWithLogitsLoss() 626 | cls_loss = loss_fct_cls(cls_logits, is_impossible) 627 | 628 | # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss 629 | total_loss += cls_loss * 0.5 630 | 631 | outputs = (total_loss,) + outputs 632 | 633 | else: 634 | # during inference, compute the end logits based on beam search 635 | bsz, slen, hsz = hidden_states.size() 636 | start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen) 637 | 638 | start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top) 639 | start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) 640 | start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz) 641 | start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) 642 | 643 | hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz) 644 | p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None 645 | end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) 646 | end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) 647 | 648 | end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top) 649 | end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) 650 | end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) 651 | 652 | start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) 653 | cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) 654 | 655 | outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs 656 | 657 | # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits 658 | # or (if labels are provided) (total_loss,) 659 | return outputs 660 | 661 | 662 | class SequenceSummary(nn.Module): 663 | r""" Compute a single vector summary of a sequence hidden states according to various possibilities: 664 | Args of the config class: 665 | summary_type: 666 | - 'last' => [default] take the last token hidden state (like XLNet) 667 | - 'first' => take the first token hidden state (like Bert) 668 | - 'mean' => take the mean of all tokens hidden states 669 | - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2) 670 | - 'attn' => Not implemented now, use multi-head attention 671 | summary_use_proj: Add a projection after the vector extraction 672 | summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False. 673 | summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default 674 | summary_first_dropout: Add a dropout before the projection and activation 675 | summary_last_dropout: Add a dropout after the projection and activation 676 | """ 677 | def __init__(self, config): 678 | super(SequenceSummary, self).__init__() 679 | 680 | self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last' 681 | if self.summary_type == 'attn': 682 | # We should use a standard multi-head attention module with absolute positional embedding for that. 683 | # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 684 | # We can probably just use the multi-head attention module of PyTorch >=1.1.0 685 | raise NotImplementedError 686 | 687 | self.summary = Identity() 688 | if hasattr(config, 'summary_use_proj') and config.summary_use_proj: 689 | if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0: 690 | num_classes = config.num_labels 691 | else: 692 | num_classes = config.hidden_size 693 | self.summary = nn.Linear(config.hidden_size, num_classes) 694 | 695 | self.activation = Identity() 696 | if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh': 697 | self.activation = nn.Tanh() 698 | 699 | self.first_dropout = Identity() 700 | if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0: 701 | self.first_dropout = nn.Dropout(config.summary_first_dropout) 702 | 703 | self.last_dropout = Identity() 704 | if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0: 705 | self.last_dropout = nn.Dropout(config.summary_last_dropout) 706 | 707 | def forward(self, hidden_states, cls_index=None): 708 | """ hidden_states: float Tensor in shape [bsz, ..., seq_len, hidden_size], the hidden-states of the last layer. 709 | cls_index: [optional] position of the classification token if summary_type == 'cls_index', 710 | shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states. 711 | if summary_type == 'cls_index' and cls_index is None: 712 | we take the last token of the sequence as classification token 713 | """ 714 | if self.summary_type == 'last': 715 | output = hidden_states[:, -1] 716 | elif self.summary_type == 'first': 717 | output = hidden_states[:, 0] 718 | elif self.summary_type == 'mean': 719 | output = hidden_states.mean(dim=1) 720 | elif self.summary_type == 'cls_index': 721 | if cls_index is None: 722 | cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2]-1, dtype=torch.long) 723 | else: 724 | cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) 725 | cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),)) 726 | # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states 727 | output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) 728 | elif self.summary_type == 'attn': 729 | raise NotImplementedError 730 | 731 | output = self.first_dropout(output) 732 | output = self.summary(output) 733 | output = self.activation(output) 734 | output = self.last_dropout(output) 735 | 736 | return output 737 | 738 | 739 | def prune_linear_layer(layer, index, dim=0): 740 | """ Prune a linear layer (a model parameters) to keep only entries in index. 741 | Return the pruned layer as a new layer with requires_grad=True. 742 | Used to remove heads. 743 | """ 744 | index = index.to(layer.weight.device) 745 | W = layer.weight.index_select(dim, index).clone().detach() 746 | if layer.bias is not None: 747 | if dim == 1: 748 | b = layer.bias.clone().detach() 749 | else: 750 | b = layer.bias[index].clone().detach() 751 | new_size = list(layer.weight.size()) 752 | new_size[dim] = len(index) 753 | new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device) 754 | new_layer.weight.requires_grad = False 755 | new_layer.weight.copy_(W.contiguous()) 756 | new_layer.weight.requires_grad = True 757 | if layer.bias is not None: 758 | new_layer.bias.requires_grad = False 759 | new_layer.bias.copy_(b.contiguous()) 760 | new_layer.bias.requires_grad = True 761 | return new_layer 762 | 763 | 764 | def prune_conv1d_layer(layer, index, dim=1): 765 | """ Prune a Conv1D layer (a model parameters) to keep only entries in index. 766 | A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed. 767 | Return the pruned layer as a new layer with requires_grad=True. 768 | Used to remove heads. 769 | """ 770 | index = index.to(layer.weight.device) 771 | W = layer.weight.index_select(dim, index).clone().detach() 772 | if dim == 0: 773 | b = layer.bias.clone().detach() 774 | else: 775 | b = layer.bias[index].clone().detach() 776 | new_size = list(layer.weight.size()) 777 | new_size[dim] = len(index) 778 | new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device) 779 | new_layer.weight.requires_grad = False 780 | new_layer.weight.copy_(W.contiguous()) 781 | new_layer.weight.requires_grad = True 782 | new_layer.bias.requires_grad = False 783 | new_layer.bias.copy_(b.contiguous()) 784 | new_layer.bias.requires_grad = True 785 | return new_layer 786 | 787 | 788 | def prune_layer(layer, index, dim=None): 789 | """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index. 790 | Return the pruned layer as a new layer with requires_grad=True. 791 | Used to remove heads. 792 | """ 793 | if isinstance(layer, nn.Linear): 794 | return prune_linear_layer(layer, index, dim=0 if dim is None else dim) 795 | elif isinstance(layer, Conv1D): 796 | return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) 797 | else: 798 | raise ValueError("Can't prune layer of class {}".format(layer.__class__)) 799 | -------------------------------------------------------------------------------- /bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import logging 18 | import math 19 | 20 | import torch 21 | from torch.optim import Optimizer 22 | from torch.optim.lr_scheduler import LambdaLR 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | class ConstantLRSchedule(LambdaLR): 27 | """ Constant learning rate schedule. 28 | """ 29 | def __init__(self, optimizer, last_epoch=-1): 30 | super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch) 31 | 32 | 33 | class WarmupConstantSchedule(LambdaLR): 34 | """ Linear warmup and then constant. 35 | Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps. 36 | Keeps learning rate schedule equal to 1. after warmup_steps. 37 | """ 38 | def __init__(self, optimizer, warmup_steps, last_epoch=-1): 39 | self.warmup_steps = warmup_steps 40 | super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 41 | 42 | def lr_lambda(self, step): 43 | if step < self.warmup_steps: 44 | return float(step) / float(max(1.0, self.warmup_steps)) 45 | return 1. 46 | 47 | 48 | class WarmupLinearSchedule(LambdaLR): 49 | """ Linear warmup and then linear decay. 50 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 51 | Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps. 52 | """ 53 | def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1): 54 | self.warmup_steps = warmup_steps 55 | self.t_total = t_total 56 | super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 57 | 58 | def lr_lambda(self, step): 59 | if step < self.warmup_steps: 60 | return float(step) / float(max(1, self.warmup_steps)) 61 | return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps))) 62 | 63 | 64 | class WarmupCosineSchedule(LambdaLR): 65 | """ Linear warmup and then cosine decay. 66 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 67 | Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve. 68 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 69 | """ 70 | def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): 71 | self.warmup_steps = warmup_steps 72 | self.t_total = t_total 73 | self.cycles = cycles 74 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 75 | 76 | def lr_lambda(self, step): 77 | if step < self.warmup_steps: 78 | return float(step) / float(max(1.0, self.warmup_steps)) 79 | # progress after warmup 80 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 81 | return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) 82 | 83 | 84 | class WarmupCosineWithHardRestartsSchedule(LambdaLR): 85 | """ Linear warmup and then cosine cycles with hard restarts. 86 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 87 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying 88 | learning rate (with hard restarts). 89 | """ 90 | def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1): 91 | self.warmup_steps = warmup_steps 92 | self.t_total = t_total 93 | self.cycles = cycles 94 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 95 | 96 | def lr_lambda(self, step): 97 | if step < self.warmup_steps: 98 | return float(step) / float(max(1, self.warmup_steps)) 99 | # progress after warmup 100 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 101 | if progress >= 1.0: 102 | return 0.0 103 | return max(0.0, 0.5 * (1. + math.cos(math.pi * ((float(self.cycles) * progress) % 1.0)))) 104 | 105 | 106 | 107 | class AdamW(Optimizer): 108 | """ Implements Adam algorithm with weight decay fix. 109 | 110 | Parameters: 111 | lr (float): learning rate. Default 1e-3. 112 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) 113 | eps (float): Adams epsilon. Default: 1e-6 114 | weight_decay (float): Weight decay. Default: 0.0 115 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True. 116 | """ 117 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True): 118 | if lr < 0.0: 119 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 120 | if not 0.0 <= betas[0] < 1.0: 121 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])) 122 | if not 0.0 <= betas[1] < 1.0: 123 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])) 124 | if not 0.0 <= eps: 125 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) 126 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 127 | correct_bias=correct_bias) 128 | super(AdamW, self).__init__(params, defaults) 129 | 130 | def step(self, closure=None): 131 | """Performs a single optimization step. 132 | 133 | Arguments: 134 | closure (callable, optional): A closure that reevaluates the model 135 | and returns the loss. 136 | """ 137 | loss = None 138 | if closure is not None: 139 | loss = closure() 140 | 141 | for group in self.param_groups: 142 | for p in group['params']: 143 | if p.grad is None: 144 | continue 145 | grad = p.grad.data 146 | if grad.is_sparse: 147 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 148 | 149 | state = self.state[p] 150 | 151 | # State initialization 152 | if len(state) == 0: 153 | state['step'] = 0 154 | # Exponential moving average of gradient values 155 | state['exp_avg'] = torch.zeros_like(p.data) 156 | # Exponential moving average of squared gradient values 157 | state['exp_avg_sq'] = torch.zeros_like(p.data) 158 | 159 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 160 | beta1, beta2 = group['betas'] 161 | 162 | state['step'] += 1 163 | 164 | # Decay the first and second moment running average coefficient 165 | # In-place operations to update the averages at the same time 166 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) 167 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) 168 | denom = exp_avg_sq.sqrt().add_(group['eps']) 169 | 170 | step_size = group['lr'] 171 | if group['correct_bias']: # No bias correction for Bert 172 | bias_correction1 = 1.0 - beta1 ** state['step'] 173 | bias_correction2 = 1.0 - beta2 ** state['step'] 174 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 175 | 176 | p.data.addcdiv_(-step_size, exp_avg, denom) 177 | 178 | # Just adding the square of the weights to the loss function is *not* 179 | # the correct way of using L2 regularization/weight decay with Adam, 180 | # since that will interact with the m and v parameters in strange ways. 181 | # 182 | # Instead we want to decay the weights in a manner that doesn't interact 183 | # with the m/v parameters. This is equivalent to adding the square 184 | # of the weights to the loss with plain (non-momentum) SGD. 185 | # Add weight decay at the end (fixed version) 186 | if group['weight_decay'] > 0.0: 187 | p.data.add_(-group['lr'] * group['weight_decay'], p.data) 188 | 189 | return loss 190 | -------------------------------------------------------------------------------- /bert/tokenization_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | from io import open 24 | 25 | from .tokenization_utils import PreTrainedTokenizer 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} 30 | 31 | PRETRAINED_VOCAB_FILES_MAP = { 32 | 'vocab_file': 33 | { 34 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 35 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 36 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 37 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 38 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 39 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 40 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 41 | 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", 42 | 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", 43 | 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", 44 | 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", 45 | 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", 46 | 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", 47 | } 48 | } 49 | 50 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 51 | 'bert-base-uncased': 512, 52 | 'bert-large-uncased': 512, 53 | 'bert-base-cased': 512, 54 | 'bert-large-cased': 512, 55 | 'bert-base-multilingual-uncased': 512, 56 | 'bert-base-multilingual-cased': 512, 57 | 'bert-base-chinese': 512, 58 | 'bert-base-german-cased': 512, 59 | 'bert-large-uncased-whole-word-masking': 512, 60 | 'bert-large-cased-whole-word-masking': 512, 61 | 'bert-large-uncased-whole-word-masking-finetuned-squad': 512, 62 | 'bert-large-cased-whole-word-masking-finetuned-squad': 512, 63 | 'bert-base-cased-finetuned-mrpc': 512, 64 | } 65 | 66 | PRETRAINED_INIT_CONFIGURATION = { 67 | 'bert-base-uncased': {'do_lower_case': True}, 68 | 'bert-large-uncased': {'do_lower_case': True}, 69 | 'bert-base-cased': {'do_lower_case': False}, 70 | 'bert-large-cased': {'do_lower_case': False}, 71 | 'bert-base-multilingual-uncased': {'do_lower_case': True}, 72 | 'bert-base-multilingual-cased': {'do_lower_case': False}, 73 | 'bert-base-chinese': {'do_lower_case': False}, 74 | 'bert-base-german-cased': {'do_lower_case': False}, 75 | 'bert-large-uncased-whole-word-masking': {'do_lower_case': True}, 76 | 'bert-large-cased-whole-word-masking': {'do_lower_case': False}, 77 | 'bert-large-uncased-whole-word-masking-finetuned-squad': {'do_lower_case': True}, 78 | 'bert-large-cased-whole-word-masking-finetuned-squad': {'do_lower_case': False}, 79 | 'bert-base-cased-finetuned-mrpc': {'do_lower_case': False}, 80 | } 81 | 82 | 83 | def load_vocab(vocab_file): 84 | """Loads a vocabulary file into a dictionary.""" 85 | vocab = collections.OrderedDict() 86 | with open(vocab_file, "r", encoding="utf-8") as reader: 87 | tokens = reader.readlines() 88 | for index, token in enumerate(tokens): 89 | token = token.rstrip('\n') 90 | vocab[token] = index 91 | return vocab 92 | 93 | 94 | def whitespace_tokenize(text): 95 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 96 | text = text.strip() 97 | if not text: 98 | return [] 99 | tokens = text.split() 100 | return tokens 101 | 102 | 103 | class BertTokenizer(PreTrainedTokenizer): 104 | r""" 105 | Constructs a BertTokenizer. 106 | :class:`~transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece 107 | 108 | Args: 109 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 110 | do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False 111 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 112 | max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the 113 | minimum of this value (if specified) and the underlying BERT model's sequence length. 114 | never_split: List of tokens which will never be split during tokenization. Only has an effect when 115 | do_wordpiece_only=False 116 | """ 117 | 118 | vocab_files_names = VOCAB_FILES_NAMES 119 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 120 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 121 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 122 | 123 | def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, 124 | unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", 125 | mask_token="[MASK]", tokenize_chinese_chars=True, **kwargs): 126 | """Constructs a BertTokenizer. 127 | 128 | Args: 129 | **vocab_file**: Path to a one-wordpiece-per-line vocabulary file 130 | **do_lower_case**: (`optional`) boolean (default True) 131 | Whether to lower case the input 132 | Only has an effect when do_basic_tokenize=True 133 | **do_basic_tokenize**: (`optional`) boolean (default True) 134 | Whether to do basic tokenization before wordpiece. 135 | **never_split**: (`optional`) list of string 136 | List of tokens which will never be split during tokenization. 137 | Only has an effect when do_basic_tokenize=True 138 | **tokenize_chinese_chars**: (`optional`) boolean (default True) 139 | Whether to tokenize Chinese characters. 140 | This should likely be deactivated for Japanese: 141 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 142 | """ 143 | super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token, 144 | pad_token=pad_token, cls_token=cls_token, 145 | mask_token=mask_token, **kwargs) 146 | self.max_len_single_sentence = self.max_len - 2 # take into account special tokens 147 | self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens 148 | 149 | if not os.path.isfile(vocab_file): 150 | raise ValueError( 151 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 152 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 153 | self.vocab = load_vocab(vocab_file) 154 | self.ids_to_tokens = collections.OrderedDict( 155 | [(ids, tok) for tok, ids in self.vocab.items()]) 156 | self.do_basic_tokenize = do_basic_tokenize 157 | if do_basic_tokenize: 158 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 159 | never_split=never_split, 160 | tokenize_chinese_chars=tokenize_chinese_chars) 161 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) 162 | 163 | @property 164 | def vocab_size(self): 165 | return len(self.vocab) 166 | 167 | def _tokenize(self, text): 168 | split_tokens = [] 169 | if self.do_basic_tokenize: 170 | for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): 171 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 172 | split_tokens.append(sub_token) 173 | else: 174 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 175 | return split_tokens 176 | 177 | def _convert_token_to_id(self, token): 178 | """ Converts a token (str/unicode) in an id using the vocab. """ 179 | return self.vocab.get(token, self.vocab.get(self.unk_token)) 180 | 181 | def _convert_id_to_token(self, index): 182 | """Converts an index (integer) in a token (string/unicode) using the vocab.""" 183 | return self.ids_to_tokens.get(index, self.unk_token) 184 | 185 | def convert_tokens_to_string(self, tokens): 186 | """ Converts a sequence of tokens (string) in a single string. """ 187 | out_string = ' '.join(tokens).replace(' ##', '').strip() 188 | return out_string 189 | 190 | def add_special_tokens_single_sequence(self, token_ids): 191 | """ 192 | Adds special tokens to the a sequence for sequence classification tasks. 193 | A BERT sequence has the following format: [CLS] X [SEP] 194 | """ 195 | return [self.cls_token_id] + token_ids + [self.sep_token_id] 196 | 197 | def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1): 198 | """ 199 | Adds special tokens to a sequence pair for sequence classification tasks. 200 | A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP] 201 | """ 202 | sep = [self.sep_token_id] 203 | cls = [self.cls_token_id] 204 | 205 | return cls + token_ids_0 + sep + token_ids_1 + sep 206 | 207 | def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1): 208 | """ 209 | Creates a mask from the two sequences passed to be used in a sequence-pair classification task. 210 | A BERT sequence pair mask has the following format: 211 | 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 212 | | first sequence | second sequence 213 | """ 214 | sep = [self.sep_token_id] 215 | cls = [self.cls_token_id] 216 | 217 | return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] 218 | 219 | def save_vocabulary(self, vocab_path): 220 | """Save the tokenizer vocabulary to a directory or file.""" 221 | index = 0 222 | if os.path.isdir(vocab_path): 223 | vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file']) 224 | else: 225 | vocab_file = vocab_path 226 | with open(vocab_file, "w", encoding="utf-8") as writer: 227 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 228 | if index != token_index: 229 | logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." 230 | " Please check that the vocabulary is not corrupted!".format(vocab_file)) 231 | index = token_index 232 | writer.write(token + u'\n') 233 | index += 1 234 | return (vocab_file,) 235 | 236 | 237 | class BasicTokenizer(object): 238 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 239 | 240 | def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True): 241 | """ Constructs a BasicTokenizer. 242 | 243 | Args: 244 | **do_lower_case**: Whether to lower case the input. 245 | **never_split**: (`optional`) list of str 246 | Kept for backward compatibility purposes. 247 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) 248 | List of token not to split. 249 | **tokenize_chinese_chars**: (`optional`) boolean (default True) 250 | Whether to tokenize Chinese characters. 251 | This should likely be deactivated for Japanese: 252 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 253 | """ 254 | if never_split is None: 255 | never_split = [] 256 | self.do_lower_case = do_lower_case 257 | self.never_split = never_split 258 | self.tokenize_chinese_chars = tokenize_chinese_chars 259 | 260 | def tokenize(self, text, never_split=None): 261 | """ Basic Tokenization of a piece of text. 262 | Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer. 263 | 264 | Args: 265 | **never_split**: (`optional`) list of str 266 | Kept for backward compatibility purposes. 267 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) 268 | List of token not to split. 269 | """ 270 | never_split = self.never_split + (never_split if never_split is not None else []) 271 | text = self._clean_text(text) 272 | # This was added on November 1st, 2018 for the multilingual and Chinese 273 | # models. This is also applied to the English models now, but it doesn't 274 | # matter since the English models were not trained on any Chinese data 275 | # and generally don't have any Chinese data in them (there are Chinese 276 | # characters in the vocabulary because Wikipedia does have some Chinese 277 | # words in the English Wikipedia.). 278 | if self.tokenize_chinese_chars: 279 | text = self._tokenize_chinese_chars(text) 280 | orig_tokens = whitespace_tokenize(text) 281 | split_tokens = [] 282 | for token in orig_tokens: 283 | if self.do_lower_case and token not in never_split: 284 | token = token.lower() 285 | token = self._run_strip_accents(token) 286 | split_tokens.extend(self._run_split_on_punc(token)) 287 | 288 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 289 | return output_tokens 290 | 291 | def _run_strip_accents(self, text): 292 | """Strips accents from a piece of text.""" 293 | text = unicodedata.normalize("NFD", text) 294 | output = [] 295 | for char in text: 296 | cat = unicodedata.category(char) 297 | if cat == "Mn": 298 | continue 299 | output.append(char) 300 | return "".join(output) 301 | 302 | def _run_split_on_punc(self, text, never_split=None): 303 | """Splits punctuation on a piece of text.""" 304 | if never_split is not None and text in never_split: 305 | return [text] 306 | chars = list(text) 307 | i = 0 308 | start_new_word = True 309 | output = [] 310 | while i < len(chars): 311 | char = chars[i] 312 | if _is_punctuation(char): 313 | output.append([char]) 314 | start_new_word = True 315 | else: 316 | if start_new_word: 317 | output.append([]) 318 | start_new_word = False 319 | output[-1].append(char) 320 | i += 1 321 | 322 | return ["".join(x) for x in output] 323 | 324 | def _tokenize_chinese_chars(self, text): 325 | """Adds whitespace around any CJK character.""" 326 | output = [] 327 | for char in text: 328 | cp = ord(char) 329 | if self._is_chinese_char(cp): 330 | output.append(" ") 331 | output.append(char) 332 | output.append(" ") 333 | else: 334 | output.append(char) 335 | return "".join(output) 336 | 337 | def _is_chinese_char(self, cp): 338 | """Checks whether CP is the codepoint of a CJK character.""" 339 | # This defines a "chinese character" as anything in the CJK Unicode block: 340 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 341 | # 342 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 343 | # despite its name. The modern Korean Hangul alphabet is a different block, 344 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 345 | # space-separated words, so they are not treated specially and handled 346 | # like the all of the other languages. 347 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 348 | (cp >= 0x3400 and cp <= 0x4DBF) or # 349 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 350 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 351 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 352 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 353 | (cp >= 0xF900 and cp <= 0xFAFF) or # 354 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 355 | return True 356 | 357 | return False 358 | 359 | def _clean_text(self, text): 360 | """Performs invalid character removal and whitespace cleanup on text.""" 361 | output = [] 362 | for char in text: 363 | cp = ord(char) 364 | if cp == 0 or cp == 0xfffd or _is_control(char): 365 | continue 366 | if _is_whitespace(char): 367 | output.append(" ") 368 | else: 369 | output.append(char) 370 | return "".join(output) 371 | 372 | 373 | class WordpieceTokenizer(object): 374 | """Runs WordPiece tokenization.""" 375 | 376 | def __init__(self, vocab, unk_token, max_input_chars_per_word=100): 377 | self.vocab = vocab 378 | self.unk_token = unk_token 379 | self.max_input_chars_per_word = max_input_chars_per_word 380 | 381 | def tokenize(self, text): 382 | """Tokenizes a piece of text into its word pieces. 383 | 384 | This uses a greedy longest-match-first algorithm to perform tokenization 385 | using the given vocabulary. 386 | 387 | For example: 388 | input = "unaffable" 389 | output = ["un", "##aff", "##able"] 390 | 391 | Args: 392 | text: A single token or whitespace separated tokens. This should have 393 | already been passed through `BasicTokenizer`. 394 | 395 | Returns: 396 | A list of wordpiece tokens. 397 | """ 398 | 399 | output_tokens = [] 400 | for token in whitespace_tokenize(text): 401 | chars = list(token) 402 | if len(chars) > self.max_input_chars_per_word: 403 | output_tokens.append(self.unk_token) 404 | continue 405 | 406 | is_bad = False 407 | start = 0 408 | sub_tokens = [] 409 | while start < len(chars): 410 | end = len(chars) 411 | cur_substr = None 412 | while start < end: 413 | substr = "".join(chars[start:end]) 414 | if start > 0: 415 | substr = "##" + substr 416 | if substr in self.vocab: 417 | cur_substr = substr 418 | break 419 | end -= 1 420 | if cur_substr is None: 421 | is_bad = True 422 | break 423 | sub_tokens.append(cur_substr) 424 | start = end 425 | 426 | if is_bad: 427 | output_tokens.append(self.unk_token) 428 | else: 429 | output_tokens.extend(sub_tokens) 430 | return output_tokens 431 | 432 | 433 | def _is_whitespace(char): 434 | """Checks whether `chars` is a whitespace character.""" 435 | # \t, \n, and \r are technically contorl characters but we treat them 436 | # as whitespace since they are generally considered as such. 437 | if char == " " or char == "\t" or char == "\n" or char == "\r": 438 | return True 439 | cat = unicodedata.category(char) 440 | if cat == "Zs": 441 | return True 442 | return False 443 | 444 | 445 | def _is_control(char): 446 | """Checks whether `chars` is a control character.""" 447 | # These are technically control characters but we count them as whitespace 448 | # characters. 449 | if char == "\t" or char == "\n" or char == "\r": 450 | return False 451 | cat = unicodedata.category(char) 452 | if cat.startswith("C"): 453 | return True 454 | return False 455 | 456 | 457 | def _is_punctuation(char): 458 | """Checks whether `chars` is a punctuation character.""" 459 | cp = ord(char) 460 | # We treat all non-letter/number ASCII as punctuation. 461 | # Characters such as "^", "$", and "`" are not in the Unicode 462 | # Punctuation class but we treat them as punctuation anyways, for 463 | # consistency. 464 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 465 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 466 | return True 467 | cat = unicodedata.category(char) 468 | if cat.startswith("P"): 469 | return True 470 | return False 471 | -------------------------------------------------------------------------------- /evaluate-v1.0.py: -------------------------------------------------------------------------------- 1 | """Official evaluation script for CoQA. 2 | 3 | The code is based partially on SQuAD 2.0 evaluation script. 4 | """ 5 | import argparse 6 | import json 7 | import re 8 | import string 9 | import sys 10 | 11 | from collections import Counter, OrderedDict 12 | 13 | OPTS = None 14 | 15 | out_domain = ["reddit", "science"] 16 | in_domain = ["mctest", "gutenberg", "race", "cnn", "wikipedia"] 17 | domain_mappings = {"mctest":"children_stories", "gutenberg":"literature", "race":"mid-high_school", "cnn":"news", "wikipedia":"wikipedia", "science":"science", "reddit":"reddit"} 18 | 19 | 20 | class CoQAEvaluator(): 21 | 22 | def __init__(self, gold_file): 23 | self.gold_data, self.id_to_source = CoQAEvaluator.gold_answers_to_dict(gold_file) 24 | 25 | @staticmethod 26 | def gold_answers_to_dict(gold_file): 27 | dataset = json.load(open(gold_file)) 28 | gold_dict = {} 29 | id_to_source = {} 30 | for story in dataset['data']: 31 | source = story['source'] 32 | story_id = story['id'] 33 | id_to_source[story_id] = source 34 | questions = story['questions'] 35 | multiple_answers = [story['answers']] 36 | multiple_answers += story['additional_answers'].values() 37 | for i, qa in enumerate(questions): 38 | qid = qa['turn_id'] 39 | if i + 1 != qid: 40 | sys.stderr.write("Turn id should match index {}: {}\n".format(i + 1, qa)) 41 | gold_answers = [] 42 | for answers in multiple_answers: 43 | answer = answers[i] 44 | if qid != answer['turn_id']: 45 | sys.stderr.write("Question turn id does match answer: {} {}\n".format(qa, answer)) 46 | gold_answers.append(answer['input_text']) 47 | key = (story_id, qid) 48 | if key in gold_dict: 49 | sys.stderr.write("Gold file has duplicate stories: {}".format(source)) 50 | gold_dict[key] = gold_answers 51 | return gold_dict, id_to_source 52 | 53 | @staticmethod 54 | def preds_to_dict(pred_file): 55 | preds = json.load(open(pred_file)) 56 | pred_dict = {} 57 | for pred in preds: 58 | pred_dict[(pred['id'], pred['turn_id'])] = pred['answer'] 59 | return pred_dict 60 | 61 | @staticmethod 62 | def normalize_answer(s): 63 | """Lower text and remove punctuation, storys and extra whitespace.""" 64 | 65 | def remove_articles(text): 66 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 67 | return re.sub(regex, ' ', text) 68 | 69 | def white_space_fix(text): 70 | return ' '.join(text.split()) 71 | 72 | def remove_punc(text): 73 | exclude = set(string.punctuation) 74 | return ''.join(ch for ch in text if ch not in exclude) 75 | 76 | def lower(text): 77 | return text.lower() 78 | 79 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 80 | 81 | @staticmethod 82 | def get_tokens(s): 83 | if not s: return [] 84 | return CoQAEvaluator.normalize_answer(s).split() 85 | 86 | @staticmethod 87 | def compute_exact(a_gold, a_pred): 88 | return int(CoQAEvaluator.normalize_answer(a_gold) == CoQAEvaluator.normalize_answer(a_pred)) 89 | 90 | @staticmethod 91 | def compute_f1(a_gold, a_pred): 92 | gold_toks = CoQAEvaluator.get_tokens(a_gold) 93 | pred_toks = CoQAEvaluator.get_tokens(a_pred) 94 | common = Counter(gold_toks) & Counter(pred_toks) 95 | num_same = sum(common.values()) 96 | if len(gold_toks) == 0 or len(pred_toks) == 0: 97 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 98 | return int(gold_toks == pred_toks) 99 | if num_same == 0: 100 | return 0 101 | precision = 1.0 * num_same / len(pred_toks) 102 | recall = 1.0 * num_same / len(gold_toks) 103 | f1 = (2 * precision * recall) / (precision + recall) 104 | return f1 105 | 106 | @staticmethod 107 | def _compute_turn_score(a_gold_list, a_pred): 108 | f1_sum = 0.0 109 | em_sum = 0.0 110 | if len(a_gold_list) > 1: 111 | for i in range(len(a_gold_list)): 112 | # exclude the current answer 113 | gold_answers = a_gold_list[0:i] + a_gold_list[i + 1:] 114 | em_sum += max(CoQAEvaluator.compute_exact(a, a_pred) for a in gold_answers) 115 | f1_sum += max(CoQAEvaluator.compute_f1(a, a_pred) for a in gold_answers) 116 | else: 117 | em_sum += max(CoQAEvaluator.compute_exact(a, a_pred) for a in a_gold_list) 118 | f1_sum += max(CoQAEvaluator.compute_f1(a, a_pred) for a in a_gold_list) 119 | 120 | return {'em': em_sum / max(1, len(a_gold_list)), 'f1': f1_sum / max(1, len(a_gold_list))} 121 | 122 | def compute_turn_score(self, story_id, turn_id, a_pred): 123 | ''' This is the function what you are probably looking for. a_pred is the answer string your model predicted. ''' 124 | key = (story_id, turn_id) 125 | a_gold_list = self.gold_data[key] 126 | return CoQAEvaluator._compute_turn_score(a_gold_list, a_pred) 127 | 128 | def get_raw_scores(self, pred_data): 129 | ''''Returns a dict with score with each turn prediction''' 130 | exact_scores = {} 131 | f1_scores = {} 132 | for story_id, turn_id in self.gold_data: 133 | key = (story_id, turn_id) 134 | if key not in pred_data: 135 | sys.stderr.write('Missing prediction for {} and turn_id: {}\n'.format(story_id, turn_id)) 136 | continue 137 | a_pred = pred_data[key] 138 | scores = self.compute_turn_score(story_id, turn_id, a_pred) 139 | # Take max over all gold answers 140 | exact_scores[key] = scores['em'] 141 | f1_scores[key] = scores['f1'] 142 | return exact_scores, f1_scores 143 | 144 | def get_raw_scores_human(self): 145 | ''''Returns a dict with score for each turn''' 146 | exact_scores = {} 147 | f1_scores = {} 148 | for story_id, turn_id in self.gold_data: 149 | key = (story_id, turn_id) 150 | f1_sum = 0.0 151 | em_sum = 0.0 152 | if len(self.gold_data[key]) > 1: 153 | for i in range(len(self.gold_data[key])): 154 | # exclude the current answer 155 | gold_answers = self.gold_data[key][0:i] + self.gold_data[key][i + 1:] 156 | em_sum += max(CoQAEvaluator.compute_exact(a, self.gold_data[key][i]) for a in gold_answers) 157 | f1_sum += max(CoQAEvaluator.compute_f1(a, self.gold_data[key][i]) for a in gold_answers) 158 | else: 159 | exit("Gold answers should be multiple: {}={}".format(key, self.gold_data[key])) 160 | exact_scores[key] = em_sum / len(self.gold_data[key]) 161 | f1_scores[key] = f1_sum / len(self.gold_data[key]) 162 | return exact_scores, f1_scores 163 | 164 | def human_performance(self): 165 | exact_scores, f1_scores = self.get_raw_scores_human() 166 | return self.get_domain_scores(exact_scores, f1_scores) 167 | 168 | def model_performance(self, pred_data): 169 | exact_scores, f1_scores = self.get_raw_scores(pred_data) 170 | return self.get_domain_scores(exact_scores, f1_scores) 171 | 172 | def get_domain_scores(self, exact_scores, f1_scores): 173 | sources = {} 174 | for source in in_domain + out_domain: 175 | sources[source] = Counter() 176 | 177 | for story_id, turn_id in self.gold_data: 178 | key = (story_id, turn_id) 179 | source = self.id_to_source[story_id] 180 | sources[source]['em_total'] += exact_scores.get(key, 0) 181 | sources[source]['f1_total'] += f1_scores.get(key, 0) 182 | sources[source]['turn_count'] += 1 183 | 184 | scores = OrderedDict() 185 | in_domain_em_total = 0.0 186 | in_domain_f1_total = 0.0 187 | in_domain_turn_count = 0 188 | 189 | out_domain_em_total = 0.0 190 | out_domain_f1_total = 0.0 191 | out_domain_turn_count = 0 192 | 193 | for source in in_domain + out_domain: 194 | domain = domain_mappings[source] 195 | scores[domain] = {} 196 | scores[domain]['em'] = round(sources[source]['em_total'] / max(1, sources[source]['turn_count']) * 100, 1) 197 | scores[domain]['f1'] = round(sources[source]['f1_total'] / max(1, sources[source]['turn_count']) * 100, 1) 198 | scores[domain]['turns'] = sources[source]['turn_count'] 199 | if source in in_domain: 200 | in_domain_em_total += sources[source]['em_total'] 201 | in_domain_f1_total += sources[source]['f1_total'] 202 | in_domain_turn_count += sources[source]['turn_count'] 203 | elif source in out_domain: 204 | out_domain_em_total += sources[source]['em_total'] 205 | out_domain_f1_total += sources[source]['f1_total'] 206 | out_domain_turn_count += sources[source]['turn_count'] 207 | 208 | scores["in_domain"] = {'em': round(in_domain_em_total / max(1, in_domain_turn_count) * 100, 1), 209 | 'f1': round(in_domain_f1_total / max(1, in_domain_turn_count) * 100, 1), 210 | 'turns': in_domain_turn_count} 211 | scores["out_domain"] = {'em': round(out_domain_em_total / max(1, out_domain_turn_count) * 100, 1), 212 | 'f1': round(out_domain_f1_total / max(1, out_domain_turn_count) * 100, 1), 213 | 'turns': out_domain_turn_count} 214 | 215 | em_total = in_domain_em_total + out_domain_em_total 216 | f1_total = in_domain_f1_total + out_domain_f1_total 217 | turn_count = in_domain_turn_count + out_domain_turn_count 218 | scores["overall"] = {'em': round(em_total / max(1, turn_count) * 100, 1), 219 | 'f1': round(f1_total / max(1, turn_count) * 100, 1), 220 | 'turns': turn_count} 221 | 222 | return scores 223 | 224 | def parse_args(): 225 | parser = argparse.ArgumentParser('Official evaluation script for CoQA.') 226 | parser.add_argument('--data-file', dest="data_file", help='Input data JSON file.') 227 | parser.add_argument('--pred-file', dest="pred_file", help='Model predictions.') 228 | parser.add_argument('--out-file', '-o', metavar='eval.json', 229 | help='Write accuracy metrics to file (default is stdout).') 230 | parser.add_argument('--verbose', '-v', action='store_true') 231 | parser.add_argument('--human', dest="human", action='store_true') 232 | if len(sys.argv) == 1: 233 | parser.print_help() 234 | sys.exit(1) 235 | return parser.parse_args() 236 | 237 | def main(): 238 | evaluator = CoQAEvaluator(OPTS.data_file) 239 | 240 | if OPTS.human: 241 | print(json.dumps(evaluator.human_performance(), indent=2)) 242 | 243 | if OPTS.pred_file: 244 | with open(OPTS.pred_file) as f: 245 | pred_data = CoQAEvaluator.preds_to_dict(OPTS.pred_file) 246 | print(json.dumps(evaluator.model_performance(pred_data), indent=2)) 247 | 248 | if __name__ == '__main__': 249 | OPTS = parse_args() 250 | main() 251 | -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | from bert import BertModel, BertPreTrainedModel 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import CrossEntropyLoss 5 | import random 6 | import torch 7 | 8 | class Multi_linear_layer(nn.Module): 9 | def __init__(self, 10 | n_layers, 11 | input_size, 12 | hidden_size, 13 | output_size, 14 | activation=None): 15 | super(Multi_linear_layer, self).__init__() 16 | self.linears = nn.ModuleList() 17 | self.linears.append(nn.Linear(input_size, hidden_size)) 18 | for _ in range(1, n_layers - 1): 19 | self.linears.append(nn.Linear(hidden_size, hidden_size)) 20 | self.linears.append(nn.Linear(hidden_size, output_size)) 21 | self.activation = getattr(F, activation) 22 | 23 | def forward(self, x): 24 | for linear in self.linears[:-1]: 25 | x = self.activation(linear(x)) 26 | linear = self.linears[-1] 27 | x = linear(x) 28 | return x 29 | 30 | class BertForCoQA(BertPreTrainedModel): 31 | def __init__( 32 | self, 33 | config, 34 | output_attentions=False, 35 | keep_multihead_output=False, 36 | n_layers=2, 37 | activation='relu', 38 | beta=100, 39 | ): 40 | super(BertForCoQA, self).__init__(config) 41 | self.output_attentions = output_attentions 42 | self.bert = BertModel(config) 43 | hidden_size = config.hidden_size 44 | self.rational_l = Multi_linear_layer(n_layers, hidden_size, 45 | hidden_size, 1, activation) 46 | self.logits_l = Multi_linear_layer(n_layers, hidden_size, hidden_size, 47 | 2, activation) 48 | self.unk_l = Multi_linear_layer(n_layers, hidden_size, hidden_size, 1, 49 | activation) 50 | self.attention_l = Multi_linear_layer(n_layers, hidden_size, 51 | hidden_size, 1, activation) 52 | self.yn_l = Multi_linear_layer(n_layers, hidden_size, hidden_size, 2, 53 | activation) 54 | self.beta = beta 55 | 56 | self.init_weights() 57 | 58 | def forward( 59 | self, 60 | input_ids, 61 | token_type_ids=None, 62 | attention_mask=None, 63 | start_positions=None, 64 | end_positions=None, 65 | rational_mask=None, 66 | cls_idx = None, 67 | head_mask=None, 68 | ): 69 | # mask some words on inputs_ids 70 | # if self.training and self.mask_p > 0: 71 | # batch_size = input_ids.size(0) 72 | # for i in range(batch_size): 73 | # len_c, len_qc = token_type_ids[i].sum( 74 | # dim=0).detach().item(), attention_mask[i].sum( 75 | # dim=0).detach().item() 76 | # masked_idx = random.sample(range(len_qc - len_c, len_qc), 77 | # int(len_c * self.mask_p)) 78 | # input_ids[i, masked_idx] = 100 79 | 80 | outputs = self.bert( 81 | input_ids, 82 | token_type_ids=token_type_ids, 83 | attention_mask=attention_mask, 84 | # output_all_encoded_layers=False, 85 | head_mask=head_mask, 86 | ) 87 | if self.output_attentions: 88 | all_attentions, sequence_output, cls_outputs = outputs 89 | else: 90 | final_hidden, pooled_output = outputs 91 | 92 | rational_logits = self.rational_l(final_hidden) 93 | rational_logits = torch.sigmoid(rational_logits) 94 | 95 | final_hidden = final_hidden * rational_logits 96 | 97 | logits = self.logits_l(final_hidden) 98 | 99 | start_logits, end_logits = logits.split(1, dim=-1) 100 | 101 | start_logits, end_logits = start_logits.squeeze( 102 | -1), end_logits.squeeze(-1) 103 | 104 | segment_mask = token_type_ids.type(final_hidden.dtype) 105 | 106 | rational_logits = rational_logits.squeeze(-1) * segment_mask 107 | 108 | start_logits = start_logits * rational_logits 109 | 110 | end_logits = end_logits * rational_logits 111 | 112 | unk_logits = self.unk_l(pooled_output) 113 | 114 | attention = self.attention_l(final_hidden).squeeze(-1) 115 | 116 | attention.data.masked_fill_(attention_mask.eq(0), -float('inf')) 117 | 118 | attention = F.softmax(attention, dim=-1) 119 | 120 | attention_pooled_output = (attention.unsqueeze(-1) * 121 | final_hidden).sum(dim=-2) 122 | 123 | yn_logits = self.yn_l(attention_pooled_output) 124 | 125 | yes_logits, no_logits = yn_logits.split(1, dim=-1) 126 | 127 | start_logits.data.masked_fill_(attention_mask.eq(0), -float('inf')) 128 | end_logits.data.masked_fill_(attention_mask.eq(0), -float('inf')) 129 | 130 | new_start_logits = torch.cat( 131 | (yes_logits, no_logits, unk_logits, start_logits), dim=-1) 132 | new_end_logits = torch.cat( 133 | (yes_logits, no_logits, unk_logits, end_logits), dim=-1) 134 | 135 | if start_positions is not None and end_positions is not None: 136 | 137 | start_positions, end_positions = start_positions + cls_idx, end_positions + cls_idx 138 | 139 | # If we are on multi-GPU, split add a dimension 140 | if len(start_positions.size()) > 1: 141 | start_positions = start_positions.squeeze(-1) 142 | if len(end_positions.size()) > 1: 143 | end_positions = end_positions.squeeze(-1) 144 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 145 | ignored_index = new_start_logits.size(1) 146 | start_positions.clamp_(0, ignored_index) 147 | end_positions.clamp_(0, ignored_index) 148 | 149 | span_loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 150 | 151 | start_loss = span_loss_fct(new_start_logits, start_positions) 152 | end_loss = span_loss_fct(new_end_logits, end_positions) 153 | 154 | # rational part 155 | alpha = 0.25 156 | gamma = 2. 157 | rational_mask = rational_mask.type(final_hidden.dtype) 158 | 159 | rational_loss = -alpha * ( 160 | (1 - rational_logits)**gamma 161 | ) * rational_mask * torch.log(rational_logits + 1e-7) - ( 162 | 1 - alpha) * (rational_logits**gamma) * ( 163 | 1 - rational_mask) * torch.log(1 - rational_logits + 1e-7) 164 | 165 | rational_loss = (rational_loss * 166 | segment_mask).sum() / segment_mask.sum() 167 | # end 168 | 169 | assert not torch.isnan(rational_loss) 170 | 171 | total_loss = (start_loss + 172 | end_loss) / 2 + rational_loss * self.beta 173 | return total_loss 174 | 175 | return start_logits, end_logits, yes_logits, no_logits, unk_logits 176 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.32.1 2 | requests==2.21.0 3 | botocore==1.12.183 4 | apex==0.1 5 | spacy==2.0.16 6 | numpy==1.15.4 7 | six==1.12.0 8 | boto3==1.9.183 9 | torch==1.1.0 10 | ptvsd==4.3.2 11 | 12 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=3 \ 2 | python run_coqa.py \ 3 | --type bert \ 4 | --bert_model ../bert-base-uncased/ \ 5 | --do_train \ 6 | --do_predict \ 7 | --output_dir tmp2 \ 8 | --train_file coqa-train-v1.0.json \ 9 | --predict_file coqa-dev-v1.0.json \ 10 | --train_batch_size 12 \ 11 | --learning_rate 3e-5 \ 12 | --warmup_proportion 0.1 \ 13 | --max_grad_norm -1 \ 14 | --weight_decay 0.01 \ 15 | --fp16 \ 16 | --do_lower_case \ 17 | 18 | -------------------------------------------------------------------------------- /run_coqa.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Run BERT on CoQA.""" 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import argparse 21 | import logging 22 | import os 23 | import random 24 | import sys 25 | from io import open 26 | import json 27 | 28 | import numpy as np 29 | import torch 30 | import torch.nn as nn 31 | from torch.nn import CrossEntropyLoss 32 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 33 | TensorDataset) 34 | from torch.utils.data.distributed import DistributedSampler 35 | from tqdm import tqdm, trange 36 | 37 | from bert.file_utils import WEIGHTS_NAME, CONFIG_NAME 38 | from modeling import BertForCoQA 39 | from bert import AdamW, WarmupLinearSchedule, BertTokenizer 40 | 41 | from run_coqa_dataset_utils import read_coqa_examples, convert_examples_to_features, RawResult, write_predictions, score 42 | # from parallel import DataParallelCriterion, DataParallelModel, gather 43 | 44 | if sys.version_info[0] == 2: 45 | import cPickle as pickle 46 | else: 47 | import pickle 48 | 49 | logger = logging.getLogger(__name__) 50 | 51 | 52 | def main(): 53 | parser = argparse.ArgumentParser() 54 | 55 | ## Required parameters 56 | parser.add_argument("--type", 57 | default=None, 58 | type=str, 59 | required=True, 60 | help=".") 61 | parser.add_argument( 62 | "--bert_model", 63 | default=None, 64 | type=str, 65 | required=True, 66 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 67 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " 68 | "bert-base-multilingual-cased, bert-base-chinese.") 69 | parser.add_argument( 70 | "--output_dir", 71 | default=None, 72 | type=str, 73 | required=True, 74 | help= 75 | "The output directory where the model checkpoints and predictions will be written." 76 | ) 77 | 78 | ## Other parameters 79 | parser.add_argument( 80 | "--train_file", 81 | default=None, 82 | type=str, 83 | help="CoQA json for training. E.g., coqa-train-v1.0.json") 84 | parser.add_argument( 85 | "--predict_file", 86 | default=None, 87 | type=str, 88 | help="CoQA json for predictions. E.g., coqa-dev-v1.0.json") 89 | parser.add_argument( 90 | "--max_seq_length", 91 | default=512, 92 | type=int, 93 | help= 94 | "The maximum total input sequence length after WordPiece tokenization. Sequences " 95 | "longer than this will be truncated, and sequences shorter than this will be padded." 96 | ) 97 | parser.add_argument( 98 | "--doc_stride", 99 | default=128, 100 | type=int, 101 | help= 102 | "When splitting up a long document into chunks, how much stride to take between chunks." 103 | ) 104 | parser.add_argument( 105 | "--max_query_length", 106 | default=64, 107 | type=int, 108 | help= 109 | "The maximum number of tokens for the question. Questions longer than this will " 110 | "be truncated to this length.") 111 | parser.add_argument("--do_train", 112 | action='store_true', 113 | help="Whether to run training.") 114 | parser.add_argument("--do_predict", 115 | action='store_true', 116 | help="Whether to run eval on the dev set.") 117 | # parser.add_argument("--do_F1", 118 | # action='store_true', 119 | # help="Whether to calculating F1 score") # we don't talk anymore. please use official evaluation scripts 120 | parser.add_argument("--train_batch_size", 121 | default=48, 122 | type=int, 123 | help="Total batch size for training.") 124 | parser.add_argument("--predict_batch_size", 125 | default=48, 126 | type=int, 127 | help="Total batch size for predictions.") 128 | parser.add_argument("--learning_rate", 129 | default=5e-5, 130 | type=float, 131 | help="The initial learning rate for Adam.") 132 | parser.add_argument("--num_train_epochs", 133 | default=2.0, 134 | type=float, 135 | help="Total number of training epochs to perform.") 136 | parser.add_argument( 137 | "--warmup_proportion", 138 | default=0.06, 139 | type=float, 140 | help= 141 | "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% " 142 | "of training.") 143 | parser.add_argument( 144 | "--n_best_size", 145 | default=20, 146 | type=int, 147 | help= 148 | "The total number of n-best predictions to generate in the nbest_predictions.json " 149 | "output file.") 150 | parser.add_argument( 151 | "--max_answer_length", 152 | default=30, 153 | type=int, 154 | help= 155 | "The maximum length of an answer that can be generated. This is needed because the start " 156 | "and end predictions are not conditioned on one another.") 157 | parser.add_argument( 158 | "--verbose_logging", 159 | action='store_true', 160 | help= 161 | "If true, all of the warnings related to data processing will be printed. " 162 | "A number of warnings are expected for a normal CoQA evaluation.") 163 | parser.add_argument("--no_cuda", 164 | action='store_true', 165 | help="Whether not to use CUDA when available") 166 | parser.add_argument('--seed', 167 | type=int, 168 | default=42, 169 | help="random seed for initialization") 170 | parser.add_argument( 171 | '--gradient_accumulation_steps', 172 | type=int, 173 | default=1, 174 | help= 175 | "Number of updates steps to accumulate before performing a backward/update pass." 176 | ) 177 | parser.add_argument( 178 | "--do_lower_case", 179 | action='store_true', 180 | help= 181 | "Whether to lower case the input text. True for uncased models, False for cased models." 182 | ) 183 | parser.add_argument("--local_rank", 184 | type=int, 185 | default=-1, 186 | help="local_rank for distributed training on gpus") 187 | parser.add_argument( 188 | '--fp16', 189 | action='store_true', 190 | help="Whether to use 16-bit float precision instead of 32-bit") 191 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 192 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 193 | "See details at https://nvidia.github.io/apex/amp.html") 194 | parser.add_argument('--overwrite_output_dir', 195 | action='store_true', 196 | help="Overwrite the content of the output directory") 197 | parser.add_argument( 198 | '--loss_scale', 199 | type=float, 200 | default=0, 201 | help= 202 | "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 203 | "0 (default value): dynamic loss scaling.\n" 204 | "Positive power of 2: static loss scaling value.\n") 205 | parser.add_argument( 206 | '--weight_decay', 207 | type=float, 208 | default=0, 209 | help="") 210 | parser.add_argument( 211 | '--null_score_diff_threshold', 212 | type=float, 213 | default=0.0, 214 | help= 215 | "If null_score - best_non_null is greater than the threshold predict null." 216 | ) 217 | parser.add_argument('--server_ip', 218 | type=str, 219 | default='', 220 | help="Can be used for distant debugging.") 221 | parser.add_argument('--server_port', 222 | type=str, 223 | default='', 224 | help="Can be used for distant debugging.") 225 | parser.add_argument('--logfile', 226 | type=str, 227 | default=None, 228 | help='Which file to keep log.') 229 | parser.add_argument('--logmode', 230 | type=str, 231 | default=None, 232 | help='logging mode, `w` or `a`') 233 | parser.add_argument('--tensorboard', 234 | action='store_true', 235 | help='no tensor board') 236 | parser.add_argument('--qa_tag', 237 | action='store_true', 238 | help='add qa tag or not') 239 | parser.add_argument('--history_len', 240 | type=int, 241 | default=2, 242 | help='length of history') 243 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 244 | help="Epsilon for Adam optimizer.") 245 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 246 | help="Max gradient norm.") 247 | parser.add_argument('--logging_steps', type=int, default=50, 248 | help="Log every X updates steps.") 249 | args = parser.parse_args() 250 | print(args) 251 | 252 | if args.server_ip and args.server_port: 253 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 254 | import ptvsd 255 | print("Waiting for debugger attach") 256 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), 257 | redirect_output=True) 258 | ptvsd.wait_for_attach() 259 | 260 | if args.local_rank == -1 or args.no_cuda: 261 | device = torch.device("cuda" if torch.cuda.is_available() 262 | and not args.no_cuda else "cpu") 263 | n_gpu = torch.cuda.device_count() 264 | else: 265 | torch.cuda.set_device(args.local_rank) 266 | device = torch.device("cuda", args.local_rank) 267 | n_gpu = 1 268 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 269 | torch.distributed.init_process_group(backend='nccl') 270 | 271 | logging.basicConfig( 272 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 273 | datefmt='%m/%d/%Y %H:%M:%S', 274 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, 275 | filename=args.logfile, 276 | filemode=args.logmode) 277 | 278 | logger.info( 279 | "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". 280 | format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) 281 | 282 | if args.gradient_accumulation_steps < 1: 283 | raise ValueError( 284 | "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" 285 | .format(args.gradient_accumulation_steps)) 286 | 287 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 288 | 289 | random.seed(args.seed) 290 | np.random.seed(args.seed) 291 | torch.manual_seed(args.seed) 292 | if n_gpu > 0: 293 | torch.cuda.manual_seed_all(args.seed) 294 | 295 | if not args.do_train and not args.do_predict: 296 | raise ValueError( 297 | "At least one of `do_train` or `do_predict` must be True." 298 | ) 299 | 300 | if args.do_train: 301 | if not args.train_file: 302 | raise ValueError( 303 | "If `do_train` is True, then `train_file` must be specified.") 304 | if args.do_predict: 305 | if not args.predict_file: 306 | raise ValueError( 307 | "If `do_predict` is True, then `predict_file` must be specified." 308 | ) 309 | 310 | if os.path.exists(args.output_dir) and os.listdir( 311 | args.output_dir 312 | ) and args.do_train and not args.overwrite_output_dir: 313 | raise ValueError( 314 | "Output directory () already exists and is not empty.") 315 | if not os.path.exists(args.output_dir): 316 | os.makedirs(args.output_dir) 317 | 318 | if args.local_rank not in [-1, 0]: 319 | torch.distributed.barrier( 320 | ) # Make sure only the first process in distributed training will download model & vocab 321 | 322 | if args.do_train or args.do_predict: 323 | tokenizer = BertTokenizer.from_pretrained( 324 | args.bert_model, do_lower_case=args.do_lower_case) 325 | model = BertForCoQA.from_pretrained(args.bert_model) 326 | if args.local_rank == 0: 327 | torch.distributed.barrier() 328 | 329 | model.to(device) 330 | 331 | if args.do_train: 332 | if args.local_rank in [-1, 0] and args.tensorboard: 333 | from tensorboardX import SummaryWriter 334 | tb_writer = SummaryWriter() 335 | # Prepare data loader 336 | cached_train_features_file = args.train_file + '_{0}_{1}_{2}_{3}_{4}_{5}_{6}'.format( 337 | args.type, str(args.max_seq_length), str(args.doc_stride), 338 | str(args.max_query_length), str(args.max_answer_length), 339 | str(args.history_len), str(args.qa_tag)) 340 | cached_train_examples_file = args.train_file + '_examples_{0}_{1}.pk'.format( 341 | str(args.history_len), str(args.qa_tag)) 342 | 343 | # try train_examples 344 | try: 345 | with open(cached_train_examples_file, "rb") as reader: 346 | train_examples = pickle.load(reader) 347 | except: 348 | logger.info(" No cached file %s", cached_train_examples_file) 349 | train_examples = read_coqa_examples(input_file=args.train_file, 350 | history_len=args.history_len, 351 | add_QA_tag=args.qa_tag) 352 | if args.local_rank == -1 or torch.distributed.get_rank() == 0: 353 | logger.info(" Saving train examples into cached file %s", 354 | cached_train_examples_file) 355 | with open(cached_train_examples_file, "wb") as writer: 356 | pickle.dump(train_examples, writer) 357 | 358 | # print('DEBUG') 359 | # exit() 360 | 361 | # try train_features 362 | try: 363 | with open(cached_train_features_file, "rb") as reader: 364 | train_features = pickle.load(reader) 365 | except: 366 | logger.info(" No cached file %s", cached_train_features_file) 367 | train_features = convert_examples_to_features( 368 | examples=train_examples, 369 | tokenizer=tokenizer, 370 | max_seq_length=args.max_seq_length, 371 | doc_stride=args.doc_stride, 372 | max_query_length=args.max_query_length, 373 | ) 374 | if args.local_rank == -1 or torch.distributed.get_rank() == 0: 375 | logger.info(" Saving train features into cached file %s", 376 | cached_train_features_file) 377 | with open(cached_train_features_file, "wb") as writer: 378 | pickle.dump(train_features, writer) 379 | 380 | # print('DEBUG') 381 | # exit() 382 | 383 | all_input_ids = torch.tensor([f.input_ids for f in train_features], 384 | dtype=torch.long) 385 | all_input_mask = torch.tensor([f.input_mask for f in train_features], 386 | dtype=torch.long) 387 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], 388 | dtype=torch.long) 389 | all_start_positions = torch.tensor( 390 | [f.start_position for f in train_features], dtype=torch.long) 391 | all_end_positions = torch.tensor( 392 | [f.end_position for f in train_features], dtype=torch.long) 393 | all_rational_mask = torch.tensor( 394 | [f.rational_mask for f in train_features], dtype=torch.long) 395 | all_cls_idx = torch.tensor([f.cls_idx for f in train_features], 396 | dtype=torch.long) 397 | train_data = TensorDataset(all_input_ids, all_input_mask, 398 | all_segment_ids, all_start_positions, 399 | all_end_positions, all_rational_mask, 400 | all_cls_idx) 401 | if args.local_rank == -1: 402 | train_sampler = RandomSampler(train_data) 403 | else: 404 | train_sampler = DistributedSampler(train_data) 405 | 406 | train_dataloader = DataLoader(train_data, 407 | sampler=train_sampler, 408 | batch_size=args.train_batch_size) 409 | num_train_optimization_steps = len( 410 | train_dataloader 411 | ) // args.gradient_accumulation_steps * args.num_train_epochs 412 | # if args.local_rank != -1: 413 | # num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() 414 | 415 | # Prepare optimizer 416 | param_optimizer = list(model.named_parameters()) 417 | 418 | # hack to remove pooler, which is not used 419 | # thus it produce None grad that break apex 420 | # param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]] 421 | 422 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 423 | optimizer_grouped_parameters = [{ 424 | 'params': [ 425 | p for n, p in param_optimizer 426 | if not any(nd in n for nd in no_decay) 427 | ], 428 | 'weight_decay': 429 | args.weight_decay 430 | }, { 431 | 'params': 432 | [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 433 | 'weight_decay': 434 | 0.0 435 | }] 436 | 437 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 438 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=int(args.warmup_proportion * num_train_optimization_steps), t_total=num_train_optimization_steps) 439 | 440 | if args.fp16: 441 | try: 442 | from apex import amp 443 | except ImportError: 444 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 445 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 446 | 447 | if n_gpu > 1: 448 | model = torch.nn.DataParallel(model) 449 | 450 | if args.local_rank != -1: 451 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 452 | output_device=args.local_rank, 453 | find_unused_parameters=True) 454 | 455 | global_step = 0 456 | tr_loss, logging_loss = 0.0, 0.0 457 | 458 | logger.info("***** Running training *****") 459 | logger.info(" Num orig examples = %d", len(train_examples)) 460 | logger.info(" Num split examples = %d", len(train_features)) 461 | logger.info(" Batch size = %d", args.train_batch_size) 462 | logger.info(" Num steps = %d", num_train_optimization_steps) 463 | 464 | model.train() 465 | for epoch in trange(int(args.num_train_epochs), desc="Epoch"): 466 | for step, batch in enumerate( 467 | tqdm(train_dataloader, 468 | desc="Iteration", 469 | disable=args.local_rank not in [-1, 0])): 470 | batch = tuple( 471 | t.to(device) 472 | for t in batch) # multi-gpu does scattering it-self 473 | input_ids, input_mask, segment_ids, start_positions, end_positions, rational_mask, cls_idx = batch 474 | loss = model(input_ids, segment_ids, input_mask, 475 | start_positions, end_positions, rational_mask, 476 | cls_idx) 477 | # loss = gather(loss, 0) 478 | if n_gpu > 1: 479 | loss = loss.mean() # mean() to average on multi-gpu. 480 | if args.gradient_accumulation_steps > 1: 481 | loss = loss / args.gradient_accumulation_steps 482 | 483 | if args.fp16: 484 | with amp.scale_loss(loss, optimizer) as scaled_loss: 485 | scaled_loss.backward() 486 | if args.max_grad_norm > 0: 487 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 488 | else: 489 | loss.backward() 490 | if args.max_grad_norm > 0: 491 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 492 | 493 | tr_loss += loss.item() 494 | 495 | if (step + 1) % args.gradient_accumulation_steps == 0: 496 | optimizer.step() 497 | scheduler.step() # Update learning rate schedule 498 | model.zero_grad() 499 | global_step += 1 500 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 501 | if args.tensorboard: 502 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) 503 | tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step) 504 | else: 505 | logger.info('Step: {}\tLearning rate: {}\tLoss: {}\t'.format(global_step, scheduler.get_lr()[0], (tr_loss - logging_loss)/args.logging_steps)) 506 | logging_loss = tr_loss 507 | 508 | if args.do_train and (args.local_rank == -1 509 | or torch.distributed.get_rank() == 0): 510 | # Save a trained model, configuration and tokenizer 511 | model_to_save = model.module if hasattr( 512 | model, 'module') else model # Only save the model it-self 513 | 514 | # If we save using the predefined names, we can load using `from_pretrained` 515 | output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) 516 | output_config_file = os.path.join(args.output_dir, CONFIG_NAME) 517 | 518 | torch.save(model_to_save.state_dict(), output_model_file) 519 | model_to_save.config.to_json_file(output_config_file) 520 | tokenizer.save_vocabulary(args.output_dir) 521 | 522 | # Load a trained model and vocabulary that you have fine-tuned 523 | model = BertForCoQA.from_pretrained(args.output_dir) 524 | tokenizer = BertTokenizer.from_pretrained( 525 | args.output_dir, do_lower_case=args.do_lower_case) 526 | 527 | # Good practice: save your training arguments together with the trained model 528 | output_args_file = os.path.join(args.output_dir, 'training_args.bin') 529 | torch.save(args, output_args_file) 530 | else: 531 | model = BertForCoQA.from_pretrained(args.bert_model) 532 | 533 | model.to(device) 534 | 535 | if args.do_predict and (args.local_rank == -1 536 | or torch.distributed.get_rank() == 0): 537 | cached_eval_features_file = args.predict_file + '_{0}_{1}_{2}_{3}_{4}_{5}_{6}'.format( 538 | args.type, str(args.max_seq_length), str(args.doc_stride), 539 | str(args.max_query_length), str(args.max_answer_length), 540 | str(args.history_len), str(args.qa_tag)) 541 | cached_eval_examples_file = args.predict_file + '_examples_{0}_{1}.pk'.format( 542 | str(args.history_len), str(args.qa_tag)) 543 | 544 | # try eval_examples 545 | try: 546 | with open(cached_eval_examples_file, 'rb') as reader: 547 | eval_examples = pickle.load(reader) 548 | except: 549 | logger.info("No cached file: %s", cached_eval_examples_file) 550 | eval_examples = read_coqa_examples(input_file=args.predict_file, 551 | history_len=args.history_len, 552 | add_QA_tag=args.qa_tag) 553 | logger.info(" Saving eval examples into cached file %s", 554 | cached_eval_examples_file) 555 | with open(cached_eval_examples_file, 'wb') as writer: 556 | pickle.dump(eval_examples, writer) 557 | 558 | # try eval_features 559 | try: 560 | with open(cached_eval_features_file, "rb") as reader: 561 | eval_features = pickle.load(reader) 562 | except: 563 | logger.info("No cached file: %s", cached_eval_features_file) 564 | eval_features = convert_examples_to_features( 565 | examples=eval_examples, 566 | tokenizer=tokenizer, 567 | max_seq_length=args.max_seq_length, 568 | doc_stride=args.doc_stride, 569 | max_query_length=args.max_query_length, 570 | ) 571 | logger.info(" Saving eval features into cached file %s", 572 | cached_eval_features_file) 573 | with open(cached_eval_features_file, "wb") as writer: 574 | pickle.dump(eval_features, writer) 575 | 576 | # print('DEBUG') 577 | # exit() 578 | 579 | logger.info("***** Running predictions *****") 580 | logger.info(" Num orig examples = %d", len(eval_examples)) 581 | logger.info(" Num split examples = %d", len(eval_features)) 582 | logger.info(" Batch size = %d", args.predict_batch_size) 583 | 584 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], 585 | dtype=torch.long) 586 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], 587 | dtype=torch.long) 588 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], 589 | dtype=torch.long) 590 | all_example_index = torch.arange(all_input_ids.size(0), 591 | dtype=torch.long) 592 | eval_data = TensorDataset(all_input_ids, all_input_mask, 593 | all_segment_ids, all_example_index) 594 | # Run prediction for full data 595 | eval_sampler = SequentialSampler(eval_data) 596 | eval_dataloader = DataLoader(eval_data, 597 | sampler=eval_sampler, 598 | batch_size=args.predict_batch_size) 599 | 600 | model.eval() 601 | all_results = [] 602 | logger.info("Start evaluating") 603 | for input_ids, input_mask, segment_ids, example_indices in tqdm( 604 | eval_dataloader, 605 | desc="Evaluating", 606 | disable=args.local_rank not in [-1, 0]): 607 | # if len(all_results) % 1000 == 0: 608 | # logger.info("Processing example: %d" % (len(all_results))) 609 | input_ids = input_ids.to(device) 610 | input_mask = input_mask.to(device) 611 | segment_ids = segment_ids.to(device) 612 | with torch.no_grad(): 613 | batch_start_logits, batch_end_logits, batch_yes_logits, batch_no_logits, batch_unk_logits = model( 614 | input_ids, segment_ids, input_mask) 615 | for i, example_index in enumerate(example_indices): 616 | start_logits = batch_start_logits[i].detach().cpu().tolist() 617 | end_logits = batch_end_logits[i].detach().cpu().tolist() 618 | yes_logits = batch_yes_logits[i].detach().cpu().tolist() 619 | no_logits = batch_no_logits[i].detach().cpu().tolist() 620 | unk_logits = batch_unk_logits[i].detach().cpu().tolist() 621 | eval_feature = eval_features[example_index.item()] 622 | unique_id = int(eval_feature.unique_id) 623 | all_results.append( 624 | RawResult(unique_id=unique_id, 625 | start_logits=start_logits, 626 | end_logits=end_logits, 627 | yes_logits=yes_logits, 628 | no_logits=no_logits, 629 | unk_logits=unk_logits)) 630 | output_prediction_file = os.path.join(args.output_dir, 631 | "predictions.json") 632 | output_nbest_file = os.path.join(args.output_dir, 633 | "nbest_predictions.json") 634 | output_null_log_odds_file = os.path.join(args.output_dir, 635 | "null_odds.json") 636 | write_predictions(eval_examples, eval_features, all_results, 637 | args.n_best_size, args.max_answer_length, 638 | args.do_lower_case, output_prediction_file, 639 | output_nbest_file, output_null_log_odds_file, 640 | args.verbose_logging, args.null_score_diff_threshold) 641 | 642 | # we don't do F1 any more 643 | 644 | # if args.do_F1 and (args.local_rank == -1 645 | # or torch.distributed.get_rank() == 0): 646 | # logger.info("Start calculating F1") 647 | # cached_eval_examples_file = args.predict_file + '_examples.pk' 648 | # try: 649 | # with open(cached_eval_examples_file, 'rb') as reader: 650 | # eval_examples = pickle.load(reader) 651 | # except: 652 | # eval_examples = read_coqa_examples(input_file=args.predict_file) 653 | # pred_dict = json.load( 654 | # open(os.path.join(args.output_dir, "predictions.json"), 'rb')) 655 | # truth_dict = {} 656 | # for i in range(len(eval_examples)): 657 | # answers = eval_examples[i].additional_answers 658 | # tmp = eval_examples[i].orig_answer_text 659 | # if tmp not in answers: 660 | # answers.append(tmp) 661 | # truth_dict[eval_examples[i].qas_id] = answers 662 | # with open(os.path.join(args.output_dir, "truths.json"), 'w') as writer: 663 | # writer.write(json.dumps(truth_dict, indent=4) + '\n') 664 | # result, all_f1s = score(pred_dict, truth_dict) 665 | # logger.info(str(result)) 666 | 667 | 668 | if __name__ == "__main__": 669 | main() 670 | -------------------------------------------------------------------------------- /run_coqa_dataset_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Load CoQA dataset. """ 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import json 21 | import logging 22 | import math 23 | import collections 24 | from io import open 25 | from tqdm import tqdm 26 | # from GeneralUtils import * 27 | import spacy 28 | import re 29 | from collections import Counter 30 | import string 31 | 32 | # from pytorch_pretrained_bert.tokenization import BasicTokenizer, whitespace_tokenize 33 | from bert.tokenization_bert import BasicTokenizer, whitespace_tokenize 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | 38 | class CoqaExample(object): 39 | """ 40 | A single training/test example for the CoQA dataset. 41 | For examples without an answer, the start and end position are -1. 42 | """ 43 | 44 | def __init__( 45 | self, 46 | qas_id, 47 | question_text, 48 | doc_tokens, 49 | orig_answer_text=None, 50 | start_position=None, 51 | end_position=None, 52 | rational_start_position=None, 53 | rational_end_position=None, 54 | additional_answers=None, 55 | ): 56 | self.qas_id = qas_id 57 | self.question_text = question_text 58 | self.doc_tokens = doc_tokens 59 | self.orig_answer_text = orig_answer_text 60 | self.start_position = start_position 61 | self.end_position = end_position 62 | self.additional_answers = additional_answers 63 | self.rational_start_position = rational_start_position 64 | self.rational_end_position = rational_end_position 65 | 66 | def __str__(self): 67 | return self.__repr__() 68 | 69 | def __repr__(self): 70 | s = "" 71 | s += "qas_id: %s" % (self.qas_id) 72 | s += ", question_text: %s" % (self.question_text) 73 | s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) 74 | if self.start_position: 75 | s += ", start_position: %d" % (self.start_position) 76 | if self.end_position: 77 | s += ", end_position: %d" % (self.end_position) 78 | return s 79 | 80 | 81 | class InputFeatures(object): 82 | """A single set of features of data.""" 83 | 84 | def __init__(self, 85 | unique_id, 86 | example_index, 87 | doc_span_index, 88 | tokens, 89 | token_to_orig_map, 90 | token_is_max_context, 91 | input_ids, 92 | input_mask, 93 | segment_ids, 94 | start_position=None, 95 | end_position=None, 96 | rational_mask=None, 97 | cls_idx=None): 98 | self.unique_id = unique_id 99 | self.example_index = example_index 100 | self.doc_span_index = doc_span_index 101 | self.tokens = tokens 102 | self.token_to_orig_map = token_to_orig_map 103 | self.token_is_max_context = token_is_max_context 104 | self.input_ids = input_ids 105 | self.input_mask = input_mask 106 | self.segment_ids = segment_ids 107 | self.start_position = start_position 108 | self.end_position = end_position 109 | self.cls_idx = cls_idx 110 | self.rational_mask = rational_mask 111 | 112 | 113 | def read_coqa_examples(input_file, history_len=2, add_QA_tag=False): 114 | """Read a CoQA json file into a list of CoqaExample.""" 115 | """Useful Function""" 116 | 117 | def is_whitespace(c): 118 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 119 | return True 120 | return False 121 | 122 | def _str(s): 123 | """ Convert PTB tokens to normal tokens """ 124 | if (s.lower() == '-lrb-'): 125 | s = '(' 126 | elif (s.lower() == '-rrb-'): 127 | s = ')' 128 | elif (s.lower() == '-lsb-'): 129 | s = '[' 130 | elif (s.lower() == '-rsb-'): 131 | s = ']' 132 | elif (s.lower() == '-lcb-'): 133 | s = '{' 134 | elif (s.lower() == '-rcb-'): 135 | s = '}' 136 | return s 137 | 138 | def space_extend(matchobj): 139 | return ' ' + matchobj.group(0) + ' ' 140 | 141 | def pre_proc(text): 142 | text = re.sub( 143 | u'-|\u2010|\u2011|\u2012|\u2013|\u2014|\u2015|%|\[|\]|:|\(|\)|/|\t', 144 | space_extend, text) 145 | text = text.strip(' \n') 146 | text = re.sub('\s+', ' ', text) 147 | return text 148 | 149 | def process(parsed_text): 150 | output = {'word': [], 'offsets': [], 'sentences': []} 151 | 152 | for token in parsed_text: 153 | #[(token.text,token.idx) for token in parsed_sentence] 154 | output['word'].append(_str(token.text)) 155 | # pos = token.tag_ 156 | # output['pos'].append(pos) 157 | # output['pos_id'].append(token2id(pos, POS, 0)) 158 | 159 | # ent = 'O' if token.ent_iob_ == 'O' else (token.ent_iob_ + '-' + token.ent_type_) 160 | # output['ent'].append(ent) 161 | # output['ent_id'].append(token2id(ent, ENT, 0)) 162 | 163 | # output['lemma'].append(token.lemma_ if token.lemma_ != '-PRON-' else token.text.lower()) 164 | output['offsets'].append((token.idx, token.idx + len(token.text))) 165 | 166 | word_idx = 0 167 | for sent in parsed_text.sents: 168 | output['sentences'].append((word_idx, word_idx + len(sent))) 169 | word_idx += len(sent) 170 | 171 | assert word_idx == len(output['word']) 172 | return output 173 | 174 | def get_raw_context_offsets(words, raw_text): 175 | raw_context_offsets = [] 176 | p = 0 177 | for token in words: 178 | while p < len(raw_text) and re.match('\s', raw_text[p]): 179 | p += 1 180 | if raw_text[p:p + len(token)] != token: 181 | print('something is wrong! token', token, 'raw_text:', 182 | raw_text) 183 | 184 | raw_context_offsets.append((p, p + len(token))) 185 | p += len(token) 186 | 187 | return raw_context_offsets 188 | 189 | def find_span(offsets, start, end): 190 | start_index = -1 191 | end_index = -1 192 | for i, offset in enumerate(offsets): 193 | if (start_index < 0) or (start >= offset[0]): 194 | start_index = i 195 | if (end_index < 0) and (end <= offset[1]): 196 | end_index = i 197 | return (start_index, end_index) 198 | 199 | def normalize_answer(s): 200 | """Lower text and remove punctuation, storys and extra whitespace.""" 201 | 202 | def remove_articles(text): 203 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 204 | return re.sub(regex, ' ', text) 205 | 206 | def white_space_fix(text): 207 | return ' '.join(text.split()) 208 | 209 | def remove_punc(text): 210 | exclude = set(string.punctuation) 211 | return ''.join(ch for ch in text if ch not in exclude) 212 | 213 | def lower(text): 214 | return text.lower() 215 | 216 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 217 | 218 | def find_span_with_gt(context, offsets, ground_truth): 219 | best_f1 = 0.0 220 | best_span = (len(offsets) - 1, len(offsets) - 1) 221 | gt = normalize_answer(pre_proc(ground_truth)).split() 222 | 223 | ls = [ 224 | i for i in range(len(offsets)) 225 | if context[offsets[i][0]:offsets[i][1]].lower() in gt 226 | ] 227 | 228 | for i in range(len(ls)): 229 | for j in range(i, len(ls)): 230 | pred = normalize_answer( 231 | pre_proc( 232 | context[offsets[ls[i]][0]:offsets[ls[j]][1]])).split() 233 | common = Counter(pred) & Counter(gt) 234 | num_same = sum(common.values()) 235 | if num_same > 0: 236 | precision = 1.0 * num_same / len(pred) 237 | recall = 1.0 * num_same / len(gt) 238 | f1 = (2 * precision * recall) / (precision + recall) 239 | if f1 > best_f1: 240 | best_f1 = f1 241 | best_span = (ls[i], ls[j]) 242 | return best_span 243 | 244 | def find_span(offsets, start, end): 245 | start_index = -1 246 | end_index = -1 247 | for i, offset in enumerate(offsets): 248 | if (start_index < 0) or (start >= offset[0]): 249 | start_index = i 250 | if (end_index < 0) and (end <= offset[1]): 251 | end_index = i 252 | return (start_index, end_index) 253 | 254 | """Main stream""" 255 | nlp = spacy.load('en', parser=False) 256 | with open(input_file, "r", encoding='utf-8') as reader: 257 | input_data = json.load(reader)["data"] 258 | examples = [] 259 | input_data = input_data # careful 260 | for data_idx in tqdm(range(len(input_data)), desc='Generating examples'): 261 | datum = input_data[data_idx] 262 | context_str = datum['story'] 263 | _datum = { 264 | 'context': context_str, 265 | 'source': datum['source'], 266 | 'id': datum['id'], 267 | 'filename': datum['filename'] 268 | } 269 | nlp_context = nlp(pre_proc(context_str)) 270 | _datum['annotated_context'] = process(nlp_context) 271 | _datum['raw_context_offsets'] = get_raw_context_offsets( 272 | _datum['annotated_context']['word'], context_str) 273 | # _datum['qas'] = [] 274 | assert len(datum['questions']) == len(datum['answers']) 275 | additional_answers = {} 276 | if 'additional_answers' in datum: 277 | for k, answer in datum['additional_answers'].items(): 278 | if len(answer) == len(datum['answers']): 279 | for ex in answer: 280 | idx = ex['turn_id'] 281 | if idx not in additional_answers: 282 | additional_answers[idx] = [] 283 | additional_answers[idx].append(ex['input_text']) 284 | for i in range(len(datum['questions'])): 285 | question, answer = datum['questions'][i], datum['answers'][i] 286 | assert question['turn_id'] == answer['turn_id'] 287 | 288 | idx = question['turn_id'] 289 | _qas = { 290 | 'turn_id': idx, 291 | 'question': question['input_text'], 292 | 'answer': answer['input_text'] 293 | } 294 | if idx in additional_answers: 295 | _qas['additional_answers'] = additional_answers[idx] 296 | 297 | # _qas['annotated_question'] = process( 298 | # nlp(pre_proc(question['input_text']))) 299 | # _qas['annotated_answer'] = process( 300 | # nlp(pre_proc(answer['input_text']))) 301 | _qas['raw_answer'] = answer['input_text'] 302 | 303 | if _qas['raw_answer'].lower() in ['yes', 'yes.']: 304 | _qas['raw_answer'] = 'yes' 305 | if _qas['raw_answer'].lower() in ['no', 'no.']: 306 | _qas['raw_answer'] = 'no' 307 | if _qas['raw_answer'].lower() in ['unknown', 'unknown.']: 308 | _qas['raw_answer'] = 'unknown' 309 | 310 | _qas['answer_span_start'] = answer['span_start'] 311 | _qas['answer_span_end'] = answer['span_end'] 312 | start = answer['span_start'] 313 | end = answer['span_end'] 314 | chosen_text = _datum['context'][start:end].lower() 315 | while len(chosen_text) > 0 and is_whitespace(chosen_text[0]): 316 | chosen_text = chosen_text[1:] 317 | start += 1 318 | while len(chosen_text) > 0 and is_whitespace(chosen_text[-1]): 319 | chosen_text = chosen_text[:-1] 320 | end -= 1 321 | r_start, r_end = find_span(_datum['raw_context_offsets'], start, 322 | end) 323 | input_text = _qas['answer'].strip().lower() 324 | if input_text in chosen_text: 325 | p = chosen_text.find(input_text) 326 | _qas['answer_span'] = find_span(_datum['raw_context_offsets'], 327 | start + p, 328 | start + p + len(input_text)) 329 | else: 330 | _qas['answer_span'] = find_span_with_gt( 331 | _datum['context'], _datum['raw_context_offsets'], 332 | input_text) 333 | long_questions = [] 334 | for j in range(i - history_len, i + 1): 335 | long_question = '' 336 | if j < 0: 337 | continue 338 | long_question += (' ' if add_QA_tag else 339 | ' ') + datum['questions'][j]['input_text'] 340 | if j < i: 341 | long_question += (' ' if add_QA_tag else 342 | ' ') + datum['answers'][j]['input_text'] + ' [SEP]' 343 | long_question = long_question.strip() 344 | long_questions.append(long_question) 345 | 346 | # long_question = long_question.strip() 347 | # _qas['raw_long_question'] = long_question 348 | # _qas['annotated_long_question'] = process( 349 | # nlp(pre_proc(long_question))) 350 | # _datum['qas'].append(_qas) 351 | example = CoqaExample( 352 | qas_id=_datum['id'] + ' ' + str(_qas['turn_id']), 353 | question_text=long_questions, 354 | doc_tokens=_datum['annotated_context']['word'], 355 | orig_answer_text=_qas['raw_answer'], 356 | start_position=_qas['answer_span'][0], 357 | end_position=_qas['answer_span'][1], 358 | rational_start_position=r_start, 359 | rational_end_position=r_end, 360 | additional_answers=_qas['additional_answers'] 361 | if 'additional_answers' in _qas else None, 362 | ) 363 | examples.append(example) 364 | 365 | return examples 366 | 367 | 368 | def convert_examples_to_features(examples, tokenizer, max_seq_length, 369 | doc_stride, max_query_length): 370 | """Loads a data file into a list of `InputBatch`s.""" 371 | 372 | unique_id = 1000000000 373 | 374 | features = [] 375 | for (example_index, 376 | example) in enumerate(tqdm(examples, desc="Generating features")): 377 | query_tokens = [] 378 | for qa in example.question_text: 379 | query_tokens.extend(tokenizer.tokenize(qa)) 380 | 381 | cls_idx = 3 382 | if example.orig_answer_text == 'yes': 383 | cls_idx = 0 # yes 384 | elif example.orig_answer_text == 'no': 385 | cls_idx = 1 # no 386 | elif example.orig_answer_text == 'unknown': 387 | cls_idx = 2 # unknown 388 | 389 | if len(query_tokens) > max_query_length: # keep tail, not head 390 | query_tokens.reverse() 391 | query_tokens = query_tokens[0:max_query_length] 392 | query_tokens.reverse() 393 | 394 | tok_to_orig_index = [] 395 | orig_to_tok_index = [] 396 | all_doc_tokens = [] 397 | for (i, token) in enumerate(example.doc_tokens): 398 | orig_to_tok_index.append(len(all_doc_tokens)) 399 | sub_tokens = tokenizer.tokenize(token) 400 | for sub_token in sub_tokens: 401 | tok_to_orig_index.append(i) 402 | all_doc_tokens.append(sub_token) 403 | 404 | tok_start_position = None 405 | tok_end_position = None 406 | tok_r_start_position, tok_r_end_position = None, None 407 | 408 | # rational part 409 | tok_r_start_position = orig_to_tok_index[ 410 | example.rational_start_position] 411 | if example.rational_end_position < len(example.doc_tokens) - 1: 412 | tok_r_end_position = orig_to_tok_index[ 413 | example.rational_end_position + 1] - 1 414 | else: 415 | tok_r_end_position = len(all_doc_tokens) - 1 416 | # rational part end 417 | 418 | # if tok_r_end_position is None: 419 | # print('DEBUG') 420 | 421 | if cls_idx < 3: 422 | tok_start_position, tok_end_position = 0, 0 423 | else: 424 | tok_start_position = orig_to_tok_index[example.start_position] 425 | if example.end_position < len(example.doc_tokens) - 1: 426 | tok_end_position = orig_to_tok_index[example.end_position + 427 | 1] - 1 428 | else: 429 | tok_end_position = len(all_doc_tokens) - 1 430 | (tok_start_position, tok_end_position) = _improve_answer_span( 431 | all_doc_tokens, tok_start_position, tok_end_position, 432 | tokenizer, example.orig_answer_text) 433 | # The -3 accounts for [CLS], [SEP] and [SEP] 434 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 435 | 436 | # We can have documents that are longer than the maximum sequence length. 437 | # To deal with this we do a sliding window approach, where we take chunks 438 | # of the up to our max length with a stride of `doc_stride`. 439 | _DocSpan = collections.namedtuple( # pylint: disable=invalid-name 440 | "DocSpan", ["start", "length"]) 441 | doc_spans = [] 442 | start_offset = 0 443 | while start_offset < len(all_doc_tokens): 444 | length = len(all_doc_tokens) - start_offset 445 | if length > max_tokens_for_doc: 446 | length = max_tokens_for_doc 447 | doc_spans.append(_DocSpan(start=start_offset, length=length)) 448 | if start_offset + length == len(all_doc_tokens): 449 | break 450 | start_offset += min(length, doc_stride) 451 | 452 | for (doc_span_index, doc_span) in enumerate(doc_spans): 453 | slice_cls_idx = cls_idx 454 | tokens = [] 455 | token_to_orig_map = {} 456 | token_is_max_context = {} 457 | segment_ids = [] 458 | 459 | # cur_id = 2 - query_tokens.count('[SEP]') 460 | 461 | # assert cur_id >= 0 462 | 463 | tokens.append("[CLS]") 464 | segment_ids.append(0) 465 | for token in query_tokens: 466 | tokens.append(token) 467 | segment_ids.append(0) 468 | # if token == '[SEP]': 469 | # cur_id += 1 470 | tokens.append("[SEP]") 471 | segment_ids.append(0) 472 | # cur_id += 1 473 | 474 | # assert cur_id <= 3 475 | 476 | for i in range(doc_span.length): 477 | split_token_index = doc_span.start + i 478 | token_to_orig_map[len( 479 | tokens)] = tok_to_orig_index[split_token_index] 480 | 481 | is_max_context = _check_is_max_context(doc_spans, 482 | doc_span_index, 483 | split_token_index) 484 | token_is_max_context[len(tokens)] = is_max_context 485 | tokens.append(all_doc_tokens[split_token_index]) 486 | segment_ids.append(1) 487 | tokens.append("[SEP]") 488 | segment_ids.append(1) 489 | 490 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 491 | 492 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 493 | # tokens are attended to. 494 | input_mask = [1] * len(input_ids) 495 | 496 | # Zero-pad up to the sequence length. 497 | while len(input_ids) < max_seq_length: 498 | input_ids.append(0) 499 | input_mask.append(0) 500 | segment_ids.append(0) 501 | 502 | assert len(input_ids) == max_seq_length 503 | assert len(input_mask) == max_seq_length 504 | assert len(segment_ids) == max_seq_length 505 | 506 | start_position = None 507 | end_position = None 508 | rational_start_position = None 509 | rational_end_position = None 510 | 511 | # rational_part 512 | doc_start = doc_span.start 513 | doc_end = doc_span.start + doc_span.length - 1 514 | out_of_span = False 515 | if example.rational_start_position == -1 or not ( 516 | tok_r_start_position >= doc_start 517 | and tok_r_end_position <= doc_end): 518 | out_of_span = True 519 | if out_of_span: 520 | rational_start_position = 0 521 | rational_end_position = 0 522 | else: 523 | doc_offset = len(query_tokens) + 2 524 | rational_start_position = tok_r_start_position - doc_start + doc_offset 525 | rational_end_position = tok_r_end_position - doc_start + doc_offset 526 | # rational_part_end 527 | 528 | rational_mask = [0] * len(input_ids) 529 | if not out_of_span: 530 | rational_mask[rational_start_position:rational_end_position + 531 | 1] = [1] * (rational_end_position - 532 | rational_start_position + 1) 533 | 534 | if cls_idx >= 3: 535 | # For training, if our document chunk does not contain an annotation 536 | # we throw it out, since there is nothing to predict. 537 | doc_start = doc_span.start 538 | doc_end = doc_span.start + doc_span.length - 1 539 | out_of_span = False 540 | if not (tok_start_position >= doc_start 541 | and tok_end_position <= doc_end): 542 | out_of_span = True 543 | if out_of_span: 544 | start_position = 0 545 | end_position = 0 546 | slice_cls_idx = 2 547 | else: 548 | doc_offset = len(query_tokens) + 2 549 | start_position = tok_start_position - doc_start + doc_offset 550 | end_position = tok_end_position - doc_start + doc_offset 551 | else: 552 | start_position = 0 553 | end_position = 0 554 | 555 | if example_index < 5: 556 | logger.info("*** Example ***") 557 | logger.info("unique_id: %s" % (unique_id)) 558 | logger.info("example_index: %s" % (example_index)) 559 | logger.info("doc_span_index: %s" % (doc_span_index)) 560 | logger.info("tokens: %s" % " ".join(tokens)) 561 | logger.info("token_to_orig_map: %s" % " ".join( 562 | ["%d:%d" % (x, y) 563 | for (x, y) in token_to_orig_map.items()])) 564 | logger.info("token_is_max_context: %s" % " ".join([ 565 | "%d:%s" % (x, y) 566 | for (x, y) in token_is_max_context.items() 567 | ])) 568 | logger.info("input_ids: %s" % 569 | " ".join([str(x) for x in input_ids])) 570 | logger.info("input_mask: %s" % 571 | " ".join([str(x) for x in input_mask])) 572 | logger.info("segment_ids: %s" % 573 | " ".join([str(x) for x in segment_ids])) 574 | 575 | if slice_cls_idx >= 3: 576 | answer_text = " ".join( 577 | tokens[start_position:(end_position + 1)]) 578 | else: 579 | tmp = ['yes', 'no', 'unknown'] 580 | answer_text = tmp[slice_cls_idx] 581 | 582 | rational_text = " ".join( 583 | tokens[rational_start_position:(rational_end_position + 584 | 1)]) 585 | logger.info("start_position: %d" % (start_position)) 586 | logger.info("end_position: %d" % (end_position)) 587 | logger.info("rational_start_position: %d" % 588 | (rational_start_position)) 589 | logger.info("rational_end_position: %d" % 590 | (rational_end_position)) 591 | logger.info("answer: %s" % (answer_text)) 592 | logger.info("rational: %s" % (rational_text)) 593 | 594 | features.append( 595 | InputFeatures(unique_id=unique_id, 596 | example_index=example_index, 597 | doc_span_index=doc_span_index, 598 | tokens=tokens, 599 | token_to_orig_map=token_to_orig_map, 600 | token_is_max_context=token_is_max_context, 601 | input_ids=input_ids, 602 | input_mask=input_mask, 603 | segment_ids=segment_ids, 604 | start_position=start_position, 605 | end_position=end_position, 606 | rational_mask=rational_mask, 607 | cls_idx=slice_cls_idx)) 608 | unique_id += 1 609 | 610 | return features 611 | 612 | 613 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, 614 | orig_answer_text): 615 | """Returns tokenized answer spans that better match the annotated answer.""" 616 | 617 | # The SQuAD annotations are character based. We first project them to 618 | # whitespace-tokenized words. But then after WordPiece tokenization, we can 619 | # often find a "better match". For example: 620 | # 621 | # Question: What year was John Smith born? 622 | # Context: The leader was John Smith (1895-1943). 623 | # Answer: 1895 624 | # 625 | # The original whitespace-tokenized answer will be "(1895-1943).". However 626 | # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match 627 | # the exact answer, 1895. 628 | # 629 | # However, this is not always possible. Consider the following: 630 | # 631 | # Question: What country is the top exporter of electornics? 632 | # Context: The Japanese electronics industry is the lagest in the world. 633 | # Answer: Japan 634 | # 635 | # In this case, the annotator chose "Japan" as a character sub-span of 636 | # the word "Japanese". Since our WordPiece tokenizer does not split 637 | # "Japanese", we just use "Japanese" as the annotation. This is fairly rare 638 | # in SQuAD, but does happen. 639 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) 640 | 641 | for new_start in range(input_start, input_end + 1): 642 | for new_end in range(input_end, new_start - 1, -1): 643 | text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) 644 | if text_span == tok_answer_text: 645 | return (new_start, new_end) 646 | 647 | return (input_start, input_end) 648 | 649 | 650 | def _check_is_max_context(doc_spans, cur_span_index, position): 651 | """Check if this is the 'max context' doc span for the token.""" 652 | 653 | # Because of the sliding window approach taken to scoring documents, a single 654 | # token can appear in multiple documents. E.g. 655 | # Doc: the man went to the store and bought a gallon of milk 656 | # Span A: the man went to the 657 | # Span B: to the store and bought 658 | # Span C: and bought a gallon of 659 | # ... 660 | # 661 | # Now the word 'bought' will have two scores from spans B and C. We only 662 | # want to consider the score with "maximum context", which we define as 663 | # the *minimum* of its left and right context (the *sum* of left and 664 | # right context will always be the same, of course). 665 | # 666 | # In the example the maximum context for 'bought' would be span C since 667 | # it has 1 left context and 3 right context, while span B has 4 left context 668 | # and 0 right context. 669 | best_score = None 670 | best_span_index = None 671 | for (span_index, doc_span) in enumerate(doc_spans): 672 | end = doc_span.start + doc_span.length - 1 673 | if position < doc_span.start: 674 | continue 675 | if position > end: 676 | continue 677 | num_left_context = position - doc_span.start 678 | num_right_context = end - position 679 | score = min(num_left_context, 680 | num_right_context) + 0.01 * doc_span.length 681 | if best_score is None or score > best_score: 682 | best_score = score 683 | best_span_index = span_index 684 | 685 | return cur_span_index == best_span_index 686 | 687 | 688 | RawResult = collections.namedtuple("RawResult", [ 689 | "unique_id", "start_logits", "end_logits", "yes_logits", "no_logits", 690 | "unk_logits" 691 | ]) 692 | 693 | 694 | def write_predictions(all_examples, all_features, all_results, n_best_size, 695 | max_answer_length, do_lower_case, output_prediction_file, 696 | output_nbest_file, output_null_log_odds_file, 697 | verbose_logging, null_score_diff_threshold): 698 | """Write final predictions to the json file and log-odds of null if needed.""" 699 | logger.info("Writing predictions to: %s" % (output_prediction_file)) 700 | logger.info("Writing nbest to: %s" % (output_nbest_file)) 701 | 702 | example_index_to_features = collections.defaultdict(list) 703 | for feature in all_features: 704 | example_index_to_features[feature.example_index].append(feature) 705 | 706 | unique_id_to_result = {} 707 | for result in all_results: 708 | unique_id_to_result[result.unique_id] = result 709 | 710 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name 711 | "PrelimPrediction", [ 712 | "feature_index", 713 | "start_index", 714 | "end_index", 715 | "score", 716 | "cls_idx", 717 | ]) 718 | 719 | # all_predictions = collections.OrderedDict() 720 | all_predictions = [] 721 | all_nbest_json = collections.OrderedDict() 722 | scores_diff_json = collections.OrderedDict() 723 | 724 | for (example_index, 725 | example) in enumerate(tqdm(all_examples, desc="Writing preditions")): 726 | features = example_index_to_features[example_index] 727 | 728 | prelim_predictions = [] 729 | # keep track of the minimum score of null start+end of position 0 730 | part_prelim_predictions = [] 731 | 732 | score_yes, score_no, score_span, score_unk = -float('INF'), -float( 733 | 'INF'), -float('INF'), float('INF') 734 | min_unk_feature_index, max_yes_feature_index, max_no_feature_index, max_span_feature_index = - \ 735 | 1, -1, -1, -1 # the paragraph slice with min null score 736 | max_span_start_indexes, max_span_end_indexes = [], [] 737 | max_start_index, max_end_index = -1, -1 738 | # null_start_logit = 0 # the start logit at the slice with min null score 739 | # null_end_logit = 0 # the end logit at the slice with min null score 740 | 741 | for (feature_index, feature) in enumerate(features): 742 | result = unique_id_to_result[feature.unique_id] 743 | # if we could have irrelevant answers, get the min score of irrelevant 744 | # feature_null_score = result.start_logits[0] + result.end_logits[0] 745 | 746 | # feature_yes_score, feature_no_score, feature_unk_score, feature_span_score = result.cls_logits 747 | 748 | feature_yes_score, feature_no_score, feature_unk_score = result.yes_logits[ 749 | 0] * 2, result.no_logits[0] * 2, result.unk_logits[0] * 2 750 | start_indexes, end_indexes = _get_best_indexes( 751 | result.start_logits, 752 | n_best_size), _get_best_indexes(result.end_logits, n_best_size) 753 | 754 | for start_index in start_indexes: 755 | for end_index in end_indexes: 756 | if start_index >= len(feature.tokens): 757 | continue 758 | if end_index >= len(feature.tokens): 759 | continue 760 | if start_index not in feature.token_to_orig_map: 761 | continue 762 | if end_index not in feature.token_to_orig_map: 763 | continue 764 | if not feature.token_is_max_context.get( 765 | start_index, False): 766 | continue 767 | if end_index < start_index: 768 | continue 769 | length = end_index - start_index + 1 770 | if length > max_answer_length: 771 | continue 772 | feature_span_score = result.start_logits[ 773 | start_index] + result.end_logits[end_index] 774 | prelim_predictions.append( 775 | _PrelimPrediction(feature_index=feature_index, 776 | start_index=start_index, 777 | end_index=end_index, 778 | score=feature_span_score, 779 | cls_idx=3)) 780 | 781 | if feature_unk_score < score_unk: # find min score_noanswer 782 | score_unk = feature_unk_score 783 | min_unk_feature_index = feature_index 784 | if feature_yes_score > score_yes: # find max score_yes 785 | score_yes = feature_yes_score 786 | max_yes_feature_index = feature_index 787 | if feature_no_score > score_no: # find max score_no 788 | score_no = feature_no_score 789 | max_no_feature_index = feature_index 790 | 791 | prelim_predictions.append( 792 | _PrelimPrediction(feature_index=min_unk_feature_index, 793 | start_index=0, 794 | end_index=0, 795 | score=score_unk, 796 | cls_idx=2)) 797 | prelim_predictions.append( 798 | _PrelimPrediction(feature_index=max_yes_feature_index, 799 | start_index=0, 800 | end_index=0, 801 | score=score_yes, 802 | cls_idx=0)) 803 | prelim_predictions.append( 804 | _PrelimPrediction(feature_index=max_no_feature_index, 805 | start_index=0, 806 | end_index=0, 807 | score=score_no, 808 | cls_idx=1)) 809 | 810 | prelim_predictions = sorted(prelim_predictions, 811 | key=lambda p: p.score, 812 | reverse=True) 813 | 814 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name 815 | "NbestPrediction", ["text", "score", "cls_idx"]) 816 | 817 | seen_predictions = {} 818 | nbest = [] 819 | cls_rank = [] 820 | for pred in prelim_predictions: 821 | if len(nbest) >= n_best_size: # including yes/no/noanswer pred 822 | break 823 | feature = features[pred.feature_index] 824 | if pred.cls_idx == 3: # this is a non-null prediction 825 | tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 826 | 1)] 827 | orig_doc_start = feature.token_to_orig_map[pred.start_index] 828 | orig_doc_end = feature.token_to_orig_map[pred.end_index] 829 | orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 830 | 1)] 831 | tok_text = " ".join(tok_tokens) 832 | 833 | # De-tokenize WordPieces that have been split off. 834 | tok_text = tok_text.replace(" ##", "") 835 | tok_text = tok_text.replace("##", "") 836 | 837 | # Clean whitespace 838 | tok_text = tok_text.strip() 839 | tok_text = " ".join(tok_text.split()) 840 | orig_text = " ".join(orig_tokens) 841 | 842 | final_text = get_final_text(tok_text, orig_text, do_lower_case, 843 | verbose_logging) 844 | if final_text in seen_predictions: 845 | continue 846 | 847 | seen_predictions[final_text] = True 848 | nbest.append( 849 | _NbestPrediction(text=final_text, 850 | score=pred.score, 851 | cls_idx=pred.cls_idx)) 852 | else: 853 | text = ['yes', 'no', 'unknown'] 854 | nbest.append( 855 | _NbestPrediction(text=text[pred.cls_idx], 856 | score=pred.score, 857 | cls_idx=pred.cls_idx)) 858 | 859 | # if we didn't include the empty option in the n-best, include it 860 | # if "" not in seen_predictions: 861 | # nbest.append( 862 | # _NbestPrediction(text=final_text, 863 | # noanswer_logit=pred.noanswer_logit, 864 | # cls_idx=pred.cls_idx)) 865 | # In very rare edge cases we could only have single null prediction. 866 | # So we just create a nonce prediction in this case to avoid failure. 867 | # if len(nbest) == 1: 868 | # nbest.insert( 869 | # 0, 870 | # _NbestPrediction(text="empty", 871 | # start_logit=0.0, 872 | # end_logit=0.0)) 873 | 874 | # In very rare edge cases we could have no valid predictions. So we 875 | # just create a nonce prediction in this case to avoid failure. 876 | 877 | if len(nbest) < 1: 878 | nbest.append( 879 | _NbestPrediction(text='unknown', 880 | score=-float('inf'), 881 | cls_idx=2)) 882 | 883 | assert len(nbest) >= 1 884 | 885 | probs = _compute_softmax([p.score for p in nbest]) 886 | 887 | # total_scores = [] 888 | # cls_scores = [] 889 | # for entry in nbest: 890 | # total_scores.append(entry.start_logit + entry.end_logit) 891 | # for entry in cls_rank: 892 | # cls_scores.append(entry.cls_logit) 893 | 894 | # span_probs = _compute_softmax(total_scores) 895 | # cls_probs = _compute_softmax(cls_scores) 896 | nbest_json = [] 897 | 898 | # # two diff nbest: for cls and for answer span 899 | # cur_rank, cur_probs, cur_scores = ( 900 | # nbest, span_probs, 901 | # total_scores) if cls_rank[0].cls_idx == 3 and len(nbest) > 1 else ( 902 | # cls_rank, cls_probs, cls_scores) 903 | 904 | for i, entry in enumerate(nbest): 905 | output = collections.OrderedDict() 906 | output["text"] = entry.text 907 | output["probability"] = probs[i] 908 | # output["start_logit"] = entry.start_logit 909 | # output["end_logit"] = entry.end_logit 910 | output["socre"] = entry.score 911 | nbest_json.append(output) 912 | 913 | assert len(nbest_json) >= 1 914 | 915 | _id, _turn_id = example.qas_id.split() 916 | all_predictions.append({ 917 | 'id': _id, 918 | 'turn_id': int(_turn_id), 919 | 'answer': confirm_preds(nbest_json) 920 | }) 921 | # if not version_2_with_negative: 922 | # all_predictions[example.qas_id] = nbest_json[0]["text"] 923 | # else: 924 | # # predict "" iff the null score - the score of best non-null > threshold 925 | # score_diff = score_null - best_non_null_entry.start_logit - ( 926 | # best_non_null_entry.end_logit) 927 | # scores_diff_json[example.qas_id] = score_diff 928 | # if score_diff > null_score_diff_threshold: 929 | # all_predictions[example.qas_id] = "" 930 | # else: 931 | # all_predictions[example.qas_id] = best_non_null_entry.text 932 | all_nbest_json[example.qas_id] = nbest_json 933 | 934 | with open(output_prediction_file, "w") as writer: 935 | writer.write(json.dumps(all_predictions, indent=4) + "\n") 936 | 937 | with open(output_nbest_file, "w") as writer: 938 | writer.write(json.dumps(all_nbest_json, indent=4) + "\n") 939 | 940 | # if version_2_with_negative: 941 | # with open(output_null_log_odds_file, "w") as writer: 942 | # writer.write(json.dumps(scores_diff_json, indent=4) + "\n") 943 | 944 | 945 | def confirm_preds(nbest_json): 946 | # Do something for some obvious wrong-predictions 947 | subs = [ 948 | 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 949 | 'ten', 'eleven', 'twelve', 'true', 'false' 950 | ] # very hard-coding, can be extended. 951 | ori = nbest_json[0]['text'] 952 | if len(ori) < 2: # mean span like '.', '!' 953 | for e in nbest_json[1:]: 954 | if _normalize_answer(e['text']) in subs: 955 | return e['text'] 956 | return 'unknown' 957 | return ori 958 | 959 | 960 | def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): 961 | """Project the tokenized prediction back to the original text.""" 962 | 963 | # When we created the data, we kept track of the alignment between original 964 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So 965 | # now `orig_text` contains the span of our original text corresponding to the 966 | # span that we predicted. 967 | # 968 | # However, `orig_text` may contain extra characters that we don't want in 969 | # our prediction. 970 | # 971 | # For example, let's say: 972 | # pred_text = steve smith 973 | # orig_text = Steve Smith's 974 | # 975 | # We don't want to return `orig_text` because it contains the extra "'s". 976 | # 977 | # We don't want to return `pred_text` because it's already been normalized 978 | # (the SQuAD eval script also does punctuation stripping/lower casing but 979 | # our tokenizer does additional normalization like stripping accent 980 | # characters). 981 | # 982 | # What we really want to return is "Steve Smith". 983 | # 984 | # Therefore, we have to apply a semi-complicated alignment heuristic between 985 | # `pred_text` and `orig_text` to get a character-to-character alignment. This 986 | # can fail in certain cases in which case we just return `orig_text`. 987 | 988 | def _strip_spaces(text): 989 | ns_chars = [] 990 | ns_to_s_map = collections.OrderedDict() 991 | for (i, c) in enumerate(text): 992 | if c == " ": 993 | continue 994 | ns_to_s_map[len(ns_chars)] = i 995 | ns_chars.append(c) 996 | ns_text = "".join(ns_chars) 997 | return (ns_text, ns_to_s_map) 998 | 999 | # We first tokenize `orig_text`, strip whitespace from the result 1000 | # and `pred_text`, and check if they are the same length. If they are 1001 | # NOT the same length, the heuristic has failed. If they are the same 1002 | # length, we assume the characters are one-to-one aligned. 1003 | tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 1004 | 1005 | tok_text = " ".join(tokenizer.tokenize(orig_text)) 1006 | 1007 | start_position = tok_text.find(pred_text) 1008 | if start_position == -1: 1009 | if verbose_logging: 1010 | logger.info("Unable to find text: '%s' in '%s'" % 1011 | (pred_text, orig_text)) 1012 | return orig_text 1013 | end_position = start_position + len(pred_text) - 1 1014 | 1015 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) 1016 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) 1017 | 1018 | if len(orig_ns_text) != len(tok_ns_text): 1019 | if verbose_logging: 1020 | logger.info( 1021 | "Length not equal after stripping spaces: '%s' vs '%s'", 1022 | orig_ns_text, tok_ns_text) 1023 | return orig_text 1024 | 1025 | # We then project the characters in `pred_text` back to `orig_text` using 1026 | # the character-to-character alignment. 1027 | tok_s_to_ns_map = {} 1028 | for (i, tok_index) in tok_ns_to_s_map.items(): 1029 | tok_s_to_ns_map[tok_index] = i 1030 | 1031 | orig_start_position = None 1032 | if start_position in tok_s_to_ns_map: 1033 | ns_start_position = tok_s_to_ns_map[start_position] 1034 | if ns_start_position in orig_ns_to_s_map: 1035 | orig_start_position = orig_ns_to_s_map[ns_start_position] 1036 | 1037 | if orig_start_position is None: 1038 | if verbose_logging: 1039 | logger.info("Couldn't map start position") 1040 | return orig_text 1041 | 1042 | orig_end_position = None 1043 | if end_position in tok_s_to_ns_map: 1044 | ns_end_position = tok_s_to_ns_map[end_position] 1045 | if ns_end_position in orig_ns_to_s_map: 1046 | orig_end_position = orig_ns_to_s_map[ns_end_position] 1047 | 1048 | if orig_end_position is None: 1049 | if verbose_logging: 1050 | logger.info("Couldn't map end position") 1051 | return orig_text 1052 | 1053 | output_text = orig_text[orig_start_position:(orig_end_position + 1)] 1054 | return output_text 1055 | 1056 | 1057 | def _get_best_indexes(logits, n_best_size): 1058 | """Get the n-best logits from a list.""" 1059 | index_and_score = sorted(enumerate(logits), 1060 | key=lambda x: x[1], 1061 | reverse=True) 1062 | 1063 | best_indexes = [] 1064 | for i in range(len(index_and_score)): 1065 | if i >= n_best_size: 1066 | break 1067 | best_indexes.append(index_and_score[i][0]) 1068 | return best_indexes 1069 | 1070 | 1071 | def _compute_softmax(scores): 1072 | """Compute softmax probability over raw logits.""" 1073 | if not scores: 1074 | return [] 1075 | 1076 | max_score = None 1077 | for score in scores: 1078 | if max_score is None or score > max_score: 1079 | max_score = score 1080 | 1081 | exp_scores = [] 1082 | total_sum = 0.0 1083 | for score in scores: 1084 | x = math.exp(score - max_score) 1085 | exp_scores.append(x) 1086 | total_sum += x 1087 | 1088 | probs = [] 1089 | for score in exp_scores: 1090 | probs.append(score / total_sum) 1091 | return probs 1092 | 1093 | 1094 | def _normalize_answer(s): 1095 | def remove_articles(text): 1096 | return re.sub(r'\b(a|an|the)\b', ' ', text) 1097 | 1098 | def white_space_fix(text): 1099 | return ' '.join(text.split()) 1100 | 1101 | def remove_punc(text): 1102 | exclude = set(string.punctuation) 1103 | return ''.join(ch for ch in text if ch not in exclude) 1104 | 1105 | def lower(text): 1106 | return text.lower() 1107 | 1108 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 1109 | 1110 | 1111 | def score(pred, truth): 1112 | def _f1_score(pred, answers): 1113 | def _score(g_tokens, a_tokens): 1114 | common = Counter(g_tokens) & Counter(a_tokens) 1115 | num_same = sum(common.values()) 1116 | if num_same == 0: 1117 | return 0 1118 | precision = 1. * num_same / len(g_tokens) 1119 | recall = 1. * num_same / len(a_tokens) 1120 | f1 = (2 * precision * recall) / (precision + recall) 1121 | return f1 1122 | 1123 | if pred is None or answers is None: 1124 | return 0 1125 | 1126 | if len(answers) == 0: 1127 | return 1. if len(pred) == 0 else 0. 1128 | 1129 | g_tokens = _normalize_answer(pred).split() 1130 | ans_tokens = [_normalize_answer(answer).split() for answer in answers] 1131 | scores = [_score(g_tokens, a) for a in ans_tokens] 1132 | if len(ans_tokens) == 1: 1133 | score = scores[0] 1134 | else: 1135 | score = 0 1136 | for i in range(len(ans_tokens)): 1137 | scores_one_out = scores[:i] + scores[(i + 1):] 1138 | score += max(scores_one_out) 1139 | score /= len(ans_tokens) 1140 | return score 1141 | 1142 | # Main Stream 1143 | assert len(pred) == len(truth) 1144 | pred, truth = pred.items(), truth.items() 1145 | no_ans_total = no_total = yes_total = normal_total = total = 0 1146 | no_ans_f1 = no_f1 = yes_f1 = normal_f1 = f1 = 0 1147 | all_f1s = [] 1148 | for (p_id, p), (t_id, t), in zip(pred, truth): 1149 | assert p_id == t_id 1150 | total += 1 1151 | this_f1 = _f1_score(p, t) 1152 | f1 += this_f1 1153 | all_f1s.append(this_f1) 1154 | if t[0].lower() == 'no': 1155 | no_total += 1 1156 | no_f1 += this_f1 1157 | elif t[0].lower() == 'yes': 1158 | yes_total += 1 1159 | yes_f1 += this_f1 1160 | elif t[0].lower() == 'unknown': 1161 | no_ans_total += 1 1162 | no_ans_f1 += this_f1 1163 | else: 1164 | normal_total += 1 1165 | normal_f1 += this_f1 1166 | 1167 | f1 = 100. * f1 / total 1168 | if no_total == 0: 1169 | no_f1 = 0. 1170 | else: 1171 | no_f1 = 100. * no_f1 / no_total 1172 | if yes_total == 0: 1173 | yes_f1 = 0 1174 | else: 1175 | yes_f1 = 100. * yes_f1 / yes_total 1176 | if no_ans_total == 0: 1177 | no_ans_f1 = 0. 1178 | else: 1179 | no_ans_f1 = 100. * no_ans_f1 / no_ans_total 1180 | normal_f1 = 100. * normal_f1 / normal_total 1181 | result = { 1182 | 'total': total, 1183 | 'f1': f1, 1184 | 'no_total': no_total, 1185 | 'no_f1': no_f1, 1186 | 'yes_total': yes_total, 1187 | 'yes_f1': yes_f1, 1188 | 'no_ans_total': no_ans_total, 1189 | 'no_ans_f1': no_ans_f1, 1190 | 'normal_total': normal_total, 1191 | 'normal_f1': normal_f1, 1192 | } 1193 | return result, all_f1s 1194 | --------------------------------------------------------------------------------