├── LICENSE ├── README.md ├── data ├── finetune │ ├── situated_qa │ │ ├── dev.truncate-1500.json │ │ ├── test.truncate-1500.json │ │ └── train.truncate-1500.json │ ├── tsqa_easy │ │ ├── dev.truncate-1500.json │ │ ├── test.truncate-1500.json │ │ └── train.truncate-1500.json │ └── tsqa_hard │ │ ├── dev.truncate-1500.json │ │ ├── test.truncate-1500.json │ │ └── train.truncate-1500.json ├── pretrain │ ├── sample.non_temporal.json │ └── sample.temporal.json └── time_expression │ └── ner_task │ ├── train.json │ └── val.json ├── requirements.txt └── src ├── odqa_t5 ├── run_finetuning.sh ├── run_seq2seq_qa.py ├── squad │ ├── compute_score.py │ └── squad.py ├── squad_v2 │ ├── compute_score.py │ └── squad_v2.py ├── trainer_seq2seq_qa.py └── utils_qa.py ├── preprocess ├── analyze_time_normalize.py ├── sample_100.annotated.enwiki-20221101_temporal-sentences_special-token-prefix.json └── sample_100.enwiki-20221101_temporal-sentences_special-token-prefix.json ├── pretrain_t5 ├── data_collator_for_t5.py ├── generate_date_text.py ├── lr_scheduler.py ├── modeling_t5_temporal.py ├── optimizer.py ├── run_pretrain.sh ├── run_seq2seq.py ├── sampler.py ├── t5_dataset.py ├── trainer_temporal.py └── utils.py └── time_expression ├── inference_time_identification.sh ├── token-classification ├── README.md ├── requirements.txt ├── run.sh ├── run_ner.py ├── run_ner_no_trainer.py └── run_no_trainer.sh └── train_time_identification_model.sh /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Multilingual NLP Team at Alibaba DAMO Academy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RemeMo 2 | 3 | This repo contains the resources used in our paper: [Once Upon a *Time* in *Graph*: Relative-Time Pretraining for Complex Temporal Reasoning](https://arxiv.org/abs/2310.14709) (*EMNLP 2023*). 4 | 5 | Checkpoints hosted on 🤗 HuggingFace: 6 | - Pre-trained RemeMo 👉 [[rememo-base](https://huggingface.co/DAMO-NLP-SG/rememo-base)] [[rememo-large](https://huggingface.co/DAMO-NLP-SG/rememo-large)] 7 | - Fine-tuned time expression extractor 👉 [[roberta-time_identification](https://huggingface.co/DAMO-NLP-SG/roberta-time_identification)] 8 | 9 | ## Table of Contents 10 | 11 | - [Overview](https://github.com/DAMO-NLP-SG/RemeMo#overview) 12 | - [Requirements](https://github.com/DAMO-NLP-SG/RemeMo#requirements) 13 | - [Usage](https://github.com/DAMO-NLP-SG/RemeMo#usage) 14 | - [Repo Structure](https://github.com/DAMO-NLP-SG/RemeMo#repo-structure) 15 | - [Checkpoints](https://github.com/DAMO-NLP-SG/RemeMo#checkpoints) 16 | - [Citation](https://github.com/DAMO-NLP-SG/RemeMo#citation) 17 | - [Acknowledgements](https://github.com/DAMO-NLP-SG/RemeMo#acknowledgments) 18 | 19 | ## Overview 20 | ![rememo_example](https://github.com/DAMO-NLP-SG/RemeMo/assets/18526640/6d1af421-11f7-4ded-9cbd-342316bd5c43) 21 | 22 | - **What Is RemeMo?** RemeMo is an improved T5-based language model, which gains better complex temporal reasoning abilities through pre-training using a novel time-relation-prediction (TRC) objective. 23 | As shown in the figure above, the time relation between any pair of facts is adopted as the TRC pre-training label. The complex temporal dependencies among all facts are thus modeled within a fully-connected directed graph. 24 | 25 | - **When to Use RemeMo?** RemeMo is recommended to be used as a replacement for T5 (or other seq2seq models) in downstream tasks that require complex temporal reasoning, e.g., temporal question answering. 26 | 27 | ## Environment Setup 28 | 29 | ``` 30 | conda create -n rememo python=3.8 31 | conda activate rememo 32 | 33 | git clone https://github.com/DAMO-NLP-SG/RemeMo.git 34 | cd RemeMo 35 | pip install torch==1.13.1+cu116 -f https://download.pytorch.org/whl/cu116/torch_stable.html 36 | pip install -r requirements.txt 37 | ``` 38 | 39 | ## Usage 40 | 41 | ### Fine-tune 42 | 43 | 1. Customize the configurations in `src/odqa_t5/run_finetuning.sh` ; 44 | 2. Run 45 | ``` 46 | mkdir src/odqa_t5/log 47 | sh src/odqa_t5/run_finetuning.sh 48 | ``` 49 | 50 | ### Time-identification 51 | 52 | 1. To run inference: 53 | 54 | see `src/time_expression/inference_time_identification.sh` . 55 | 56 | 2. To train a new model: 57 | 58 | see `src/time_expression/train_time_identification_model.sh` . 59 | 60 | 61 | ### Pre-train 62 | 63 | 1. Customize the configurations in `src/pretran_t5/run_pretrain.sh` : 64 | - `NSP_MODE`: choices include { `mlm` (T5+LM), `mlm_trelation` (RemeMo) }. 65 | - See `data/pretrain` for examples of the pre-training data. Prepare your own pre-training data following the same format. 66 | - Modify other arguments if needed. 67 | 2. Run 68 | ``` 69 | mkdir src/pretrain_t5/log 70 | sh src/pretrain_t5/run_pretrain.sh 71 | ``` 72 | 73 | ## Repo Structure 74 | 75 | - Code: 76 | - `src/time_expression`: time-identification; 77 | - `src/preprocess`: analysis of the pre-processing pipeline; 78 | - `src/pretrain_t5`: pre-training code; 79 | - `src/odqa_t5`: fine-tuning code; 80 | - Data: 81 | - `data/time_expression/ner_task`: obtained by pre-processing TimeBank & adopted for training the roberta-time_identification model; 82 | - `data/pretrain`: examples of the pre-training data; 83 | - `data/finetune`: temporal question answering data; 84 | 85 | - Checkpoints: 86 | - `model_checkpoints/time_expression`: to store the roberta-time_identification model checkpoint; 87 | - `model_checkpoints/rememo_ckpt`: to store RemeMo-{base/large} model checkpoints; 88 | 89 | ## Checkpoints 90 | 91 | | Model | Size (# parameters) | 🤗 Link | 92 | |----------|----------|----------| 93 | | rememo-base| ~250M | [DAMO-NLP-SG/rememo-base](https://huggingface.co/DAMO-NLP-SG/rememo-base) | 94 | | rememo-large| ~800M | [DAMO-NLP-SG/rememo-large](https://huggingface.co/DAMO-NLP-SG/rememo-large) | 95 | 96 | 97 | ## Citation 98 | If you find our project useful, hope you can star our repo and cite our paper as follows: 99 | ``` 100 | @inproceedings{yang-etal-2023-once, 101 | title = "Once Upon a $\textit{Time}$ in $\textit{Graph}$: Relative-Time Pretraining for Complex Temporal Reasoning", 102 | author = "Yang, Sen and 103 | Li, Xin and 104 | Bing, Lidong and 105 | Lam, Wai", 106 | booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing", 107 | year = "2023", 108 | } 109 | ``` 110 | 111 | ## Acknowledgments 112 | 113 | This project uses the code from: 114 | - [HuggingFace Transformers](https://github.com/huggingface/transformers/) 115 | - [timex-normaliser](https://github.com/filannim/timex-normaliser) 116 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.17.0 2 | apex==0.9.10dev 3 | bitsandbytes==0.41.1 4 | datasets==2.7.1 5 | DateTimeRange==1.2.0 6 | evaluate==0.3.0 7 | filelock==3.12.4 8 | huggingface_hub==0.12.0 9 | matplotlib==3.8.0 10 | nltk==3.8.1 11 | numpy==1.23.5 12 | numpy==1.24.2 13 | pathos==0.3.0 14 | # torch==1.13.1+cu116 15 | torch_xla==2.1.0 16 | tqdm==4.64.1 17 | transformers==4.28.1 18 | -------------------------------------------------------------------------------- /src/odqa_t5/run_finetuning.sh: -------------------------------------------------------------------------------- 1 | export TRUNCATION_LEN=1500 # 1500, 1000 or 500 2 | export USE_SQUAD_V2=true # set to true if the dataset contains unanswerable questions, otherwise false 3 | 4 | export DATASET="tsqa_hard" 5 | export HALF_NUM_TRAIN_EXAMPLES=7340 # (number_of_training_instances / 2) 6 | 7 | ########################### Dataset Paths ########################### 8 | 9 | export DATASET_PATH=data/finetune/${DATASET} 10 | export TRAIN_FILE=${DATASET_PATH}/train.truncate-${TRUNCATION_LEN}.json 11 | export VAL_FILE=${DATASET_PATH}/dev.truncate-${TRUNCATION_LEN}.json 12 | export TEST_FILE=${DATASET_PATH}/test.truncate-${TRUNCATION_LEN}.json 13 | 14 | ########################### Model Path ########################### 15 | 16 | export MODEL="rememo-base" 17 | export MODEL_PATH=DAMO-NLP-SG/${MODEL} 18 | 19 | # export MODEL="t5-v1_1-base" 20 | # export MODEL_PATH=google/${MODEL} 21 | 22 | ########################### wandb ########################### 23 | export WANDB_PROJECT=RemeMo-finetune 24 | 25 | # {'wandb', 'none'} 26 | export REPORT_TO=none 27 | 28 | ########################### GPU, Batch-Size, Learning-Rate ########################### 29 | for BATCH_SIZE_PER_GPU in 8 16 30 | do 31 | for LR in 3e-5 1e-4 3e-4 32 | do 33 | 34 | export NUM_TRAIN_EPOCHS=10 35 | 36 | export GRADIENT_ACCUMULATION_STEPS=2 37 | 38 | export EVAL_STEPS=$((HALF_NUM_TRAIN_EXAMPLES / BATCH_SIZE_PER_GPU)) 39 | 40 | export FINAL_BATCH_SIZE=$((BATCH_SIZE_PER_GPU * GRADIENT_ACCUMULATION_STEPS)) 41 | 42 | ####################################################################################### 43 | 44 | export RUN_NAME=${MODEL}.${DATASET}.truncate-${TRUNCATION_LEN}.lr-${LR}.bsz-${FINAL_BATCH_SIZE}.epoch-${NUM_TRAIN_EPOCHS} 45 | 46 | python run_seq2seq_qa.py \ 47 | --model_name_or_path ${MODEL_PATH} \ 48 | --train_file ${TRAIN_FILE} \ 49 | --validation_file ${VAL_FILE} \ 50 | --test_file ${TEST_FILE} \ 51 | --context_column context \ 52 | --question_column question \ 53 | --answer_column answers \ 54 | --do_train \ 55 | --do_eval \ 56 | --do_predict \ 57 | --per_device_train_batch_size ${BATCH_SIZE_PER_GPU} \ 58 | --per_device_eval_batch_size 64 \ 59 | --gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEPS} \ 60 | --learning_rate ${LR} \ 61 | --num_train_epochs ${NUM_TRAIN_EPOCHS} \ 62 | --max_seq_length 512 \ 63 | --output_dir log/${RUN_NAME} \ 64 | --report_to ${REPORT_TO} \ 65 | --logging_steps 10 \ 66 | --evaluation_strategy steps \ 67 | --eval_steps ${EVAL_STEPS} \ 68 | --save_steps ${EVAL_STEPS} \ 69 | --logging_first_step true \ 70 | --load_best_model_at_end true \ 71 | --metric_for_best_model eval_f1 \ 72 | --greater_is_better true \ 73 | --save_total_limit 2 \ 74 | --overwrite_output_dir true \ 75 | --remove_unused_columns true \ 76 | --predict_with_generate \ 77 | --run_name ${RUN_NAME} \ 78 | --version_2_with_negative ${USE_SQUAD_V2} \ 79 | 2>&1 | tee log/${RUN_NAME}.txt 80 | 81 | done 82 | done 83 | 84 | 85 | -------------------------------------------------------------------------------- /src/odqa_t5/run_seq2seq_qa.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Team All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library's seq2seq models for question answering using the 🤗 Seq2SeqTrainer. 18 | """ 19 | # You can also adapt this script on your own question answering task. Pointers for this are left as comments. 20 | 21 | import logging 22 | import os 23 | import sys 24 | from dataclasses import dataclass, field 25 | from typing import List, Optional, Tuple 26 | 27 | import datasets 28 | import evaluate 29 | import numpy as np 30 | from datasets import load_dataset 31 | from trainer_seq2seq_qa import QuestionAnsweringSeq2SeqTrainer 32 | 33 | import transformers 34 | from transformers import ( 35 | AutoConfig, 36 | AutoModelForSeq2SeqLM, 37 | AutoTokenizer, 38 | DataCollatorForSeq2Seq, 39 | HfArgumentParser, 40 | Seq2SeqTrainingArguments, 41 | set_seed, 42 | ) 43 | from transformers.trainer_utils import EvalLoopOutput, EvalPrediction, get_last_checkpoint 44 | from transformers.utils import check_min_version, send_example_telemetry 45 | from transformers.utils.versions import require_version 46 | 47 | 48 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 49 | # check_min_version("4.29.0.dev0") 50 | 51 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") 52 | 53 | logger = logging.getLogger(__name__) 54 | 55 | 56 | @dataclass 57 | class ModelArguments: 58 | """ 59 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 60 | """ 61 | 62 | model_name_or_path: str = field( 63 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 64 | ) 65 | config_name: Optional[str] = field( 66 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 67 | ) 68 | tokenizer_name: Optional[str] = field( 69 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 70 | ) 71 | cache_dir: Optional[str] = field( 72 | default=None, 73 | metadata={"help": "Path to directory to store the pretrained models downloaded from huggingface.co"}, 74 | ) 75 | use_fast_tokenizer: bool = field( 76 | default=True, 77 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 78 | ) 79 | model_revision: str = field( 80 | default="main", 81 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 82 | ) 83 | use_auth_token: bool = field( 84 | default=False, 85 | metadata={ 86 | "help": ( 87 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 88 | "with private models)." 89 | ) 90 | }, 91 | ) 92 | 93 | 94 | @dataclass 95 | class DataTrainingArguments: 96 | """ 97 | Arguments pertaining to what data we are going to input our model for training and eval. 98 | """ 99 | 100 | dataset_name: Optional[str] = field( 101 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 102 | ) 103 | dataset_config_name: Optional[str] = field( 104 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 105 | ) 106 | context_column: Optional[str] = field( 107 | default="context", 108 | metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."}, 109 | ) 110 | question_column: Optional[str] = field( 111 | default="question", 112 | metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."}, 113 | ) 114 | answer_column: Optional[str] = field( 115 | default="answers", 116 | metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."}, 117 | ) 118 | train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) 119 | validation_file: Optional[str] = field( 120 | default=None, 121 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 122 | ) 123 | test_file: Optional[str] = field( 124 | default=None, 125 | metadata={"help": "An optional input test data file to evaluate the perplexity on (a text file)."}, 126 | ) 127 | overwrite_cache: bool = field( 128 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 129 | ) 130 | preprocessing_num_workers: Optional[int] = field( 131 | default=None, 132 | metadata={"help": "The number of processes to use for the preprocessing."}, 133 | ) 134 | max_seq_length: int = field( 135 | default=384, 136 | metadata={ 137 | "help": ( 138 | "The maximum total input sequence length after tokenization. Sequences longer " 139 | "than this will be truncated, sequences shorter will be padded." 140 | ) 141 | }, 142 | ) 143 | max_answer_length: int = field( 144 | default=30, 145 | metadata={ 146 | "help": ( 147 | "The maximum length of an answer that can be generated. This is needed because the start " 148 | "and end predictions are not conditioned on one another." 149 | ) 150 | }, 151 | ) 152 | val_max_answer_length: Optional[int] = field( 153 | default=None, 154 | metadata={ 155 | "help": ( 156 | "The maximum total sequence length for validation target text after tokenization. Sequences longer " 157 | "than this will be truncated, sequences shorter will be padded. Will default to `max_answer_length`." 158 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 159 | "during ``evaluate`` and ``predict``." 160 | ) 161 | }, 162 | ) 163 | pad_to_max_length: bool = field( 164 | default=True, 165 | metadata={ 166 | "help": ( 167 | "Whether to pad all samples to `max_seq_length`. If False, will pad the samples dynamically when" 168 | " batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)." 169 | ) 170 | }, 171 | ) 172 | max_train_samples: Optional[int] = field( 173 | default=None, 174 | metadata={ 175 | "help": ( 176 | "For debugging purposes or quicker training, truncate the number of training examples to this " 177 | "value if set." 178 | ) 179 | }, 180 | ) 181 | max_eval_samples: Optional[int] = field( 182 | default=None, 183 | metadata={ 184 | "help": ( 185 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 186 | "value if set." 187 | ) 188 | }, 189 | ) 190 | max_predict_samples: Optional[int] = field( 191 | default=None, 192 | metadata={ 193 | "help": ( 194 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 195 | "value if set." 196 | ) 197 | }, 198 | ) 199 | version_2_with_negative: bool = field( 200 | default=False, metadata={"help": "If true, some of the examples do not have an answer."} 201 | ) 202 | null_score_diff_threshold: float = field( 203 | default=0.0, 204 | metadata={ 205 | "help": ( 206 | "The threshold used to select the null answer: if the best answer has a score that is less than " 207 | "the score of the null answer minus this threshold, the null answer is selected for this example. " 208 | "Only useful when `version_2_with_negative=True`." 209 | ) 210 | }, 211 | ) 212 | doc_stride: int = field( 213 | default=128, 214 | metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."}, 215 | ) 216 | n_best_size: int = field( 217 | default=20, 218 | metadata={"help": "The total number of n-best predictions to generate when looking for an answer."}, 219 | ) 220 | num_beams: Optional[int] = field( 221 | default=None, 222 | metadata={ 223 | "help": ( 224 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 225 | "which is used during ``evaluate`` and ``predict``." 226 | ) 227 | }, 228 | ) 229 | ignore_pad_token_for_loss: bool = field( 230 | default=True, 231 | metadata={ 232 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 233 | }, 234 | ) 235 | 236 | def __post_init__(self): 237 | if ( 238 | self.dataset_name is None 239 | and self.train_file is None 240 | and self.validation_file is None 241 | and self.test_file is None 242 | ): 243 | raise ValueError("Need either a dataset name or a training/validation file/test_file.") 244 | else: 245 | if self.train_file is not None: 246 | extension = self.train_file.split(".")[-1] 247 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 248 | if self.validation_file is not None: 249 | extension = self.validation_file.split(".")[-1] 250 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 251 | if self.test_file is not None: 252 | extension = self.test_file.split(".")[-1] 253 | assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." 254 | if self.val_max_answer_length is None: 255 | self.val_max_answer_length = self.max_answer_length 256 | 257 | 258 | question_answering_column_name_mapping = { 259 | "squad_v2": ("question", "context", "answer"), 260 | } 261 | 262 | 263 | def main(): 264 | # See all possible arguments in src/transformers/training_args.py 265 | # or by passing the --help flag to this script. 266 | # We now keep distinct sets of args, for a cleaner separation of concerns. 267 | 268 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 269 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 270 | # If we pass only one argument to the script and it's the path to a json file, 271 | # let's parse it to get our arguments. 272 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 273 | else: 274 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 275 | 276 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The 277 | # information sent is the one passed as arguments along with your Python/PyTorch versions. 278 | send_example_telemetry("run_seq2seq_qa", model_args, data_args) 279 | 280 | # Setup logging 281 | logging.basicConfig( 282 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 283 | datefmt="%m/%d/%Y %H:%M:%S", 284 | handlers=[logging.StreamHandler(sys.stdout)], 285 | ) 286 | 287 | if training_args.should_log: 288 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 289 | transformers.utils.logging.set_verbosity_info() 290 | 291 | log_level = training_args.get_process_log_level() 292 | logger.setLevel(log_level) 293 | datasets.utils.logging.set_verbosity(log_level) 294 | transformers.utils.logging.set_verbosity(log_level) 295 | transformers.utils.logging.enable_default_handler() 296 | transformers.utils.logging.enable_explicit_format() 297 | 298 | # Log on each process the small summary: 299 | logger.warning( 300 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 301 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 302 | ) 303 | logger.info(f"Training/evaluation parameters {training_args}") 304 | 305 | # Detecting last checkpoint. 306 | last_checkpoint = None 307 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 308 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 309 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 310 | raise ValueError( 311 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 312 | "Use --overwrite_output_dir to overcome." 313 | ) 314 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 315 | logger.info( 316 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 317 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 318 | ) 319 | 320 | # Set seed before initializing model. 321 | set_seed(training_args.seed) 322 | 323 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 324 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 325 | # (the dataset will be downloaded automatically from the datasets Hub). 326 | # 327 | # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called 328 | # 'text' is found. You can easily tweak this behavior (see below). 329 | # 330 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 331 | # download the dataset. 332 | if data_args.dataset_name is not None: 333 | # Downloading and loading a dataset from the hub. 334 | raw_datasets = load_dataset( 335 | data_args.dataset_name, 336 | data_args.dataset_config_name, 337 | cache_dir=model_args.cache_dir, 338 | use_auth_token=True if model_args.use_auth_token else None, 339 | ) 340 | else: 341 | data_files = {} 342 | if data_args.train_file is not None: 343 | data_files["train"] = data_args.train_file 344 | extension = data_args.train_file.split(".")[-1] 345 | if data_args.validation_file is not None: 346 | data_files["validation"] = data_args.validation_file 347 | extension = data_args.validation_file.split(".")[-1] 348 | if data_args.test_file is not None: 349 | data_files["test"] = data_args.test_file 350 | extension = data_args.test_file.split(".")[-1] 351 | raw_datasets = load_dataset( 352 | extension, 353 | data_files=data_files, 354 | field="data", 355 | cache_dir=model_args.cache_dir, 356 | use_auth_token=True if model_args.use_auth_token else None, 357 | ) 358 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 359 | # https://huggingface.co/docs/datasets/loading_datasets.html. 360 | 361 | # Load pretrained model and tokenizer 362 | # 363 | # Distributed training: 364 | # The .from_pretrained methods guarantee that only one local process can concurrently 365 | # download model & vocab. 366 | config = AutoConfig.from_pretrained( 367 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 368 | cache_dir=model_args.cache_dir, 369 | revision=model_args.model_revision, 370 | use_auth_token=True if model_args.use_auth_token else None, 371 | ) 372 | tokenizer = AutoTokenizer.from_pretrained( 373 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 374 | cache_dir=model_args.cache_dir, 375 | use_fast=model_args.use_fast_tokenizer, 376 | revision=model_args.model_revision, 377 | use_auth_token=True if model_args.use_auth_token else None, 378 | ) 379 | model = AutoModelForSeq2SeqLM.from_pretrained( 380 | model_args.model_name_or_path, 381 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 382 | config=config, 383 | cache_dir=model_args.cache_dir, 384 | revision=model_args.model_revision, 385 | use_auth_token=True if model_args.use_auth_token else None, 386 | ) 387 | 388 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch 389 | # on a small vocab and want a smaller embedding size, remove this test. 390 | embedding_size = model.get_input_embeddings().weight.shape[0] 391 | if len(tokenizer) > embedding_size: 392 | model.resize_token_embeddings(len(tokenizer)) 393 | 394 | if model.config.decoder_start_token_id is None: 395 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") 396 | 397 | # Preprocessing the datasets. 398 | # We need to generate and tokenize inputs and targets. 399 | if training_args.do_train: 400 | column_names = raw_datasets["train"].column_names 401 | elif training_args.do_eval: 402 | column_names = raw_datasets["validation"].column_names 403 | elif training_args.do_predict: 404 | column_names = raw_datasets["test"].column_names 405 | else: 406 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") 407 | return 408 | 409 | # Get the column names for input/target. 410 | dataset_columns = question_answering_column_name_mapping.get(data_args.dataset_name, None) 411 | if data_args.question_column is None: 412 | question_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 413 | else: 414 | question_column = data_args.question_column 415 | if question_column not in column_names: 416 | raise ValueError( 417 | f"--question_column' value '{data_args.question_column}' needs to be one of: {', '.join(column_names)}" 418 | ) 419 | if data_args.context_column is None: 420 | context_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 421 | else: 422 | context_column = data_args.context_column 423 | if context_column not in column_names: 424 | raise ValueError( 425 | f"--context_column' value '{data_args.context_column}' needs to be one of: {', '.join(column_names)}" 426 | ) 427 | if data_args.answer_column is None: 428 | answer_column = dataset_columns[2] if dataset_columns is not None else column_names[2] 429 | else: 430 | answer_column = data_args.answer_column 431 | if answer_column not in column_names: 432 | raise ValueError( 433 | f"--answer_column' value '{data_args.answer_column}' needs to be one of: {', '.join(column_names)}" 434 | ) 435 | 436 | # Temporarily set max_answer_length for training. 437 | max_answer_length = data_args.max_answer_length 438 | padding = "max_length" if data_args.pad_to_max_length else False 439 | 440 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): 441 | logger.warning( 442 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" 443 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" 444 | ) 445 | 446 | if data_args.max_seq_length > tokenizer.model_max_length: 447 | logger.warning( 448 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 449 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 450 | ) 451 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 452 | 453 | def preprocess_squad_batch( 454 | examples, 455 | question_column: str, 456 | context_column: str, 457 | answer_column: str, 458 | ) -> Tuple[List[str], List[str]]: 459 | questions = examples[question_column] 460 | contexts = examples[context_column] 461 | answers = examples[answer_column] 462 | 463 | def generate_input(_question, _context): 464 | return " ".join(["question:", _question.lstrip(), "context:", _context.lstrip()]) 465 | 466 | inputs = [generate_input(question, context) for question, context in zip(questions, contexts)] 467 | targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] 468 | return inputs, targets 469 | 470 | def preprocess_function(examples): 471 | inputs, targets = preprocess_squad_batch(examples, question_column, context_column, answer_column) 472 | 473 | model_inputs = tokenizer(inputs, max_length=max_seq_length, padding=padding, truncation=True) 474 | # Tokenize targets with text_target=... 475 | labels = tokenizer(text_target=targets, max_length=max_answer_length, padding=padding, truncation=True) 476 | 477 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 478 | # padding in the loss. 479 | if padding == "max_length" and data_args.ignore_pad_token_for_loss: 480 | labels["input_ids"] = [ 481 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 482 | ] 483 | 484 | model_inputs["labels"] = labels["input_ids"] 485 | return model_inputs 486 | 487 | # Validation preprocessing 488 | def preprocess_validation_function(examples): 489 | inputs, targets = preprocess_squad_batch(examples, question_column, context_column, answer_column) 490 | 491 | model_inputs = tokenizer( 492 | inputs, 493 | max_length=max_seq_length, 494 | padding=padding, 495 | truncation=True, 496 | return_overflowing_tokens=True, 497 | return_offsets_mapping=True, 498 | ) 499 | # Tokenize targets with the `text_target` keyword argument 500 | labels = tokenizer(text_target=targets, max_length=max_answer_length, padding=padding, truncation=True) 501 | 502 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 503 | # padding in the loss. 504 | if padding == "max_length" and data_args.ignore_pad_token_for_loss: 505 | labels["input_ids"] = [ 506 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 507 | ] 508 | 509 | # Since one example might give us several features if it has a long context, we need a map from a feature to 510 | # its corresponding example. This key gives us just that. 511 | sample_mapping = model_inputs.pop("overflow_to_sample_mapping") 512 | 513 | # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the 514 | # corresponding example_id and we will store the offset mappings. 515 | model_inputs["example_id"] = [] 516 | # Augment the overflowing tokens to the labels 517 | labels_out = [] 518 | 519 | for i in range(len(model_inputs["input_ids"])): 520 | # One example can give several spans, this is the index of the example containing this span of text. 521 | sample_index = sample_mapping[i] 522 | model_inputs["example_id"].append(examples["id"][sample_index]) 523 | labels_out.append(labels["input_ids"][sample_index]) 524 | 525 | model_inputs["labels"] = labels_out 526 | return model_inputs 527 | 528 | if training_args.do_train: 529 | if "train" not in raw_datasets: 530 | raise ValueError("--do_train requires a train dataset") 531 | train_dataset = raw_datasets["train"] 532 | if data_args.max_train_samples is not None: 533 | # We will select sample from whole data if agument is specified 534 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 535 | train_dataset = train_dataset.select(range(max_train_samples)) 536 | # Create train feature from dataset 537 | with training_args.main_process_first(desc="train dataset map pre-processing"): 538 | train_dataset = train_dataset.map( 539 | preprocess_function, 540 | batched=True, 541 | num_proc=data_args.preprocessing_num_workers, 542 | remove_columns=column_names, 543 | load_from_cache_file=not data_args.overwrite_cache, 544 | desc="Running tokenizer on train dataset", 545 | ) 546 | if data_args.max_train_samples is not None: 547 | # Number of samples might increase during Feature Creation, We select only specified max samples 548 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 549 | train_dataset = train_dataset.select(range(max_train_samples)) 550 | 551 | if training_args.do_eval: 552 | if "validation" not in raw_datasets: 553 | raise ValueError("--do_eval requires a validation dataset") 554 | eval_examples = raw_datasets["validation"] 555 | if data_args.max_eval_samples is not None: 556 | # We will select sample from whole data 557 | max_eval_samples = min(len(eval_examples), data_args.max_eval_samples) 558 | eval_examples = eval_examples.select(range(max_eval_samples)) 559 | # Validation Feature Creation 560 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 561 | eval_dataset = eval_examples.map( 562 | preprocess_validation_function, 563 | batched=True, 564 | num_proc=data_args.preprocessing_num_workers, 565 | remove_columns=column_names, 566 | load_from_cache_file=not data_args.overwrite_cache, 567 | desc="Running tokenizer on validation dataset", 568 | ) 569 | if data_args.max_eval_samples is not None: 570 | # During Feature creation dataset samples might increase, we will select required samples again 571 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 572 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 573 | 574 | if training_args.do_predict: 575 | if "test" not in raw_datasets: 576 | raise ValueError("--do_predict requires a test dataset") 577 | predict_examples = raw_datasets["test"] 578 | if data_args.max_predict_samples is not None: 579 | # We will select sample from whole data 580 | predict_examples = predict_examples.select(range(data_args.max_predict_samples)) 581 | # Predict Feature Creation 582 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 583 | predict_dataset = predict_examples.map( 584 | preprocess_validation_function, 585 | batched=True, 586 | num_proc=data_args.preprocessing_num_workers, 587 | remove_columns=column_names, 588 | load_from_cache_file=not data_args.overwrite_cache, 589 | desc="Running tokenizer on prediction dataset", 590 | ) 591 | if data_args.max_predict_samples is not None: 592 | # During Feature creation dataset samples might increase, we will select required samples again 593 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) 594 | predict_dataset = predict_dataset.select(range(max_predict_samples)) 595 | 596 | # Data collator 597 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 598 | data_collator = DataCollatorForSeq2Seq( 599 | tokenizer, 600 | model=model, 601 | label_pad_token_id=label_pad_token_id, 602 | pad_to_multiple_of=8 if training_args.fp16 else None, 603 | ) 604 | 605 | metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad") 606 | 607 | def compute_metrics(p: EvalPrediction): 608 | return metric.compute(predictions=p.predictions, references=p.label_ids) 609 | 610 | # Post-processing: 611 | def post_processing_function( 612 | examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, stage="eval" 613 | ): 614 | # Decode the predicted tokens. 615 | preds = outputs.predictions 616 | if isinstance(preds, tuple): 617 | preds = preds[0] 618 | # Replace -100s used for padding as we can't decode them 619 | preds = np.where(preds != -100, preds, tokenizer.pad_token_id) 620 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 621 | 622 | # Build a map example to its corresponding features. 623 | example_id_to_index = {k: i for i, k in enumerate(examples["id"])} 624 | feature_per_example = {example_id_to_index[feature["example_id"]]: i for i, feature in enumerate(features)} 625 | predictions = {} 626 | # Let's loop over all the examples! 627 | for example_index, example in enumerate(examples): 628 | # This is the index of the feature associated to the current example. 629 | feature_index = feature_per_example[example_index] 630 | predictions[example["id"]] = decoded_preds[feature_index] 631 | 632 | # Format the result to the format the metric expects. 633 | if data_args.version_2_with_negative: 634 | formatted_predictions = [ 635 | {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() 636 | ] 637 | else: 638 | formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] 639 | 640 | references = [{"id": ex["id"], "answers": ex[answer_column]} for ex in examples] 641 | return EvalPrediction(predictions=formatted_predictions, label_ids=references) 642 | 643 | # Initialize our Trainer 644 | trainer = QuestionAnsweringSeq2SeqTrainer( 645 | model=model, 646 | args=training_args, 647 | train_dataset=train_dataset if training_args.do_train else None, 648 | eval_dataset=eval_dataset if training_args.do_eval else None, 649 | eval_examples=eval_examples if training_args.do_eval else None, 650 | tokenizer=tokenizer, 651 | data_collator=data_collator, 652 | compute_metrics=compute_metrics if training_args.predict_with_generate else None, 653 | post_process_function=post_processing_function, 654 | ) 655 | 656 | # Training 657 | if training_args.do_train: 658 | checkpoint = None 659 | if training_args.resume_from_checkpoint is not None: 660 | checkpoint = training_args.resume_from_checkpoint 661 | elif last_checkpoint is not None: 662 | checkpoint = last_checkpoint 663 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 664 | trainer.save_model() # Saves the tokenizer too for easy upload 665 | 666 | metrics = train_result.metrics 667 | max_train_samples = ( 668 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 669 | ) 670 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 671 | 672 | trainer.log_metrics("train", metrics) 673 | trainer.save_metrics("train", metrics) 674 | trainer.save_state() 675 | 676 | # Evaluation 677 | results = {} 678 | max_length = ( 679 | training_args.generation_max_length 680 | if training_args.generation_max_length is not None 681 | else data_args.val_max_answer_length 682 | ) 683 | num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams 684 | if training_args.do_eval: 685 | logger.info("*** Evaluate ***") 686 | metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval") 687 | 688 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 689 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 690 | 691 | trainer.log_metrics("eval", metrics) 692 | trainer.save_metrics("eval", metrics) 693 | 694 | # Prediction 695 | if training_args.do_predict: 696 | logger.info("*** Predict ***") 697 | results = trainer.predict(predict_dataset, predict_examples) 698 | metrics = results.metrics 699 | 700 | max_predict_samples = ( 701 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 702 | ) 703 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 704 | 705 | trainer.log_metrics("predict", metrics) 706 | trainer.save_metrics("predict", metrics) 707 | 708 | if training_args.push_to_hub: 709 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "question-answering"} 710 | if data_args.dataset_name is not None: 711 | kwargs["dataset_tags"] = data_args.dataset_name 712 | if data_args.dataset_config_name is not None: 713 | kwargs["dataset_args"] = data_args.dataset_config_name 714 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 715 | else: 716 | kwargs["dataset"] = data_args.dataset_name 717 | 718 | trainer.push_to_hub(**kwargs) 719 | 720 | 721 | def _mp_fn(index): 722 | # For xla_spawn (TPUs) 723 | main() 724 | 725 | 726 | if __name__ == "__main__": 727 | main() 728 | -------------------------------------------------------------------------------- /src/odqa_t5/squad/compute_score.py: -------------------------------------------------------------------------------- 1 | """ Official evaluation script for v1.1 of the SQuAD dataset. """ 2 | 3 | import argparse 4 | import json 5 | import re 6 | import string 7 | import sys 8 | from collections import Counter 9 | 10 | 11 | def normalize_answer(s): 12 | """Lower text and remove punctuation, articles and extra whitespace.""" 13 | 14 | def remove_articles(text): 15 | return re.sub(r"\b(a|an|the)\b", " ", text) 16 | 17 | def white_space_fix(text): 18 | return " ".join(text.split()) 19 | 20 | def remove_punc(text): 21 | exclude = set(string.punctuation) 22 | return "".join(ch for ch in text if ch not in exclude) 23 | 24 | def lower(text): 25 | return text.lower() 26 | 27 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 28 | 29 | 30 | def f1_score(prediction, ground_truth): 31 | prediction_tokens = normalize_answer(prediction).split() 32 | ground_truth_tokens = normalize_answer(ground_truth).split() 33 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 34 | num_same = sum(common.values()) 35 | if num_same == 0: 36 | return 0 37 | precision = 1.0 * num_same / len(prediction_tokens) 38 | recall = 1.0 * num_same / len(ground_truth_tokens) 39 | f1 = (2 * precision * recall) / (precision + recall) 40 | return f1 41 | 42 | 43 | def exact_match_score(prediction, ground_truth): 44 | return normalize_answer(prediction) == normalize_answer(ground_truth) 45 | 46 | 47 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 48 | scores_for_ground_truths = [] 49 | for ground_truth in ground_truths: 50 | score = metric_fn(prediction, ground_truth) 51 | scores_for_ground_truths.append(score) 52 | return max(scores_for_ground_truths) 53 | 54 | 55 | def compute_score(dataset, predictions): 56 | f1 = exact_match = total = 0 57 | for article in dataset: 58 | for paragraph in article["paragraphs"]: 59 | for qa in paragraph["qas"]: 60 | total += 1 61 | if qa["id"] not in predictions: 62 | message = "Unanswered question " + qa["id"] + " will receive score 0." 63 | print(message, file=sys.stderr) 64 | continue 65 | ground_truths = list(map(lambda x: x["text"], qa["answers"])) 66 | prediction = predictions[qa["id"]] 67 | exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths) 68 | f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths) 69 | 70 | exact_match = 100.0 * exact_match / total 71 | f1 = 100.0 * f1 / total 72 | 73 | return {"exact_match": exact_match, "f1": f1} 74 | 75 | 76 | if __name__ == "__main__": 77 | expected_version = "1.1" 78 | parser = argparse.ArgumentParser(description="Evaluation for SQuAD " + expected_version) 79 | parser.add_argument("dataset_file", help="Dataset file") 80 | parser.add_argument("prediction_file", help="Prediction File") 81 | args = parser.parse_args() 82 | with open(args.dataset_file) as dataset_file: 83 | dataset_json = json.load(dataset_file) 84 | if dataset_json["version"] != expected_version: 85 | print( 86 | "Evaluation expects v-" + expected_version + ", but got dataset with v-" + dataset_json["version"], 87 | file=sys.stderr, 88 | ) 89 | dataset = dataset_json["data"] 90 | with open(args.prediction_file) as prediction_file: 91 | predictions = json.load(prediction_file) 92 | print(json.dumps(compute_score(dataset, predictions))) 93 | -------------------------------------------------------------------------------- /src/odqa_t5/squad/squad.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Evaluate Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ SQuAD metric. """ 15 | 16 | import datasets 17 | 18 | import evaluate 19 | 20 | from .compute_score import compute_score 21 | 22 | 23 | _CITATION = """\ 24 | @inproceedings{Rajpurkar2016SQuAD10, 25 | title={SQuAD: 100, 000+ Questions for Machine Comprehension of Text}, 26 | author={Pranav Rajpurkar and Jian Zhang and Konstantin Lopyrev and Percy Liang}, 27 | booktitle={EMNLP}, 28 | year={2016} 29 | } 30 | """ 31 | 32 | _DESCRIPTION = """ 33 | This metric wrap the official scoring script for version 1 of the Stanford Question Answering Dataset (SQuAD). 34 | 35 | Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by 36 | crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, 37 | from the corresponding reading passage, or the question might be unanswerable. 38 | """ 39 | 40 | _KWARGS_DESCRIPTION = """ 41 | Computes SQuAD scores (F1 and EM). 42 | Args: 43 | predictions: List of question-answers dictionaries with the following key-values: 44 | - 'id': id of the question-answer pair as given in the references (see below) 45 | - 'prediction_text': the text of the answer 46 | references: List of question-answers dictionaries with the following key-values: 47 | - 'id': id of the question-answer pair (see above), 48 | - 'answers': a Dict in the SQuAD dataset format 49 | { 50 | 'text': list of possible texts for the answer, as a list of strings 51 | 'answer_start': list of start positions for the answer, as a list of ints 52 | } 53 | Note that answer_start values are not taken into account to compute the metric. 54 | Returns: 55 | 'exact_match': Exact match (the normalized answer exactly match the gold answer) 56 | 'f1': The F-score of predicted tokens versus the gold answer 57 | Examples: 58 | 59 | >>> predictions = [{'prediction_text': '1976', 'id': '56e10a3be3433e1400422b22'}] 60 | >>> references = [{'answers': {'answer_start': [97], 'text': ['1976']}, 'id': '56e10a3be3433e1400422b22'}] 61 | >>> squad_metric = evaluate.load("squad") 62 | >>> results = squad_metric.compute(predictions=predictions, references=references) 63 | >>> print(results) 64 | {'exact_match': 100.0, 'f1': 100.0} 65 | """ 66 | 67 | 68 | @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 69 | class Squad(evaluate.Metric): 70 | def _info(self): 71 | return evaluate.MetricInfo( 72 | description=_DESCRIPTION, 73 | citation=_CITATION, 74 | inputs_description=_KWARGS_DESCRIPTION, 75 | features=datasets.Features( 76 | { 77 | "predictions": {"id": datasets.Value("string"), "prediction_text": datasets.Value("string")}, 78 | "references": { 79 | "id": datasets.Value("string"), 80 | "answers": datasets.features.Sequence( 81 | { 82 | "text": datasets.Value("string"), 83 | "answer_start": datasets.Value("int32"), 84 | } 85 | ), 86 | }, 87 | } 88 | ), 89 | codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"], 90 | reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"], 91 | ) 92 | 93 | def _compute(self, predictions, references): 94 | pred_dict = {prediction["id"]: prediction["prediction_text"] for prediction in predictions} 95 | dataset = [ 96 | { 97 | "paragraphs": [ 98 | { 99 | "qas": [ 100 | { 101 | "answers": [{"text": answer_text} for answer_text in ref["answers"]["text"]], 102 | "id": ref["id"], 103 | } 104 | for ref in references 105 | ] 106 | } 107 | ] 108 | } 109 | ] 110 | score = compute_score(dataset=dataset, predictions=pred_dict) 111 | return score 112 | -------------------------------------------------------------------------------- /src/odqa_t5/squad_v2/compute_score.py: -------------------------------------------------------------------------------- 1 | """Official evaluation script for SQuAD version 2.0. 2 | 3 | In addition to basic functionality, we also compute additional statistics and 4 | plot precision-recall curves if an additional na_prob.json file is provided. 5 | This file is expected to map question ID's to the model's predicted probability 6 | that a question is unanswerable. 7 | """ 8 | import argparse 9 | import collections 10 | import json 11 | import os 12 | import re 13 | import string 14 | import sys 15 | 16 | import numpy as np 17 | 18 | 19 | ARTICLES_REGEX = re.compile(r"\b(a|an|the)\b", re.UNICODE) 20 | 21 | OPTS = None 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser("Official evaluation script for SQuAD version 2.0.") 26 | parser.add_argument("data_file", metavar="data.json", help="Input data JSON file.") 27 | parser.add_argument("pred_file", metavar="pred.json", help="Model predictions.") 28 | parser.add_argument( 29 | "--out-file", "-o", metavar="eval.json", help="Write accuracy metrics to file (default is stdout)." 30 | ) 31 | parser.add_argument( 32 | "--na-prob-file", "-n", metavar="na_prob.json", help="Model estimates of probability of no answer." 33 | ) 34 | parser.add_argument( 35 | "--na-prob-thresh", 36 | "-t", 37 | type=float, 38 | default=1.0, 39 | help='Predict "" if no-answer probability exceeds this (default = 1.0).', 40 | ) 41 | parser.add_argument( 42 | "--out-image-dir", "-p", metavar="out_images", default=None, help="Save precision-recall curves to directory." 43 | ) 44 | parser.add_argument("--verbose", "-v", action="store_true") 45 | if len(sys.argv) == 1: 46 | parser.print_help() 47 | sys.exit(1) 48 | return parser.parse_args() 49 | 50 | 51 | def make_qid_to_has_ans(dataset): 52 | qid_to_has_ans = {} 53 | for article in dataset: 54 | for p in article["paragraphs"]: 55 | for qa in p["qas"]: 56 | qid_to_has_ans[qa["id"]] = bool(qa["answers"]["text"]) 57 | return qid_to_has_ans 58 | 59 | 60 | def normalize_answer(s): 61 | """Lower text and remove punctuation, articles and extra whitespace.""" 62 | 63 | def remove_articles(text): 64 | return ARTICLES_REGEX.sub(" ", text) 65 | 66 | def white_space_fix(text): 67 | return " ".join(text.split()) 68 | 69 | def remove_punc(text): 70 | exclude = set(string.punctuation) 71 | return "".join(ch for ch in text if ch not in exclude) 72 | 73 | def lower(text): 74 | return text.lower() 75 | 76 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 77 | 78 | 79 | def get_tokens(s): 80 | if not s: 81 | return [] 82 | return normalize_answer(s).split() 83 | 84 | 85 | def compute_exact(a_gold, a_pred): 86 | return int(normalize_answer(a_gold) == normalize_answer(a_pred)) 87 | 88 | 89 | def compute_f1(a_gold, a_pred): 90 | gold_toks = get_tokens(a_gold) 91 | pred_toks = get_tokens(a_pred) 92 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 93 | num_same = sum(common.values()) 94 | if len(gold_toks) == 0 or len(pred_toks) == 0: 95 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 96 | return int(gold_toks == pred_toks) 97 | if num_same == 0: 98 | return 0 99 | precision = 1.0 * num_same / len(pred_toks) 100 | recall = 1.0 * num_same / len(gold_toks) 101 | f1 = (2 * precision * recall) / (precision + recall) 102 | return f1 103 | 104 | 105 | def get_raw_scores(dataset, preds): 106 | exact_scores = {} 107 | f1_scores = {} 108 | for article in dataset: 109 | for p in article["paragraphs"]: 110 | for qa in p["qas"]: 111 | qid = qa["id"] 112 | gold_answers = [t for t in qa["answers"]["text"] if normalize_answer(t)] 113 | if not gold_answers: 114 | # For unanswerable questions, only correct answer is empty string 115 | gold_answers = [""] 116 | if qid not in preds: 117 | print(f"Missing prediction for {qid}") 118 | continue 119 | a_pred = preds[qid] 120 | # Take max over all gold answers 121 | exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers) 122 | f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers) 123 | return exact_scores, f1_scores 124 | 125 | 126 | def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): 127 | new_scores = {} 128 | for qid, s in scores.items(): 129 | pred_na = na_probs[qid] > na_prob_thresh 130 | if pred_na: 131 | new_scores[qid] = float(not qid_to_has_ans[qid]) 132 | else: 133 | new_scores[qid] = s 134 | return new_scores 135 | 136 | 137 | def make_eval_dict(exact_scores, f1_scores, qid_list=None): 138 | if not qid_list: 139 | total = len(exact_scores) 140 | return collections.OrderedDict( 141 | [ 142 | ("exact", 100.0 * sum(exact_scores.values()) / total), 143 | ("f1", 100.0 * sum(f1_scores.values()) / total), 144 | ("total", total), 145 | ] 146 | ) 147 | else: 148 | total = len(qid_list) 149 | return collections.OrderedDict( 150 | [ 151 | ("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total), 152 | ("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total), 153 | ("total", total), 154 | ] 155 | ) 156 | 157 | 158 | def merge_eval(main_eval, new_eval, prefix): 159 | for k in new_eval: 160 | main_eval[f"{prefix}_{k}"] = new_eval[k] 161 | 162 | 163 | def plot_pr_curve(precisions, recalls, out_image, title): 164 | plt.step(recalls, precisions, color="b", alpha=0.2, where="post") 165 | plt.fill_between(recalls, precisions, step="post", alpha=0.2, color="b") 166 | plt.xlabel("Recall") 167 | plt.ylabel("Precision") 168 | plt.xlim([0.0, 1.05]) 169 | plt.ylim([0.0, 1.05]) 170 | plt.title(title) 171 | plt.savefig(out_image) 172 | plt.clf() 173 | 174 | 175 | def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, out_image=None, title=None): 176 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 177 | true_pos = 0.0 178 | cur_p = 1.0 179 | cur_r = 0.0 180 | precisions = [1.0] 181 | recalls = [0.0] 182 | avg_prec = 0.0 183 | for i, qid in enumerate(qid_list): 184 | if qid_to_has_ans[qid]: 185 | true_pos += scores[qid] 186 | cur_p = true_pos / float(i + 1) 187 | cur_r = true_pos / float(num_true_pos) 188 | if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i + 1]]: 189 | # i.e., if we can put a threshold after this point 190 | avg_prec += cur_p * (cur_r - recalls[-1]) 191 | precisions.append(cur_p) 192 | recalls.append(cur_r) 193 | if out_image: 194 | plot_pr_curve(precisions, recalls, out_image, title) 195 | return {"ap": 100.0 * avg_prec} 196 | 197 | 198 | def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, out_image_dir): 199 | if out_image_dir and not os.path.exists(out_image_dir): 200 | os.makedirs(out_image_dir) 201 | num_true_pos = sum(1 for v in qid_to_has_ans.values() if v) 202 | if num_true_pos == 0: 203 | return 204 | pr_exact = make_precision_recall_eval( 205 | exact_raw, 206 | na_probs, 207 | num_true_pos, 208 | qid_to_has_ans, 209 | out_image=os.path.join(out_image_dir, "pr_exact.png"), 210 | title="Precision-Recall curve for Exact Match score", 211 | ) 212 | pr_f1 = make_precision_recall_eval( 213 | f1_raw, 214 | na_probs, 215 | num_true_pos, 216 | qid_to_has_ans, 217 | out_image=os.path.join(out_image_dir, "pr_f1.png"), 218 | title="Precision-Recall curve for F1 score", 219 | ) 220 | oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()} 221 | pr_oracle = make_precision_recall_eval( 222 | oracle_scores, 223 | na_probs, 224 | num_true_pos, 225 | qid_to_has_ans, 226 | out_image=os.path.join(out_image_dir, "pr_oracle.png"), 227 | title="Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)", 228 | ) 229 | merge_eval(main_eval, pr_exact, "pr_exact") 230 | merge_eval(main_eval, pr_f1, "pr_f1") 231 | merge_eval(main_eval, pr_oracle, "pr_oracle") 232 | 233 | 234 | def histogram_na_prob(na_probs, qid_list, image_dir, name): 235 | if not qid_list: 236 | return 237 | x = [na_probs[k] for k in qid_list] 238 | weights = np.ones_like(x) / float(len(x)) 239 | plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0)) 240 | plt.xlabel("Model probability of no-answer") 241 | plt.ylabel("Proportion of dataset") 242 | plt.title(f"Histogram of no-answer probability: {name}") 243 | plt.savefig(os.path.join(image_dir, f"na_prob_hist_{name}.png")) 244 | plt.clf() 245 | 246 | 247 | def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): 248 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) 249 | cur_score = num_no_ans 250 | best_score = cur_score 251 | best_thresh = 0.0 252 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 253 | for i, qid in enumerate(qid_list): 254 | if qid not in scores: 255 | continue 256 | if qid_to_has_ans[qid]: 257 | diff = scores[qid] 258 | else: 259 | if preds[qid]: 260 | diff = -1 261 | else: 262 | diff = 0 263 | cur_score += diff 264 | if cur_score > best_score: 265 | best_score = cur_score 266 | best_thresh = na_probs[qid] 267 | return 100.0 * best_score / len(scores), best_thresh 268 | 269 | 270 | def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): 271 | best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) 272 | best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) 273 | main_eval["best_exact"] = best_exact 274 | main_eval["best_exact_thresh"] = exact_thresh 275 | main_eval["best_f1"] = best_f1 276 | main_eval["best_f1_thresh"] = f1_thresh 277 | 278 | 279 | def main(): 280 | with open(OPTS.data_file) as f: 281 | dataset_json = json.load(f) 282 | dataset = dataset_json["data"] 283 | with open(OPTS.pred_file) as f: 284 | preds = json.load(f) 285 | if OPTS.na_prob_file: 286 | with open(OPTS.na_prob_file) as f: 287 | na_probs = json.load(f) 288 | else: 289 | na_probs = {k: 0.0 for k in preds} 290 | qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False 291 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] 292 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] 293 | exact_raw, f1_raw = get_raw_scores(dataset, preds) 294 | exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh) 295 | f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh) 296 | out_eval = make_eval_dict(exact_thresh, f1_thresh) 297 | if has_ans_qids: 298 | has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) 299 | merge_eval(out_eval, has_ans_eval, "HasAns") 300 | if no_ans_qids: 301 | no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) 302 | merge_eval(out_eval, no_ans_eval, "NoAns") 303 | if OPTS.na_prob_file: 304 | find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans) 305 | if OPTS.na_prob_file and OPTS.out_image_dir: 306 | run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, OPTS.out_image_dir) 307 | histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, "hasAns") 308 | histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, "noAns") 309 | if OPTS.out_file: 310 | with open(OPTS.out_file, "w") as f: 311 | json.dump(out_eval, f) 312 | else: 313 | print(json.dumps(out_eval, indent=2)) 314 | 315 | 316 | if __name__ == "__main__": 317 | OPTS = parse_args() 318 | if OPTS.out_image_dir: 319 | import matplotlib 320 | 321 | matplotlib.use("Agg") 322 | import matplotlib.pyplot as plt 323 | main() 324 | -------------------------------------------------------------------------------- /src/odqa_t5/squad_v2/squad_v2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Evaluate Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ SQuAD v2 metric. """ 15 | 16 | import datasets 17 | 18 | import evaluate 19 | 20 | from .compute_score import ( 21 | apply_no_ans_threshold, 22 | find_all_best_thresh, 23 | get_raw_scores, 24 | make_eval_dict, 25 | make_qid_to_has_ans, 26 | merge_eval, 27 | ) 28 | 29 | 30 | _CITATION = """\ 31 | @inproceedings{Rajpurkar2016SQuAD10, 32 | title={SQuAD: 100, 000+ Questions for Machine Comprehension of Text}, 33 | author={Pranav Rajpurkar and Jian Zhang and Konstantin Lopyrev and Percy Liang}, 34 | booktitle={EMNLP}, 35 | year={2016} 36 | } 37 | """ 38 | 39 | _DESCRIPTION = """ 40 | This metric wrap the official scoring script for version 2 of the Stanford Question 41 | Answering Dataset (SQuAD). 42 | 43 | Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by 44 | crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, 45 | from the corresponding reading passage, or the question might be unanswerable. 46 | 47 | SQuAD2.0 combines the 100,000 questions in SQuAD1.1 with over 50,000 unanswerable questions 48 | written adversarially by crowdworkers to look similar to answerable ones. 49 | To do well on SQuAD2.0, systems must not only answer questions when possible, but also 50 | determine when no answer is supported by the paragraph and abstain from answering. 51 | """ 52 | 53 | _KWARGS_DESCRIPTION = """ 54 | Computes SQuAD v2 scores (F1 and EM). 55 | Args: 56 | predictions: List of triple for question-answers to score with the following elements: 57 | - the question-answer 'id' field as given in the references (see below) 58 | - the text of the answer 59 | - the probability that the question has no answer 60 | references: List of question-answers dictionaries with the following key-values: 61 | - 'id': id of the question-answer pair (see above), 62 | - 'answers': a list of Dict {'text': text of the answer as a string} 63 | no_answer_threshold: float 64 | Probability threshold to decide that a question has no answer. 65 | Returns: 66 | 'exact': Exact match (the normalized answer exactly match the gold answer) 67 | 'f1': The F-score of predicted tokens versus the gold answer 68 | 'total': Number of score considered 69 | 'HasAns_exact': Exact match (the normalized answer exactly match the gold answer) 70 | 'HasAns_f1': The F-score of predicted tokens versus the gold answer 71 | 'HasAns_total': Number of score considered 72 | 'NoAns_exact': Exact match (the normalized answer exactly match the gold answer) 73 | 'NoAns_f1': The F-score of predicted tokens versus the gold answer 74 | 'NoAns_total': Number of score considered 75 | 'best_exact': Best exact match (with varying threshold) 76 | 'best_exact_thresh': No-answer probability threshold associated to the best exact match 77 | 'best_f1': Best F1 (with varying threshold) 78 | 'best_f1_thresh': No-answer probability threshold associated to the best F1 79 | Examples: 80 | 81 | >>> predictions = [{'prediction_text': '1976', 'id': '56e10a3be3433e1400422b22', 'no_answer_probability': 0.}] 82 | >>> references = [{'answers': {'answer_start': [97], 'text': ['1976']}, 'id': '56e10a3be3433e1400422b22'}] 83 | >>> squad_v2_metric = evaluate.load("squad_v2") 84 | >>> results = squad_v2_metric.compute(predictions=predictions, references=references) 85 | >>> print(results) 86 | {'exact': 100.0, 'f1': 100.0, 'total': 1, 'HasAns_exact': 100.0, 'HasAns_f1': 100.0, 'HasAns_total': 1, 'best_exact': 100.0, 'best_exact_thresh': 0.0, 'best_f1': 100.0, 'best_f1_thresh': 0.0} 87 | """ 88 | 89 | 90 | @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 91 | class SquadV2(evaluate.Metric): 92 | def _info(self): 93 | return evaluate.MetricInfo( 94 | description=_DESCRIPTION, 95 | citation=_CITATION, 96 | inputs_description=_KWARGS_DESCRIPTION, 97 | features=datasets.Features( 98 | { 99 | "predictions": { 100 | "id": datasets.Value("string"), 101 | "prediction_text": datasets.Value("string"), 102 | "no_answer_probability": datasets.Value("float32"), 103 | }, 104 | "references": { 105 | "id": datasets.Value("string"), 106 | "answers": datasets.features.Sequence( 107 | {"text": datasets.Value("string"), "answer_start": datasets.Value("int32")} 108 | ), 109 | }, 110 | } 111 | ), 112 | codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"], 113 | reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"], 114 | ) 115 | 116 | def _compute(self, predictions, references, no_answer_threshold=1.0): 117 | no_answer_probabilities = {p["id"]: p["no_answer_probability"] for p in predictions} 118 | dataset = [{"paragraphs": [{"qas": references}]}] 119 | predictions = {p["id"]: p["prediction_text"] for p in predictions} 120 | 121 | qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False 122 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] 123 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] 124 | 125 | exact_raw, f1_raw = get_raw_scores(dataset, predictions) 126 | exact_thresh = apply_no_ans_threshold(exact_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold) 127 | f1_thresh = apply_no_ans_threshold(f1_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold) 128 | out_eval = make_eval_dict(exact_thresh, f1_thresh) 129 | 130 | if has_ans_qids: 131 | has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) 132 | merge_eval(out_eval, has_ans_eval, "HasAns") 133 | if no_ans_qids: 134 | no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) 135 | merge_eval(out_eval, no_ans_eval, "NoAns") 136 | find_all_best_thresh(out_eval, predictions, exact_raw, f1_raw, no_answer_probabilities, qid_to_has_ans) 137 | return dict(out_eval) 138 | -------------------------------------------------------------------------------- /src/odqa_t5/trainer_seq2seq_qa.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Team All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | A subclass of `Trainer` specific to Question-Answering tasks 17 | """ 18 | import math 19 | import time 20 | from typing import Dict, List, Optional 21 | 22 | from torch.utils.data import Dataset 23 | 24 | from transformers import Seq2SeqTrainer, is_torch_tpu_available 25 | from transformers.trainer_utils import PredictionOutput, speed_metrics 26 | 27 | 28 | if is_torch_tpu_available(check_device=False): 29 | import torch_xla.core.xla_model as xm 30 | import torch_xla.debug.metrics as met 31 | 32 | 33 | class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): 34 | def __init__(self, *args, eval_examples=None, post_process_function=None, **kwargs): 35 | super().__init__(*args, **kwargs) 36 | self.eval_examples = eval_examples 37 | self.post_process_function = post_process_function 38 | 39 | # def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None, metric_key_prefix: str = "eval"): 40 | def evaluate( 41 | self, 42 | eval_dataset: Optional[Dataset] = None, 43 | eval_examples=None, 44 | ignore_keys: Optional[List[str]] = None, 45 | metric_key_prefix: str = "eval", 46 | **gen_kwargs, 47 | ) -> Dict[str, float]: 48 | gen_kwargs = gen_kwargs.copy() 49 | gen_kwargs["max_length"] = ( 50 | gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length 51 | ) 52 | gen_kwargs["num_beams"] = ( 53 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams 54 | ) 55 | self._gen_kwargs = gen_kwargs 56 | 57 | eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset 58 | eval_dataloader = self.get_eval_dataloader(eval_dataset) 59 | eval_examples = self.eval_examples if eval_examples is None else eval_examples 60 | 61 | # Temporarily disable metric computation, we will do it in the loop here. 62 | compute_metrics = self.compute_metrics 63 | self.compute_metrics = None 64 | start_time = time.time() 65 | eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop 66 | try: 67 | output = eval_loop( 68 | eval_dataloader, 69 | description="Evaluation", 70 | # No point gathering the predictions if there are no metrics, otherwise we defer to 71 | # self.args.prediction_loss_only 72 | prediction_loss_only=True if compute_metrics is None else None, 73 | ignore_keys=ignore_keys, 74 | metric_key_prefix=metric_key_prefix, 75 | ) 76 | finally: 77 | self.compute_metrics = compute_metrics 78 | total_batch_size = self.args.eval_batch_size * self.args.world_size 79 | if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: 80 | start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] 81 | output.metrics.update( 82 | speed_metrics( 83 | metric_key_prefix, 84 | start_time, 85 | num_samples=output.num_samples, 86 | num_steps=math.ceil(output.num_samples / total_batch_size), 87 | ) 88 | ) 89 | 90 | if self.post_process_function is not None and self.compute_metrics is not None and self.args.should_save: 91 | # Only the main node write the results by default 92 | eval_preds = self.post_process_function(eval_examples, eval_dataset, output) 93 | metrics = self.compute_metrics(eval_preds) 94 | 95 | # Prefix all keys with metric_key_prefix + '_' 96 | for key in list(metrics.keys()): 97 | if not key.startswith(f"{metric_key_prefix}_"): 98 | metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) 99 | 100 | metrics.update(output.metrics) 101 | else: 102 | metrics = output.metrics 103 | 104 | if self.args.should_log: 105 | # Only the main node log the results by default 106 | self.log(metrics) 107 | 108 | if self.args.tpu_metrics_debug or self.args.debug: 109 | # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) 110 | xm.master_print(met.metrics_report()) 111 | 112 | self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) 113 | return metrics 114 | 115 | def predict( 116 | self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test", **gen_kwargs 117 | ): 118 | self._gen_kwargs = gen_kwargs.copy() 119 | 120 | predict_dataloader = self.get_test_dataloader(predict_dataset) 121 | 122 | # Temporarily disable metric computation, we will do it in the loop here. 123 | compute_metrics = self.compute_metrics 124 | self.compute_metrics = None 125 | start_time = time.time() 126 | eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop 127 | try: 128 | output = eval_loop( 129 | predict_dataloader, 130 | description="Prediction", 131 | # No point gathering the predictions if there are no metrics, otherwise we defer to 132 | # self.args.prediction_loss_only 133 | prediction_loss_only=True if compute_metrics is None else None, 134 | ignore_keys=ignore_keys, 135 | metric_key_prefix=metric_key_prefix, 136 | ) 137 | finally: 138 | self.compute_metrics = compute_metrics 139 | 140 | total_batch_size = self.args.eval_batch_size * self.args.world_size 141 | if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: 142 | start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] 143 | output.metrics.update( 144 | speed_metrics( 145 | metric_key_prefix, 146 | start_time, 147 | num_samples=output.num_samples, 148 | num_steps=math.ceil(output.num_samples / total_batch_size), 149 | ) 150 | ) 151 | if self.post_process_function is None or self.compute_metrics is None: 152 | return output 153 | 154 | predictions = self.post_process_function(predict_examples, predict_dataset, output, "predict") 155 | metrics = self.compute_metrics(predictions) 156 | 157 | # Prefix all keys with metric_key_prefix + '_' 158 | for key in list(metrics.keys()): 159 | if not key.startswith(f"{metric_key_prefix}_"): 160 | metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) 161 | metrics.update(output.metrics) 162 | return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics) 163 | -------------------------------------------------------------------------------- /src/odqa_t5/utils_qa.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Team All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | Post-processing utilities for question answering. 17 | """ 18 | import collections 19 | import json 20 | import logging 21 | import os 22 | from typing import Optional, Tuple 23 | 24 | import numpy as np 25 | from tqdm.auto import tqdm 26 | 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | def postprocess_qa_predictions( 32 | examples, 33 | features, 34 | predictions: Tuple[np.ndarray, np.ndarray], 35 | version_2_with_negative: bool = False, 36 | n_best_size: int = 20, 37 | max_answer_length: int = 30, 38 | null_score_diff_threshold: float = 0.0, 39 | output_dir: Optional[str] = None, 40 | prefix: Optional[str] = None, 41 | log_level: Optional[int] = logging.WARNING, 42 | ): 43 | """ 44 | Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the 45 | original contexts. This is the base postprocessing functions for models that only return start and end logits. 46 | 47 | Args: 48 | examples: The non-preprocessed dataset (see the main script for more information). 49 | features: The processed dataset (see the main script for more information). 50 | predictions (:obj:`Tuple[np.ndarray, np.ndarray]`): 51 | The predictions of the model: two arrays containing the start logits and the end logits respectively. Its 52 | first dimension must match the number of elements of :obj:`features`. 53 | version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`): 54 | Whether or not the underlying dataset contains examples with no answers. 55 | n_best_size (:obj:`int`, `optional`, defaults to 20): 56 | The total number of n-best predictions to generate when looking for an answer. 57 | max_answer_length (:obj:`int`, `optional`, defaults to 30): 58 | The maximum length of an answer that can be generated. This is needed because the start and end predictions 59 | are not conditioned on one another. 60 | null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0): 61 | The threshold used to select the null answer: if the best answer has a score that is less than the score of 62 | the null answer minus this threshold, the null answer is selected for this example (note that the score of 63 | the null answer for an example giving several features is the minimum of the scores for the null answer on 64 | each feature: all features must be aligned on the fact they `want` to predict a null answer). 65 | 66 | Only useful when :obj:`version_2_with_negative` is :obj:`True`. 67 | output_dir (:obj:`str`, `optional`): 68 | If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if 69 | :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null 70 | answers, are saved in `output_dir`. 71 | prefix (:obj:`str`, `optional`): 72 | If provided, the dictionaries mentioned above are saved with `prefix` added to their names. 73 | log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): 74 | ``logging`` log level (e.g., ``logging.WARNING``) 75 | """ 76 | if len(predictions) != 2: 77 | raise ValueError("`predictions` should be a tuple with two elements (start_logits, end_logits).") 78 | all_start_logits, all_end_logits = predictions 79 | 80 | if len(predictions[0]) != len(features): 81 | raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.") 82 | 83 | # Build a map example to its corresponding features. 84 | example_id_to_index = {k: i for i, k in enumerate(examples["id"])} 85 | features_per_example = collections.defaultdict(list) 86 | for i, feature in enumerate(features): 87 | features_per_example[example_id_to_index[feature["example_id"]]].append(i) 88 | 89 | # The dictionaries we have to fill. 90 | all_predictions = collections.OrderedDict() 91 | all_nbest_json = collections.OrderedDict() 92 | if version_2_with_negative: 93 | scores_diff_json = collections.OrderedDict() 94 | 95 | # Logging. 96 | logger.setLevel(log_level) 97 | logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") 98 | 99 | # Let's loop over all the examples! 100 | for example_index, example in enumerate(tqdm(examples)): 101 | # Those are the indices of the features associated to the current example. 102 | feature_indices = features_per_example[example_index] 103 | 104 | min_null_prediction = None 105 | prelim_predictions = [] 106 | 107 | # Looping through all the features associated to the current example. 108 | for feature_index in feature_indices: 109 | # We grab the predictions of the model for this feature. 110 | start_logits = all_start_logits[feature_index] 111 | end_logits = all_end_logits[feature_index] 112 | # This is what will allow us to map some the positions in our logits to span of texts in the original 113 | # context. 114 | offset_mapping = features[feature_index]["offset_mapping"] 115 | # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context 116 | # available in the current feature. 117 | token_is_max_context = features[feature_index].get("token_is_max_context", None) 118 | 119 | # Update minimum null prediction. 120 | feature_null_score = start_logits[0] + end_logits[0] 121 | if min_null_prediction is None or min_null_prediction["score"] > feature_null_score: 122 | min_null_prediction = { 123 | "offsets": (0, 0), 124 | "score": feature_null_score, 125 | "start_logit": start_logits[0], 126 | "end_logit": end_logits[0], 127 | } 128 | 129 | # Go through all possibilities for the `n_best_size` greater start and end logits. 130 | start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist() 131 | end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist() 132 | for start_index in start_indexes: 133 | for end_index in end_indexes: 134 | # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond 135 | # to part of the input_ids that are not in the context. 136 | if ( 137 | start_index >= len(offset_mapping) 138 | or end_index >= len(offset_mapping) 139 | or offset_mapping[start_index] is None 140 | or len(offset_mapping[start_index]) < 2 141 | or offset_mapping[end_index] is None 142 | or len(offset_mapping[end_index]) < 2 143 | ): 144 | continue 145 | # Don't consider answers with a length that is either < 0 or > max_answer_length. 146 | if end_index < start_index or end_index - start_index + 1 > max_answer_length: 147 | continue 148 | # Don't consider answer that don't have the maximum context available (if such information is 149 | # provided). 150 | if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): 151 | continue 152 | 153 | prelim_predictions.append( 154 | { 155 | "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), 156 | "score": start_logits[start_index] + end_logits[end_index], 157 | "start_logit": start_logits[start_index], 158 | "end_logit": end_logits[end_index], 159 | } 160 | ) 161 | if version_2_with_negative and min_null_prediction is not None: 162 | # Add the minimum null prediction 163 | prelim_predictions.append(min_null_prediction) 164 | null_score = min_null_prediction["score"] 165 | 166 | # Only keep the best `n_best_size` predictions. 167 | predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size] 168 | 169 | # Add back the minimum null prediction if it was removed because of its low score. 170 | if ( 171 | version_2_with_negative 172 | and min_null_prediction is not None 173 | and not any(p["offsets"] == (0, 0) for p in predictions) 174 | ): 175 | predictions.append(min_null_prediction) 176 | 177 | # Use the offsets to gather the answer text in the original context. 178 | context = example["context"] 179 | for pred in predictions: 180 | offsets = pred.pop("offsets") 181 | pred["text"] = context[offsets[0] : offsets[1]] 182 | 183 | # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid 184 | # failure. 185 | if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""): 186 | predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0}) 187 | 188 | # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using 189 | # the LogSumExp trick). 190 | scores = np.array([pred.pop("score") for pred in predictions]) 191 | exp_scores = np.exp(scores - np.max(scores)) 192 | probs = exp_scores / exp_scores.sum() 193 | 194 | # Include the probabilities in our predictions. 195 | for prob, pred in zip(probs, predictions): 196 | pred["probability"] = prob 197 | 198 | # Pick the best prediction. If the null answer is not possible, this is easy. 199 | if not version_2_with_negative: 200 | all_predictions[example["id"]] = predictions[0]["text"] 201 | else: 202 | # Otherwise we first need to find the best non-empty prediction. 203 | i = 0 204 | while predictions[i]["text"] == "": 205 | i += 1 206 | best_non_null_pred = predictions[i] 207 | 208 | # Then we compare to the null prediction using the threshold. 209 | score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"] 210 | scores_diff_json[example["id"]] = float(score_diff) # To be JSON-serializable. 211 | if score_diff > null_score_diff_threshold: 212 | all_predictions[example["id"]] = "" 213 | else: 214 | all_predictions[example["id"]] = best_non_null_pred["text"] 215 | 216 | # Make `predictions` JSON-serializable by casting np.float back to float. 217 | all_nbest_json[example["id"]] = [ 218 | {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()} 219 | for pred in predictions 220 | ] 221 | 222 | # If we have an output_dir, let's save all those dicts. 223 | if output_dir is not None: 224 | if not os.path.isdir(output_dir): 225 | raise EnvironmentError(f"{output_dir} is not a directory.") 226 | 227 | prediction_file = os.path.join( 228 | output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json" 229 | ) 230 | nbest_file = os.path.join( 231 | output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json" 232 | ) 233 | if version_2_with_negative: 234 | null_odds_file = os.path.join( 235 | output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json" 236 | ) 237 | 238 | logger.info(f"Saving predictions to {prediction_file}.") 239 | with open(prediction_file, "w") as writer: 240 | writer.write(json.dumps(all_predictions, indent=4) + "\n") 241 | logger.info(f"Saving nbest_preds to {nbest_file}.") 242 | with open(nbest_file, "w") as writer: 243 | writer.write(json.dumps(all_nbest_json, indent=4) + "\n") 244 | if version_2_with_negative: 245 | logger.info(f"Saving null_odds to {null_odds_file}.") 246 | with open(null_odds_file, "w") as writer: 247 | writer.write(json.dumps(scores_diff_json, indent=4) + "\n") 248 | 249 | return all_predictions 250 | 251 | 252 | def postprocess_qa_predictions_with_beam_search( 253 | examples, 254 | features, 255 | predictions: Tuple[np.ndarray, np.ndarray], 256 | version_2_with_negative: bool = False, 257 | n_best_size: int = 20, 258 | max_answer_length: int = 30, 259 | start_n_top: int = 5, 260 | end_n_top: int = 5, 261 | output_dir: Optional[str] = None, 262 | prefix: Optional[str] = None, 263 | log_level: Optional[int] = logging.WARNING, 264 | ): 265 | """ 266 | Post-processes the predictions of a question-answering model with beam search to convert them to answers that are substrings of the 267 | original contexts. This is the postprocessing functions for models that return start and end logits, indices, as well as 268 | cls token predictions. 269 | 270 | Args: 271 | examples: The non-preprocessed dataset (see the main script for more information). 272 | features: The processed dataset (see the main script for more information). 273 | predictions (:obj:`Tuple[np.ndarray, np.ndarray]`): 274 | The predictions of the model: two arrays containing the start logits and the end logits respectively. Its 275 | first dimension must match the number of elements of :obj:`features`. 276 | version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`): 277 | Whether or not the underlying dataset contains examples with no answers. 278 | n_best_size (:obj:`int`, `optional`, defaults to 20): 279 | The total number of n-best predictions to generate when looking for an answer. 280 | max_answer_length (:obj:`int`, `optional`, defaults to 30): 281 | The maximum length of an answer that can be generated. This is needed because the start and end predictions 282 | are not conditioned on one another. 283 | start_n_top (:obj:`int`, `optional`, defaults to 5): 284 | The number of top start logits too keep when searching for the :obj:`n_best_size` predictions. 285 | end_n_top (:obj:`int`, `optional`, defaults to 5): 286 | The number of top end logits too keep when searching for the :obj:`n_best_size` predictions. 287 | output_dir (:obj:`str`, `optional`): 288 | If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if 289 | :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null 290 | answers, are saved in `output_dir`. 291 | prefix (:obj:`str`, `optional`): 292 | If provided, the dictionaries mentioned above are saved with `prefix` added to their names. 293 | log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): 294 | ``logging`` log level (e.g., ``logging.WARNING``) 295 | """ 296 | if len(predictions) != 5: 297 | raise ValueError("`predictions` should be a tuple with five elements.") 298 | start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions 299 | 300 | if len(predictions[0]) != len(features): 301 | raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.") 302 | 303 | # Build a map example to its corresponding features. 304 | example_id_to_index = {k: i for i, k in enumerate(examples["id"])} 305 | features_per_example = collections.defaultdict(list) 306 | for i, feature in enumerate(features): 307 | features_per_example[example_id_to_index[feature["example_id"]]].append(i) 308 | 309 | # The dictionaries we have to fill. 310 | all_predictions = collections.OrderedDict() 311 | all_nbest_json = collections.OrderedDict() 312 | scores_diff_json = collections.OrderedDict() if version_2_with_negative else None 313 | 314 | # Logging. 315 | logger.setLevel(log_level) 316 | logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") 317 | 318 | # Let's loop over all the examples! 319 | for example_index, example in enumerate(tqdm(examples)): 320 | # Those are the indices of the features associated to the current example. 321 | feature_indices = features_per_example[example_index] 322 | 323 | min_null_score = None 324 | prelim_predictions = [] 325 | 326 | # Looping through all the features associated to the current example. 327 | for feature_index in feature_indices: 328 | # We grab the predictions of the model for this feature. 329 | start_log_prob = start_top_log_probs[feature_index] 330 | start_indexes = start_top_index[feature_index] 331 | end_log_prob = end_top_log_probs[feature_index] 332 | end_indexes = end_top_index[feature_index] 333 | feature_null_score = cls_logits[feature_index] 334 | # This is what will allow us to map some the positions in our logits to span of texts in the original 335 | # context. 336 | offset_mapping = features[feature_index]["offset_mapping"] 337 | # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context 338 | # available in the current feature. 339 | token_is_max_context = features[feature_index].get("token_is_max_context", None) 340 | 341 | # Update minimum null prediction 342 | if min_null_score is None or feature_null_score < min_null_score: 343 | min_null_score = feature_null_score 344 | 345 | # Go through all possibilities for the `n_start_top`/`n_end_top` greater start and end logits. 346 | for i in range(start_n_top): 347 | for j in range(end_n_top): 348 | start_index = int(start_indexes[i]) 349 | j_index = i * end_n_top + j 350 | end_index = int(end_indexes[j_index]) 351 | # Don't consider out-of-scope answers (last part of the test should be unnecessary because of the 352 | # p_mask but let's not take any risk) 353 | if ( 354 | start_index >= len(offset_mapping) 355 | or end_index >= len(offset_mapping) 356 | or offset_mapping[start_index] is None 357 | or len(offset_mapping[start_index]) < 2 358 | or offset_mapping[end_index] is None 359 | or len(offset_mapping[end_index]) < 2 360 | ): 361 | continue 362 | 363 | # Don't consider answers with a length negative or > max_answer_length. 364 | if end_index < start_index or end_index - start_index + 1 > max_answer_length: 365 | continue 366 | # Don't consider answer that don't have the maximum context available (if such information is 367 | # provided). 368 | if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): 369 | continue 370 | prelim_predictions.append( 371 | { 372 | "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), 373 | "score": start_log_prob[i] + end_log_prob[j_index], 374 | "start_log_prob": start_log_prob[i], 375 | "end_log_prob": end_log_prob[j_index], 376 | } 377 | ) 378 | 379 | # Only keep the best `n_best_size` predictions. 380 | predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size] 381 | 382 | # Use the offsets to gather the answer text in the original context. 383 | context = example["context"] 384 | for pred in predictions: 385 | offsets = pred.pop("offsets") 386 | pred["text"] = context[offsets[0] : offsets[1]] 387 | 388 | # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid 389 | # failure. 390 | if len(predictions) == 0: 391 | # Without predictions min_null_score is going to be None and None will cause an exception later 392 | min_null_score = -2e-6 393 | predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": min_null_score}) 394 | 395 | # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using 396 | # the LogSumExp trick). 397 | scores = np.array([pred.pop("score") for pred in predictions]) 398 | exp_scores = np.exp(scores - np.max(scores)) 399 | probs = exp_scores / exp_scores.sum() 400 | 401 | # Include the probabilities in our predictions. 402 | for prob, pred in zip(probs, predictions): 403 | pred["probability"] = prob 404 | 405 | # Pick the best prediction and set the probability for the null answer. 406 | all_predictions[example["id"]] = predictions[0]["text"] 407 | if version_2_with_negative: 408 | scores_diff_json[example["id"]] = float(min_null_score) 409 | 410 | # Make `predictions` JSON-serializable by casting np.float back to float. 411 | all_nbest_json[example["id"]] = [ 412 | {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()} 413 | for pred in predictions 414 | ] 415 | 416 | # If we have an output_dir, let's save all those dicts. 417 | if output_dir is not None: 418 | if not os.path.isdir(output_dir): 419 | raise EnvironmentError(f"{output_dir} is not a directory.") 420 | 421 | prediction_file = os.path.join( 422 | output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json" 423 | ) 424 | nbest_file = os.path.join( 425 | output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json" 426 | ) 427 | if version_2_with_negative: 428 | null_odds_file = os.path.join( 429 | output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json" 430 | ) 431 | 432 | logger.info(f"Saving predictions to {prediction_file}.") 433 | with open(prediction_file, "w") as writer: 434 | writer.write(json.dumps(all_predictions, indent=4) + "\n") 435 | logger.info(f"Saving nbest_preds to {nbest_file}.") 436 | with open(nbest_file, "w") as writer: 437 | writer.write(json.dumps(all_nbest_json, indent=4) + "\n") 438 | if version_2_with_negative: 439 | logger.info(f"Saving null_odds to {null_odds_file}.") 440 | with open(null_odds_file, "w") as writer: 441 | writer.write(json.dumps(scores_diff_json, indent=4) + "\n") 442 | 443 | return all_predictions, scores_diff_json 444 | -------------------------------------------------------------------------------- /src/preprocess/analyze_time_normalize.py: -------------------------------------------------------------------------------- 1 | import json, random, os 2 | 3 | SAMPLE_SIZE = 100 4 | 5 | def main(): 6 | sample_file_path = f"./sample_{SAMPLE_SIZE}.enwiki-20221101_temporal-sentences_special-token-prefix.json" 7 | original_file_path = "../../../data/pretrain_t5/enwiki-20221101_temporal-sentences_special-token-prefix.json" 8 | if os.path.exists(sample_file_path): 9 | print(f"Sample file already exists at {sample_file_path}") 10 | with open(sample_file_path, 'r', encoding='utf-8') as f: 11 | data = [json.loads(line) for line in f.readlines() if line.strip() != ''] 12 | else: 13 | with open(original_file_path, 'r', encoding='utf-8') as f: 14 | data = [json.loads(line) for line in f.readlines() if line.strip() != ''] 15 | random.shuffle(data) 16 | data = data[:SAMPLE_SIZE] 17 | with open(sample_file_path, 'w', encoding='utf-8') as f: 18 | f.writelines([json.dumps(d) + '\n' for d in data]) 19 | annotated_data_path = f"./sample_{SAMPLE_SIZE}.annotated.enwiki-20221101_temporal-sentences_special-token-prefix.json" 20 | if os.path.exists(annotated_data_path): 21 | print(f"Annotated data file already exists at {annotated_data_path}") 22 | with open(annotated_data_path, 'r', encoding='utf-8') as f: 23 | annotated_data = json.load(f) 24 | else: 25 | annotated_data = [[] for _ in range(SAMPLE_SIZE)] 26 | # annotated_data = [[]*SAMPLE_SIZE] 27 | start_idx = 0 28 | for i, d in enumerate(annotated_data): 29 | if d == []: 30 | if i > 0: 31 | start_idx = i - 1 32 | break 33 | try: 34 | for i, d in enumerate(data): 35 | if i >= start_idx: 36 | while True: 37 | x = input( 38 | '******\nNo.{}: Enter a normalized time expression (e.g., "[20100101, 20120304]") for the following sentence:\n\n{}\n\nAlready input: {} \n\n(Input "q" to quit the program, "s" to skip the current one.):\n'.format(i+1, d['text'], annotated_data[i]) 39 | ) 40 | if x == 'q': 41 | raise KeyboardInterrupt 42 | elif x == 's': 43 | break 44 | else: 45 | try: 46 | x = json.loads(x) 47 | annotated_data[i].append(x) 48 | except: 49 | print('Invalid input. Please try again.\n') 50 | print("\n") 51 | except: 52 | with open(annotated_data_path, 'w', encoding='utf-8') as f: 53 | json.dump(annotated_data, f) 54 | raise KeyboardInterrupt 55 | 56 | if __name__ == '__main__': 57 | main() -------------------------------------------------------------------------------- /src/preprocess/sample_100.annotated.enwiki-20221101_temporal-sentences_special-token-prefix.json: -------------------------------------------------------------------------------- 1 | [[[20070101, 20080101], [20060622, 20060623]], [[19620101, 19630101]], [[20070101, 20110101]], [[20131115, 20131116], [20140101, 20150101]], [[18920101, 18930101], [18950101, 18960101], [19030101, 19890101], [19900101, 20221101]], [[18470101, 18480101], [18570101, 18580101], [19070101, 19080101]], [[19590402, 19590403]], [[20160101, 20170101]], [[20170101, 20180101], [20180101, 20190101]], [[19800101, 19900101], [19990101, 20000101]], [[18921014, 19591227]], [[17780521, 17780522]], [[19280101, 19290101], [19320101, 19330101], [19360101, 19370101], [19400101, 19410101]], [[19900101, 20100101]], [[19730101, 19740101], [19760101, 19770101]], [[19300101, 19310101], [19300101, 19650101]], [[20190101, 20200101]], [[19910101, 19920101]], [[20200826, 20200827]], [[19580101, 19590101], [19950101, 19960101], [20040101, 20050101]], [[20100101, 20110101], [19010527, 19010528]], [[19051224, 19051225], [19290622, 19290623], [19300329, 19300330], [19560420, 19560421], [19710101, 19720101], [19920201, 19920301]], [[19900101, 19910101], [20090101, 20100101]], [[19840101, 19850101]], [[19070101, 19080101], [20190101, 20200101], [19040101, 19050101]], [[18090401, 18090501]], [[19871228, 19871229]], [[19220101, 19230101]], [[20050101, 20060101]], [[18500101, 18510101], [17851225,17851226], [18660910, 18660911]], [[19900101, 20000101], [19920101, 19930101], [20090924, 20090925]], [[18060101, 18070101], [19800101, 19810101]], [[20011201, 20020101], [20020101, 20030101], [20041201, 20041202]], [[20010101, 20070101], [20070101, 20080101]], [[17940301, 17940401]], [[20080101, 20090101], [20150101, 20160101], [20170101, 20180101], [20210101, 20220101]], [[19670101, 19680101], [19740101, 19750101], [19750101, 19760101]], [[19480101, 19490101]], [[20010909, 20010911], [20010101, 20020101]], [[19411221, 19411222], [19420114, 19420115]], [[20020101, 20060101]], [[19590101, 19600101], [19700101, 19760101]], [[20060501, 20060601], [20110101, 20120101]], [[20160101, 20170101], [20170101, 20180101]], [[19450101, 19460101]], [[19200101, 19210101], [20130101, 20140101]], [[18660101, 18670101]], [[19060101, 19070101]], [[20180612, 20180613]], [[19700101, 19800101], [19770101, 19780101]], [[20000101, 21000101]], [[19111004, 19111005], [19390306, 19390307]], [[19750101, 19760101], [19800101, 19810101], [19920101, 19930101]], [[20040101, 20050101]], [[19910101, 19920101], [19930101, 19940101], [19930615, 19930616], [19930618, 19930619], [19930625, 19930626], [19931003, 19931004]], [[20200312, 20200313]], [[19850101, 19860101]], [[19570101,19580101]], [[20170601, 20170701], [20181001, 20181101], [20190201, 20190301], [20200101, 20200201]], [[10540101, 10550101], [10630101, 10640101]], [[19440101, 19450101], [19510101, 19520101], [19650101, 19660101]], [[18640901, 18641001], [18640917, 18640918], [18640919, 18640920]], [[20180101, 20190101]], [[20080101, 20080201]], [[20130301, 20130401], [20130101, 20140101], [20131007, 20131008]], [[19910912, 19910913], [19830101, 19840101]], [[18200101, 18210101], [18610101, 18620101], [18710101, 18720101]], [[20130101, 20140101]], [[20080101, 20090101]], [[20170816, 20170817]], [[18680101, 18690101]], [[19700101, 19930101], [20050101, 20060101]], [[18511103, 18511105], [18510304, 18510305], [18511201, 18520101]], [[19930101, 19940101], [19940101, 19950101], [19950101, 19960101]], [[20200101, 20210101]], [[19200101, 19210101], [19960101, 20000101]], [[17800101, 17810101], [17820101, 17830101], [17890101, 17900101]], [[17580101, 17590101]], [[20010101, 20020101]], [[18860101, 18870101], [18861018, 18861023]], [[19840101, 19850101]], [[20070101, 20080101], [20101201, 20110101]], [[19300101, 19310101], [19300101, 19400101]], [[19980531, 19980601], [19980925, 19980927]], [[20050101, 20060101], [20040101, 20050101]], [[20080101, 20090101]], [[18890101,18900101], [18960101, 18970101]], [[20090101, 20100101]], [[19410101, 19420101], [19420101, 19430101]], [[20120501, 20120601], [20121101, 20121201]], [[19700101, 19710101]], [[20080101, 20090101], [20100101, 20110101]], [[20160101, 20200101]], [[20110407, 20110408]], [[20160101, 20170101], [20110101, 20120101], [20160101, 20170101]], [[19400101, 19500101], [19500101, 19600101], [19600101, 19700101], [19430101, 19440101]], [[20211212, 20211213]], [[20110101, 20120101]], [], []] -------------------------------------------------------------------------------- /src/pretrain_t5/data_collator_for_t5.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | PreTrainedTokenizerBase, 3 | BatchEncoding, 4 | ) 5 | import numpy as np 6 | from typing import List, Dict 7 | from dataclasses import dataclass 8 | import math 9 | 10 | 11 | 12 | # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right 13 | def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray: 14 | """ 15 | Shift input ids one token to the right. 16 | """ 17 | shifted_input_ids = np.zeros_like(input_ids) 18 | shifted_input_ids[:, 1:] = input_ids[:, :-1] 19 | shifted_input_ids[:, 0] = decoder_start_token_id 20 | 21 | shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) 22 | return shifted_input_ids 23 | 24 | 25 | @dataclass 26 | class FlaxDataCollatorForT5MLM: 27 | """ 28 | Data collator used for T5 span-masked language modeling. 29 | It is made sure that after masking the inputs are of length `data_args.max_seq_length` and targets are also of fixed length. 30 | For more information on how T5 span-masked language modeling works, one can take a look 31 | at the `official paper `__ 32 | or the `official code for preprocessing `__ . 33 | Args: 34 | tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): 35 | The tokenizer used for encoding the data. 36 | noise_density (:obj:`float`): 37 | The probability with which to (randomly) mask tokens in the input. 38 | mean_noise_span_length (:obj:`float`): 39 | The average span length of the masked tokens. 40 | input_length (:obj:`int`): 41 | The expected input length after masking. 42 | target_length (:obj:`int`): 43 | The expected target length after masking. 44 | pad_token_id: (:obj:`int`): 45 | The pad token id of the model 46 | decoder_start_token_id: (:obj:`int): 47 | The decoder start token id of the model 48 | """ 49 | 50 | tokenizer: PreTrainedTokenizerBase 51 | noise_density: float 52 | mean_noise_span_length: float 53 | input_length: int 54 | target_length: int 55 | pad_token_id: int 56 | decoder_start_token_id: int 57 | num_add_special_tokens: int 58 | time_token_masking: bool 59 | time_mask_prob: float 60 | time_relation_prediction: bool 61 | time_cls_token_id: int 62 | scale_num_trelation_labels: float 63 | 64 | def __call__(self, examples: List[Dict[str, np.ndarray]]) -> BatchEncoding: 65 | if "cls_positions" in examples[0].keys(): 66 | _cls_positions = [np.array(examples[i].pop("cls_positions")) for i in range(len(examples))] 67 | if "time_relation_labels" in examples[0].keys(): 68 | _time_relation_labels = [np.array(examples[i].pop("time_relation_labels")) for i in range(len(examples))] 69 | if "time_tokens_mask" in examples[0].keys(): 70 | _time_tokens_mask = [np.array(examples[i].pop("time_tokens_mask")) for i in range(len(examples))] 71 | 72 | # convert list to dict and tensorize input 73 | batch = BatchEncoding( 74 | {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()} 75 | ) 76 | 77 | input_ids = batch["input_ids"] 78 | batch_size, expandend_input_length = input_ids.shape 79 | 80 | mask_indices = np.asarray([self.random_spans_noise_mask(expandend_input_length) for i in range(batch_size)]) 81 | 82 | 83 | # if not self.time_token_masking: 84 | # mask_indices = np.asarray([self.random_spans_noise_mask(expandend_input_length) for i in range(batch_size)]) 85 | # else: # Time Token Masking 86 | # mask_indices = np.asarray([self.span_corruption_mask(expandend_input_length, _time_tokens_mask[i]) for i in range(batch_size)]) 87 | ######################################## 88 | labels_mask = ~mask_indices 89 | 90 | input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8)) 91 | labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8)) 92 | 93 | batch["input_ids"] = self.filter_input_ids(input_ids, input_ids_sentinel) 94 | labels_with_padding_token_ids = self.filter_input_ids(input_ids, labels_sentinel) 95 | batch["labels"] = np.where(labels_with_padding_token_ids == self.pad_token_id, -100, labels_with_padding_token_ids) 96 | 97 | if self.time_token_masking and batch["input_ids"].shape[-1] == self.input_length + 1: 98 | batch["input_ids"] = batch["input_ids"][:, :-1] 99 | 100 | if batch["input_ids"].shape[-1] != self.input_length: 101 | raise ValueError( 102 | f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but" 103 | f" should be {self.input_length}." 104 | ) 105 | 106 | if batch["labels"].shape[-1] != self.target_length: 107 | raise ValueError( 108 | f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be" 109 | f" {self.target_length}." 110 | ) 111 | 112 | # Time Relation Prediction 113 | if self.time_relation_prediction: 114 | # remove those [TIME] tokens that are masked 115 | mask_indices = mask_indices.astype(np.int8) 116 | for batch_idx in range(len(_cls_positions)): 117 | try: 118 | if (_time_relation_labels[batch_idx] == -100).all(): 119 | _time_relation_labels[batch_idx] = np.array([[-100]]) 120 | _cls_positions[batch_idx] = np.array([0]) 121 | continue 122 | for cls_idx in range(len(_cls_positions[batch_idx])): 123 | if _cls_positions[batch_idx][cls_idx] in np.where(mask_indices[batch_idx] == 1)[0]: 124 | _cls_positions[batch_idx][cls_idx] = -1 125 | truncated_cls_indices = np.where(_cls_positions[batch_idx] != -1)[0] 126 | curr_time_relation_labels = [] 127 | for i in range(len(truncated_cls_indices)): 128 | curr_time_relation_labels.append([_time_relation_labels[batch_idx][truncated_cls_indices[i], j] for j in truncated_cls_indices]) 129 | _time_relation_labels[batch_idx] = np.array(curr_time_relation_labels) 130 | _cls_positions[batch_idx] = np.where(batch["input_ids"][batch_idx] == self.time_cls_token_id)[0] 131 | assert _time_relation_labels[batch_idx].shape[0] == len(_cls_positions[batch_idx]),f"{_time_relation_labels[batch_idx].shape[0]} != {len(_cls_positions[batch_idx])}" 132 | assert _time_relation_labels[batch_idx].shape[1] == len(_cls_positions[batch_idx]),f"{_time_relation_labels[batch_idx].shape[1]} != {len(_cls_positions[batch_idx])}" 133 | except Exception as e: 134 | print(e) 135 | # print() 136 | # import pdb; pdb.set_trace() 137 | _time_relation_labels[batch_idx] = np.array([[-100]]) 138 | _cls_positions[batch_idx] = np.array([0]) 139 | continue 140 | # pad cls_positions 141 | MAX_NUM_CLS = 10 142 | num_cls = [min(len(cls_position), MAX_NUM_CLS) for cls_position in _cls_positions] 143 | max_num_cls = max(num_cls) 144 | max_num_cls = max(max_num_cls, 1) 145 | # max_num_cls = min(max_num_cls, MAX_NUM_CLS) 146 | cls_positions = np.full([len(_cls_positions), max_num_cls], 0) 147 | for i in range(len(_cls_positions)): 148 | cls_positions[i, :min(len(_cls_positions[i]), MAX_NUM_CLS)] = _cls_positions[i][:min(len(_cls_positions[i]), MAX_NUM_CLS)] 149 | time_relation_labels = np.full([len(_time_relation_labels), max_num_cls, max_num_cls], -100) 150 | if self.scale_num_trelation_labels != 0.0: 151 | for i in range(len(_time_relation_labels)): 152 | time_relation_labels[i, : min(_time_relation_labels[i].shape[0], MAX_NUM_CLS), : min(_time_relation_labels[i].shape[1], MAX_NUM_CLS)] \ 153 | = _time_relation_labels[i][: min(_time_relation_labels[i].shape[0], MAX_NUM_CLS), : min(_time_relation_labels[i].shape[1], MAX_NUM_CLS)] 154 | if self.scale_num_trelation_labels != -1.0 and self.scale_num_trelation_labels != 0.0: 155 | if self.scale_num_trelation_labels == 0.1: 156 | print("##################") 157 | print('Warning: "scale_num_trelation_labels = 0.1" means that there is only one TRC label in one instance. ') 158 | print("##################") 159 | 160 | mask = np.ones(time_relation_labels.shape, dtype=np.bool) 161 | for i in range(len(time_relation_labels)): 162 | for j in range(max_num_cls): 163 | non_null_pairs = [(j,k) for k in range(max_num_cls) if time_relation_labels[i, j, k] != -100 and j != k] 164 | if len(non_null_pairs) > 0: 165 | chosen_pair = np.random.permutation(non_null_pairs)[0] 166 | mask[i, chosen_pair[0], chosen_pair[1]] = False 167 | time_relation_labels[mask] = -100 168 | else: 169 | num_trelation_labels = int(min(self.scale_num_trelation_labels, max_num_cls) * max_num_cls * len(_time_relation_labels)) 170 | time_relation_labels = time_relation_labels.reshape(-1) 171 | sample_indices = np.random.choice(np.arange(time_relation_labels.shape[0]), size=num_trelation_labels, replace=False) 172 | mask = np.ones(time_relation_labels.shape, dtype=np.bool) 173 | mask[sample_indices] = False 174 | time_relation_labels[mask] = -100 175 | time_relation_labels = time_relation_labels.reshape([len(_time_relation_labels), max_num_cls, max_num_cls]) 176 | for i in range(len(_time_relation_labels)): 177 | for j in range(max_num_cls): 178 | time_relation_labels[i, j, j] = -100 179 | 180 | batch["time_cls_indices"] = np.array(cls_positions) 181 | batch["time_relation_labels"] = np.array(time_relation_labels) 182 | ######################################## 183 | 184 | # to check that tokens are correctly preprocessed, one can run `self.tokenizer.batch_decode(input_ids)` and `self.tokenizer.batch_decode(labels)` here... 185 | batch["decoder_input_ids"] = shift_tokens_right( 186 | labels_with_padding_token_ids, self.pad_token_id, self.decoder_start_token_id 187 | ) 188 | 189 | batch.pop("attention_mask", None) 190 | batch.pop("time_tokens_mask", None) 191 | batch.pop("special_tokens_mask", None) 192 | 193 | return batch.convert_to_tensors(tensor_type="pt") 194 | 195 | def create_sentinel_ids(self, mask_indices): 196 | """ 197 | Sentinel ids creation given the indices that should be masked. 198 | The start indices of each mask are replaced by the sentinel ids in increasing 199 | order. Consecutive mask indices to be deleted are replaced with `-1`. 200 | """ 201 | 202 | start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices 203 | start_indices[:, 0] = mask_indices[:, 0] 204 | 205 | sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices) 206 | sentinel_ids = np.where(sentinel_ids != 0, (len(self.tokenizer) - self.num_add_special_tokens - sentinel_ids), 0) 207 | sentinel_ids -= mask_indices - start_indices 208 | 209 | return sentinel_ids 210 | 211 | def filter_input_ids(self, input_ids, sentinel_ids): 212 | """ 213 | Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting. 214 | This will reduce the sequence length from `expanded_inputs_length` to `input_length`. 215 | """ 216 | batch_size = input_ids.shape[0] 217 | 218 | input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids) 219 | # input_ids tokens and sentinel tokens are >= 0, tokens < 0 are 220 | # masked tokens coming after sentinel tokens and should be removed 221 | input_ids = input_ids_full[input_ids_full >= 0].reshape((batch_size, -1)) 222 | input_ids = np.concatenate( 223 | [input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1 224 | ) 225 | return input_ids 226 | 227 | def random_spans_noise_mask(self, length): 228 | """This function is copy of `random_spans_helper `__ . 229 | Noise mask consisting of random spans of noise tokens. 230 | The number of noise tokens and the number of noise spans and non-noise spans 231 | are determined deterministically as follows: 232 | num_noise_tokens = round(length * noise_density) 233 | num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length) 234 | Spans alternate between non-noise and noise, beginning with non-noise. 235 | Subject to the above restrictions, all masks are equally likely. 236 | Args: 237 | length: an int32 scalar (length of the incoming token sequence) 238 | noise_density: a float - approximate density of output mask 239 | mean_noise_span_length: a number 240 | Returns: 241 | a boolean tensor with shape [length] 242 | """ 243 | 244 | orig_length = length 245 | 246 | num_noise_tokens = int(np.round(length * self.noise_density)) 247 | # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. 248 | num_noise_tokens = min(max(num_noise_tokens, 1), length - 1) 249 | num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length)) 250 | 251 | # avoid degeneracy by ensuring positive number of noise spans 252 | num_noise_spans = max(num_noise_spans, 1) 253 | num_nonnoise_tokens = length - num_noise_tokens 254 | 255 | # pick the lengths of the noise spans and the non-noise spans 256 | def _random_segmentation(num_items, num_segments): 257 | """Partition a sequence of items randomly into non-empty segments. 258 | Args: 259 | num_items: an integer scalar > 0 260 | num_segments: an integer scalar in [1, num_items] 261 | Returns: 262 | a Tensor with shape [num_segments] containing positive integers that add 263 | up to num_items 264 | """ 265 | mask_indices = np.arange(num_items - 1) < (num_segments - 1) 266 | np.random.shuffle(mask_indices) 267 | first_in_segment = np.pad(mask_indices, [[1, 0]]) 268 | segment_id = np.cumsum(first_in_segment) 269 | # count length of sub segments assuming that list is sorted 270 | _, segment_length = np.unique(segment_id, return_counts=True) 271 | return segment_length 272 | 273 | noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans) 274 | nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans) 275 | 276 | interleaved_span_lengths = np.reshape( 277 | np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2] 278 | ) 279 | span_starts = np.cumsum(interleaved_span_lengths)[:-1] 280 | span_start_indicator = np.zeros((length,), dtype=np.int8) 281 | span_start_indicator[span_starts] = True 282 | span_num = np.cumsum(span_start_indicator) 283 | is_noise = np.equal(span_num % 2, 1) 284 | 285 | return is_noise[:orig_length] 286 | 287 | def get_time_token_mask_prob(self, input_ids, time_tokens_mask, time_mask_prob=0.3, max_time_mask_ratio=0.1): 288 | """ 289 | Copy from ../pretrain/data_collator.py 290 | 291 | Get 0/1 labels for masked tokens with time-tokens-mask. 292 | Our goal is: 293 | to keep the total ratio of masks to be {15\%}; 294 | If time-tokens does not account for more than {max_time_mask_ratio}, just sample [MASK] by {time_mask_prob}; 295 | If time-tokens account for more than {max_time_mask_ratio}, we restrict the ratio to be {max_time_mask_ratio}, 296 | so that {0.15-max_time_mask_ratio} normal tokens can get masked; 297 | If the sequence is too short, at least {one} time-token gets masked. 298 | Return: 299 | final_time_mask_prob: 300 | probability to mask time-tokens 301 | final_mlm_prob: 302 | probability to mask normal tokens 303 | """ 304 | if self.tokenizer.pad_token_id in input_ids: 305 | input_ids = input_ids[:input_ids.index(self.tokenizer.pad_token_id)] 306 | # input_length = len(input_ids) - self.tokenizer.num_special_tokens_to_add(pair=False) 307 | input_length = len(input_ids) 308 | # time_tokens_length = sum(time_tokens_mask) 309 | time_tokens_length = np.sum(time_tokens_mask) 310 | if time_tokens_length > 0: 311 | try: 312 | if time_tokens_length / input_length <= max_time_mask_ratio: 313 | final_time_mask_prob = time_mask_prob 314 | else: 315 | final_time_mask_prob = time_mask_prob * (max_time_mask_ratio / (time_tokens_length / input_length)) 316 | if time_tokens_length*final_time_mask_prob < 1: 317 | final_time_mask_prob = 1 / time_tokens_length 318 | if input_length > time_tokens_length: 319 | final_mlm_prob = (self.mlm_probability*input_length - final_time_mask_prob*time_tokens_length) / (input_length - time_tokens_length) 320 | else: 321 | final_mlm_prob = 0.15 322 | return max(min(final_time_mask_prob, 1.0), 0.0), max(min(final_mlm_prob, 1.0), 0.0) 323 | except: 324 | return 0.0, self.mlm_probability 325 | else: 326 | return 0.0, self.mlm_probability 327 | 328 | def span_corruption_mask(self, input_length, time_tokens_mask, max_time_mask_ratio=0.1): 329 | """ 330 | Copy from https://github.com/joeljang/Pretraining_T5_custom_dataset/blob/master/pretrain.py#L398 331 | I use this function to generate ``time-tokens-mask'' & ``normal mask'' for the span corruption task. 332 | 30% of time-tokens should be masked, masked-time-tokens should not be more than 10% of all tokens. The rest should be normal tokens. 333 | """ 334 | if self.mean_noise_span_length != 3: 335 | raise NotImplementedError("Only mean_noise_span_length=3 is supported if time_token_masking=True. ") 336 | max_index = input_length 337 | mask = max_index * [0] 338 | # span_num = math.ceil(( max_index * self.noise_density ) / self.mean_noise_span_length ) 339 | num_noise_tokens = int(np.round(input_length * self.noise_density)) 340 | # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. 341 | num_noise_tokens = min(max(num_noise_tokens, 1), input_length - 1) 342 | span_num = int(np.round(num_noise_tokens / self.mean_noise_span_length)) 343 | 344 | # pick the lengths of the noise spans and the non-noise spans 345 | def _random_segmentation(num_items, num_segments): 346 | """Partition a sequence of items randomly into non-empty segments. 347 | Args: 348 | num_items: an integer scalar > 0 349 | num_segments: an integer scalar in [1, num_items] 350 | Returns: 351 | a Tensor with shape [num_segments] containing positive integers that add 352 | up to num_items 353 | """ 354 | mask_indices = np.arange(num_items - 1) < (num_segments - 1) 355 | np.random.shuffle(mask_indices) 356 | first_in_segment = np.pad(mask_indices, [[1, 0]]) 357 | segment_id = np.cumsum(first_in_segment) 358 | # count length of sub segments assuming that list is sorted 359 | _, segment_length = np.unique(segment_id, return_counts=True) 360 | return segment_length 361 | 362 | noise_span_lengths = _random_segmentation(num_noise_tokens, span_num) 363 | 364 | exclude = set([0]) 365 | num_time_tokens = sum(time_tokens_mask) 366 | time_tokens_indices = np.where(time_tokens_mask)[0] 367 | time_tokens_indices_set = set(time_tokens_indices) 368 | for _ in exclude: 369 | if _ in time_tokens_indices_set: 370 | time_tokens_indices_set.remove(_) 371 | time_mask_indices = [] 372 | already_mask_final = False 373 | for i in range(span_num): 374 | curr_noise_span_length = noise_span_lengths[i] 375 | while True: 376 | if not already_mask_final: 377 | # rand_num = max_index - curr_noise_span_length - 1 378 | rand_num = max_index - curr_noise_span_length 379 | already_mask_final = True 380 | elif len(time_tokens_indices_set) == 0 \ 381 | or len(time_mask_indices) / input_length >= max_time_mask_ratio \ 382 | or len(time_mask_indices) / num_time_tokens >= self.time_mask_prob: 383 | rand_num = np.random.randint(low=0, high=max_index) #Getting random number for mask index 384 | else: 385 | rand_num = np.random.choice(time_tokens_indices) 386 | if rand_num not in time_tokens_indices_set: 387 | continue 388 | time_tokens_indices_set.remove(rand_num) 389 | 390 | if all([rand_num + k not in exclude for k in range(curr_noise_span_length)]) and rand_num + curr_noise_span_length < max_index: 391 | span = [rand_num + k for k in range(curr_noise_span_length)] 392 | need_exclude = [] 393 | if rand_num + curr_noise_span_length < max_index: 394 | need_exclude.append(rand_num + curr_noise_span_length) 395 | if rand_num - 1 >= 0: 396 | need_exclude.append(rand_num - 1) 397 | for s in span: 398 | mask[s] = 1 399 | if s in time_tokens_indices: 400 | time_mask_indices.append(s) 401 | need_exclude.append(s) 402 | for _ in need_exclude: 403 | exclude.add(_) #Adding to exclude list 404 | if _ in time_tokens_indices_set: 405 | time_tokens_indices_set.remove(_) 406 | break 407 | # mask[-1] = 1 # not meaningful, just to reach label_length=114 408 | return np.array(mask, dtype=bool) -------------------------------------------------------------------------------- /src/pretrain_t5/generate_date_text.py: -------------------------------------------------------------------------------- 1 | """ 2 | List of Date Formats 3 | """ 4 | 5 | import random 6 | import datetime 7 | 8 | MONTH_DAYS = { 9 | 1: 31, 10 | 2: 28, 11 | 3: 31, 12 | 4: 30, 13 | 5: 31, 14 | 6: 30, 15 | 7: 31, 16 | 8: 31, 17 | 9: 30, 18 | 10: 31, 19 | 11: 30, 20 | 12: 31, 21 | } 22 | 23 | MONTHS = { 24 | 1: ["January", "Jan", "01"], 25 | 2: ["February", "Feb", "02"], 26 | 3: ["March", "Mar", "03"], 27 | 4: ["April", "Apr", "04"], 28 | 5: ["May", "May", "05"], 29 | 6: ["June", "Jun", "06"], 30 | 7: ["July", "Jul", "07"], 31 | 8: ["August", "Aug", "08"], 32 | 9: ["September", "Sep", "09"], 33 | 10: ["October", "Oct", "10"], 34 | 11: ["November", "Nov", "11"], 35 | 12: ["December", "Dec", "12"], 36 | } 37 | 38 | RULES = [ 39 | "MM/DD/YY", 40 | "DD/MM/YY", 41 | "YY/MM/DD", 42 | "Month D, Yr", 43 | "M/D/YY", 44 | "D/M/YY", 45 | "YY/M/D", 46 | "bM/bD/YY", 47 | "YY/bM/bD", 48 | "MMDDYY", 49 | "DDMMYY", 50 | 'YYMMDD', 51 | "MonDDYY", 52 | "DDMonYY", 53 | "YYMonDD", 54 | "D Month, Yr", 55 | "Yr, Month D", 56 | "Mon-DD-YYYY", 57 | "DD-Mon-YYYY", 58 | "YYYY-Mon-DD", 59 | "Mon DD, YYYY", 60 | "DD Mon, YYYY", 61 | "YYYY, Mon DD", 62 | ] 63 | # "day/YY", 64 | 65 | YEAR_FORMATS=[ 66 | "YYYY", 67 | "YY", 68 | "Yr", 69 | ] 70 | 71 | MONTH_FORMATS = [ 72 | "MM", 73 | "Month", 74 | "Mon", 75 | "bM", 76 | "M", 77 | ] 78 | 79 | DAY_FORMATS = [ 80 | "DD", 81 | "bD", 82 | "D", 83 | ] 84 | 85 | def date_format_transform(date): 86 | date = str(date) 87 | return f"{date[:4]}-{date[4:6]}-{date[6:]}" 88 | 89 | 90 | def covert_date_to_text(year, month=None, day=None): 91 | assert year is not None 92 | if not month: 93 | month = 0 94 | if not day: 95 | day = 0 96 | year, month, day = int(year), int(month), int(day) 97 | 98 | # assert month >= 1 and month <= 12 and day <= 31 99 | 100 | year_signal, month_signal, day_signal = False, False, False 101 | rule = random.choice(RULES) 102 | 103 | for year_fromat in YEAR_FORMATS: 104 | if year_fromat in rule: 105 | rule = rule.replace(year_fromat, str(year)) 106 | year_signal = True 107 | break 108 | 109 | for day_format in DAY_FORMATS: 110 | if day_format in rule: 111 | if day_format == "DD": 112 | if day == 0 or month == 0: 113 | rule = rule.replace(day_format, "") 114 | else: 115 | day_str = '0'+str(day) if len(str(day)) == 1 else str(day) 116 | rule = rule.replace(day_format, day_str) 117 | elif day_format == "bD": 118 | if day == 0 or month == 0: 119 | rule = rule.replace(day_format, "") 120 | else: 121 | day_str = ' '+str(day) if len(str(day)) == 1 else str(day) 122 | rule = rule.replace(day_format, day_str) 123 | elif day_format == "D": 124 | if day == 0 or month == 0: 125 | rule = rule.replace(day_format, "") 126 | else: 127 | rule = rule.replace(day_format, str(day)) 128 | day_signal = True 129 | break 130 | 131 | 132 | for month_format in MONTH_FORMATS: 133 | if month_format in rule: 134 | if month_format == "MM": 135 | if month == 0: 136 | rule = rule.replace(month_format, "") 137 | else: 138 | rule = rule.replace(month_format, str(MONTHS[month][2])) 139 | elif month_format == "Month": 140 | if month == 0: 141 | rule = rule.replace(month_format, "") 142 | else: 143 | rule = rule.replace(month_format, str(MONTHS[month][0])) 144 | elif month_format == "Mon": 145 | if month == 0: 146 | rule = rule.replace(month_format, "") 147 | else: 148 | rule = rule.replace(month_format, str(MONTHS[month][1])) 149 | elif month_format == "bM": 150 | if month == 0: 151 | rule = rule.replace(month_format, "") 152 | else: 153 | month_str = (" " + str(int(MONTHS[month][2])))[-2:] 154 | rule = rule.replace(month_format, month_str) 155 | elif month_format == "M": 156 | if month == 0: 157 | rule = rule.replace(month_format, "") 158 | else: 159 | rule = rule.replace(month_format, str(int(MONTHS[month][2]))) 160 | month_signal = True 161 | break 162 | 163 | if year_signal and month_signal and day_signal: 164 | return rule 165 | return None 166 | 167 | def generate_random_date_inside_span(span_start, span_end): 168 | # start_date = datetime.date.fromisoformat(date_format_transform(span_start)) 169 | # end_date = datetime.date.fromisoformat(date_format_transform(span_end)) 170 | random_date = datetime.date.fromordinal( 171 | random.randint(span_start, span_end) 172 | ) 173 | if random.random() < 0.5: 174 | input_month = None 175 | input_day = None 176 | year, month, day = random_date.isoformat().split("-") 177 | success = False 178 | next_day = random_date.day 179 | while not success: 180 | try: 181 | next_year = random_date.replace(year=int(year)+1, day=next_day).isoformat().split("-")[0] 182 | success = True 183 | except: 184 | next_day -= 1 185 | continue 186 | date_span = (f"{year}0101", f"{next_year}0101") 187 | date_text = covert_date_to_text(year, month=None, day=None) 188 | else: 189 | if random.random() < 0.5: 190 | input_day = None 191 | year, month, day = random_date.isoformat().split("-") 192 | if int(month) < 12: 193 | next_year = year 194 | next_month = random_date.replace(month=int(month)+1, day=min(MONTH_DAYS[int(month)+1], random_date.day)).isoformat().split("-")[1] 195 | elif int(month) == 12: 196 | #next_year = random_date.replace(year=int(year)+1).isoformat().split("-")[0] 197 | success = False 198 | assign_day = random_date.day 199 | while not success: 200 | try: 201 | next_year = random_date.replace(year=int(year)+1, day=assign_day).isoformat().split("-")[0] 202 | success = True 203 | except ValueError: 204 | assign_day -= 1 205 | assign_day = max(1, assign_day) 206 | continue 207 | next_month = "01" 208 | date_span = (f"{year}{month}01", f"{next_year}{next_month}01") 209 | date_text = covert_date_to_text(year, month=month, day=None) 210 | else: 211 | year, month, day = random_date.isoformat().split("-") 212 | span_a = ''.join(random_date.isoformat().split("-")) 213 | span_b = ''.join( 214 | datetime.date.fromordinal(random_date.toordinal()+1).isoformat().split("-") 215 | ) 216 | date_span = (span_a, span_b) 217 | date_text = covert_date_to_text(year, month=month, day=day) 218 | return date_text, date_span 219 | 220 | 221 | 222 | 223 | if __name__ == "__main__": 224 | for i in range(20): 225 | print( 226 | covert_date_to_text( 227 | random.randint(1700, 2022), 228 | month=random.randint(1,12), 229 | day=random.randint(1,31), 230 | ) 231 | ) 232 | -------------------------------------------------------------------------------- /src/pretrain_t5/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def get_lr_scheduler(optimizer, args, logger): 4 | if args.optim.lr_scheduler == 'cosine': 5 | from torch.optim.lr_scheduler import ( 6 | SequentialLR, 7 | LinearLR, 8 | CosineAnnealingLR, 9 | ) 10 | 11 | scheduler1 = LinearLR( 12 | optimizer, 13 | start_factor=0.5, 14 | end_factor=1, 15 | total_iters=args.optim.warmup_steps, 16 | last_epoch=-1, 17 | ) 18 | 19 | scheduler2 = CosineAnnealingLR( 20 | optimizer, 21 | T_max=args.optim.total_steps - args.optim.warmup_steps, 22 | eta_min=args.optim.final_cosine, 23 | ) 24 | 25 | lr_scheduler = SequentialLR( 26 | optimizer, 27 | schedulers=[scheduler1, scheduler2], 28 | milestones=[args.optim.warmup_steps] 29 | ) 30 | elif args.optim.lr_scheduler == 'legacy': 31 | import math 32 | from torch.optim.lr_scheduler import ( 33 | SequentialLR, 34 | LinearLR, 35 | LambdaLR, 36 | ) 37 | 38 | msg = "You are using T5 legacy LR Schedule, it's independent from the optim.base_lr" 39 | logger.log_message(msg) 40 | 41 | num_steps_optimizer1 = math.ceil(args.optim.total_steps * 0.9) 42 | iters_left_for_optimizer2 = args.optim.total_steps - num_steps_optimizer1 43 | 44 | scheduler1 = LambdaLR( 45 | optimizer, 46 | lambda step: min( 47 | 1e-2, 1.0 / math.sqrt(step) 48 | ) / args.optim.base_lr if step else 1e-2 / args.optim.base_lr 49 | ) 50 | 51 | scheduler2 = LinearLR( 52 | optimizer, 53 | start_factor=( 54 | min(1e-2, 1.0 / math.sqrt(num_steps_optimizer1)) / args.optim.base_lr 55 | ), 56 | end_factor=0, 57 | total_iters=iters_left_for_optimizer2, 58 | last_epoch=-1, 59 | ) 60 | 61 | lr_scheduler = SequentialLR( 62 | optimizer, 63 | schedulers=[scheduler1, scheduler2], 64 | milestones=[num_steps_optimizer1] 65 | ) 66 | elif args.optim.lr_scheduler == 'constant': 67 | from transformers import get_scheduler 68 | lr_scheduler = get_scheduler( 69 | name=args.optim.lr_scheduler, 70 | optimizer=optimizer, 71 | ) 72 | else: 73 | raise NotImplementedError 74 | 75 | return -------------------------------------------------------------------------------- /src/pretrain_t5/optimizer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | import numpy as np 3 | from transformers import BatchEncoding 4 | from dataclasses import dataclass 5 | from transformers import AutoTokenizer 6 | import torch 7 | import math 8 | from torch.optim import Optimizer 9 | from typing import Iterable, Tuple 10 | from torch import nn 11 | import random 12 | import string 13 | 14 | class AdamWScale(Optimizer): 15 | """ 16 | This AdamW implementation is copied from Huggingface. 17 | We modified it with Adagrad scaling by rms of a weight tensor 18 | Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay 19 | Regularization](https://arxiv.org/abs/1711.05101). 20 | Parameters: 21 | params (`Iterable[nn.parameter.Parameter]`): 22 | Iterable of parameters to optimize or dictionaries defining parameter groups. 23 | lr (`float`, *optional*, defaults to 1e-3): 24 | The learning rate to use. 25 | betas (`Tuple[float,float]`, *optional*, defaults to (0.9, 0.999)): 26 | Adam's betas parameters (b1, b2). 27 | eps (`float`, *optional*, defaults to 1e-6): 28 | Adam's epsilon for numerical stability. 29 | weight_decay (`float`, *optional*, defaults to 0): 30 | Decoupled weight decay to apply. 31 | correct_bias (`bool`, *optional*, defaults to `True`): 32 | Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). 33 | no_deprecation_warning (`bool`, *optional*, defaults to `False`): 34 | A flag used to disable the deprecation warning (set to `True` to disable the warning). 35 | """ 36 | 37 | def __init__( 38 | self, 39 | params: Iterable[nn.parameter.Parameter], 40 | lr: float = 1e-3, 41 | betas: Tuple[float, float] = (0.9, 0.999), 42 | eps: float = 1e-6, 43 | weight_decay: float = 0.0, 44 | correct_bias: bool = True, 45 | ): 46 | if lr < 0.0: 47 | raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") 48 | if not 0.0 <= betas[0] < 1.0: 49 | raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)") 50 | if not 0.0 <= betas[1] < 1.0: 51 | raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)") 52 | if not 0.0 <= eps: 53 | raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") 54 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias) 55 | super().__init__(params, defaults) 56 | 57 | @staticmethod 58 | def _rms(tensor): 59 | return tensor.norm(2) / (tensor.numel() ** 0.5) 60 | 61 | def step(self, closure=None): 62 | """ 63 | Performs a single optimization step. 64 | Arguments: 65 | closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. 66 | """ 67 | loss = None 68 | if closure is not None: 69 | loss = closure() 70 | 71 | for group in self.param_groups: 72 | for p in group["params"]: 73 | if p.grad is None: 74 | continue 75 | grad = p.grad.data 76 | if grad.is_sparse: 77 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") 78 | 79 | state = self.state[p] 80 | beta1, beta2 = group["betas"] 81 | 82 | # State initialization 83 | if len(state) == 0: 84 | state["step"] = 0 85 | # Exponential moving average of gradient values 86 | state["exp_avg"] = torch.zeros_like(p.data) 87 | # Exponential moving average of squared gradient values 88 | state["exp_avg_sq"] = torch.zeros_like(p.data) 89 | 90 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 91 | 92 | state["step"] += 1 93 | 94 | # Decay the first and second moment running average coefficient 95 | # In-place operations to update the averages at the same time 96 | exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) 97 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) 98 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 99 | 100 | step_size = group["lr"] 101 | if group["correct_bias"]: # No bias correction for Bert 102 | bias_correction1 = 1.0 - beta1 ** state["step"] 103 | bias_correction2 = 1.0 - beta2 ** state["step"] 104 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 105 | 106 | # /Adapt Step from Adagrad 107 | step_size = step_size * max(1e-3, self._rms(p.data)) 108 | # /Adapt Step from Adagrad 109 | 110 | p.data.addcdiv_(exp_avg, denom, value=-step_size) 111 | 112 | # Just adding the square of the weights to the loss function is *not* 113 | # the correct way of using L2 regularization/weight decay with Adam, 114 | # since that will interact with the m and v parameters in strange ways. 115 | # 116 | # Instead we want to decay the weights in a manner that doesn't interact 117 | # with the m/v parameters. This is equivalent to adding the square 118 | # of the weights to the loss with plain (non-momentum) SGD. 119 | # Add weight decay at the end (fixed version) 120 | if group["weight_decay"] > 0.0: 121 | p.data.add_(p.data, alpha=(-group["lr"] * group["weight_decay"])) 122 | 123 | return loss 124 | 125 | 126 | def get_optimizer(model, args): 127 | no_decay = ["bias", "LayerNorm", "layernorm", "layer_norm", "ln"] 128 | 129 | optimizer_grouped_parameters = [ 130 | { 131 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 132 | "weight_decay": args.optim.weight_decay, 133 | }, 134 | { 135 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 136 | "weight_decay": 0.0, 137 | }, 138 | ] 139 | 140 | if args.optim.name == 'adamw': 141 | from transformers import AdamW 142 | optimizer = AdamW( 143 | optimizer_grouped_parameters, 144 | lr=args.optim.base_lr, 145 | ) 146 | elif args.optim.name == 'adamwscale': 147 | from .copied_utils import AdamWScale 148 | optimizer = AdamWScale( 149 | optimizer_grouped_parameters, 150 | lr=args.optim.base_lr, 151 | ) 152 | elif args.optim.name == 'adafactor': 153 | from transformers import Adafactor 154 | optimizer = Adafactor( 155 | optimizer_grouped_parameters, 156 | lr=args.optim.base_lr, 157 | relative_step=False, 158 | ) 159 | else: 160 | raise NotImplementedError 161 | 162 | return optimizer -------------------------------------------------------------------------------- /src/pretrain_t5/run_pretrain.sh: -------------------------------------------------------------------------------- 1 | 2 | export NSP_MODE=mlm_trelation # mode=RemeMo 3 | # export NSP_MODE=mlm # mode=LM 4 | 5 | ###################### Run ID ########################### 6 | export RUN_ID=debug_pretrain 7 | 8 | ###################### Corpora File Paths ########################### 9 | export DATA_DIR=data/pretrain 10 | export WIKI_PATH=${DATA_DIR}/enwiki-20221101.json 11 | export BOOKCORPUS_PATH=${DATA_DIR}/bookcorpus.json 12 | export WIKI_TEMPORAL_CONTINUOUS_PATH=${DATA_DIR}/enwiki-20221101_temporal-sentences_special-token-prefix.json 13 | 14 | ###################### Pretrain Dataset Configuration ###################### 15 | 16 | export TRAIN_FILES=${WIKI_PATH},${BOOKCORPUS_PATH},${WIKI_TEMPORAL_CONTINUOUS_PATH} 17 | export DATA_USED=wiki_books_twikiCon 18 | 19 | ###################### GPU & Batch-size Configuration ###################### 20 | ######## T5 ####### 21 | export BATCH_SIZE=2048 22 | export ADAM_EPSILON=1e-6 23 | 24 | export OPTIMIZER=adafactor 25 | # export OPTIMIZER=adamw_torch 26 | 27 | # export LR_SCHEDULER=linear 28 | export LR_SCHEDULER=cosine 29 | 30 | # export LEARNING_RATE=1e-3 # t5-small 31 | # export LEARNING_RATE=5e-4 # t5-base 32 | export LEARNING_RATE=3e-4 # t5-large 33 | 34 | # export BATCH_SIZE_PER_GPU=90 # t5_v1.1-small model, on A6000-48G, with --bf16, with MAX_NUM_CLS=10 35 | # export BATCH_SIZE_PER_GPU=128 # t5_small model, on A6000-48G, with --bf16, with MAX_NUM_CLS=10 36 | export BATCH_SIZE_PER_GPU=32 # t5_small model, debug 37 | # export BATCH_SIZE_PER_GPU=16 # t5_large model, on A6000-48G, with --bf16, with MAX_NUM_CLS=10 38 | 39 | export GPU_IDX="0" 40 | export NUM_OF_GPUS=2 41 | # export NUM_OF_GPUS=$(nvidia-smi --list-gpus | wc -l) 42 | export CUDA_VISIBLE_DEVICES=${GPU_IDX} 43 | 44 | export GRADIENT_ACCUMULATION_STEPS=$(($(( BATCH_SIZE / BATCH_SIZE_PER_GPU))/NUM_OF_GPUS)) 45 | 46 | ###################### Which Pretrained-LM ###################### 47 | # export PLM_DIR="../pretrained_models/" 48 | export PLM_DIR="google" 49 | export MODEL_NAME_OR_PATH="t5-v1_1-base" 50 | 51 | ###################### Logging Configuration ###################### 52 | export WANDB_PROJECT=RemeMo-pretrain 53 | 54 | # {'wandb', 'none'} 55 | # export REPORT_TO=none 56 | export REPORT_TO=wandb 57 | 58 | ###################### Start Training ###################### 59 | export RUN_NAME=${MODEL_NAME_OR_PATH}.${DATA_USED}.${NSP_MODE}.${OPTIMIZER}.${LR_SCHEDULER}.run-${RUN_ID} 60 | 61 | # python -m torch.distributed.launch --nproc_per_node ${NUM_OF_GPUS} run_seq2seq.py \ 62 | python run_seq2seq.py \ 63 | --train_files ${TRAIN_FILES} \ 64 | --model_name_or_path ${PLM_DIR}${MODEL_NAME_OR_PATH} \ 65 | --max_source_length 512 \ 66 | --per_device_train_batch_size ${BATCH_SIZE_PER_GPU} \ 67 | --gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEPS} \ 68 | --max_steps 8000 \ 69 | --learning_rate ${LEARNING_RATE} \ 70 | --warmup_ratio 0.1 \ 71 | --weight_decay 0.01 \ 72 | --adam_beta1 0.9 \ 73 | --adam_beta2 0.98 \ 74 | --adam_epsilon ${ADAM_EPSILON} \ 75 | --report_to ${REPORT_TO} \ 76 | --run_name ${RUN_NAME} \ 77 | --nsp_mode ${NSP_MODE} \ 78 | --do_train \ 79 | --remove_unused_columns false \ 80 | --logging_steps 2 \ 81 | --logging_first_step true \ 82 | --output_dir log/${RUN_NAME} \ 83 | --save_strategy steps \ 84 | --save_steps 500 \ 85 | --max_grad_norm 1.0 \ 86 | --optim ${OPTIMIZER} \ 87 | --lr_scheduler_type ${LR_SCHEDULER} \ 88 | 2>&1 | tee log/${RUN_NAME}.txt 89 | 90 | # --bf16 -------------------------------------------------------------------------------- /src/pretrain_t5/sampler.py: -------------------------------------------------------------------------------- 1 | # modified based on https://github.com/catalyst-team/catalyst/raw/ea3fadbaa6034dabeefbbb53ab8c310186f6e5d0/catalyst/data/sampler.py 2 | 3 | from typing import Iterator, List, Optional, Union 4 | from collections import Counter 5 | import logging 6 | from operator import itemgetter 7 | from random import choices, sample 8 | 9 | import numpy as np 10 | 11 | import torch 12 | from torch.utils.data import DistributedSampler, Dataset 13 | from torch.utils.data.sampler import BatchSampler, Sampler 14 | 15 | # from catalyst.data.dataset.torch import DatasetFromSampler 16 | 17 | class DatasetFromSampler(Dataset): 18 | """Dataset to create indexes from `Sampler`. 19 | Args: 20 | sampler: PyTorch sampler 21 | """ 22 | 23 | def __init__(self, sampler: Sampler): 24 | """Initialisation for DatasetFromSampler.""" 25 | self.sampler = sampler 26 | self.sampler_list = None 27 | 28 | def __getitem__(self, index: int): 29 | """Gets element of the dataset. 30 | Args: 31 | index: index of the element in the dataset 32 | Returns: 33 | Single element by index 34 | """ 35 | if self.sampler_list is None: 36 | self.sampler_list = list(self.sampler) 37 | return self.sampler_list[index] 38 | 39 | def __len__(self) -> int: 40 | """ 41 | Returns: 42 | int: length of the dataset 43 | """ 44 | return len(self.sampler) 45 | 46 | class BalanceClassSampler(Sampler): 47 | """Allows you to create stratified sample on unbalanced classes. 48 | 49 | Args: 50 | labels: list of class label for each elem in the dataset 51 | mode: Strategy to balance classes. 52 | Must be one of [downsampling, upsampling] 53 | """ 54 | 55 | def __init__( 56 | self, labels: List[int], mode: Union[str, int] = "downsampling" 57 | ): 58 | """Sampler initialisation.""" 59 | super().__init__(labels) 60 | 61 | labels = np.array(labels) 62 | samples_per_class = { 63 | label: (labels == label).sum() for label in set(labels) 64 | } 65 | 66 | self.lbl2idx = { 67 | label: np.arange(len(labels))[labels == label].tolist() 68 | for label in set(labels) 69 | } 70 | 71 | if isinstance(mode, str): 72 | assert mode in ["downsampling", "upsampling"] 73 | 74 | if isinstance(mode, int) or mode == "upsampling": 75 | samples_per_class = ( 76 | mode 77 | if isinstance(mode, int) 78 | else max(samples_per_class.values()) 79 | ) 80 | else: 81 | samples_per_class = min(samples_per_class.values()) 82 | 83 | self.labels = labels 84 | self.samples_per_class = samples_per_class 85 | self.length = self.samples_per_class * len(set(labels)) 86 | 87 | def __iter__(self) -> Iterator[int]: 88 | """ 89 | Yields: 90 | indices of stratified sample 91 | """ 92 | indices = [] 93 | for key in sorted(self.lbl2idx): 94 | replace_flag = self.samples_per_class > len(self.lbl2idx[key]) 95 | indices += np.random.choice( 96 | self.lbl2idx[key], self.samples_per_class, replace=replace_flag 97 | ).tolist() 98 | assert len(indices) == self.length 99 | np.random.shuffle(indices) 100 | 101 | return iter(indices) 102 | 103 | def __len__(self) -> int: 104 | """ 105 | Returns: 106 | length of result sample 107 | """ 108 | return self.length 109 | 110 | class DistributedSamplerWrapper(DistributedSampler): 111 | """ 112 | Wrapper over `Sampler` for distributed training. 113 | Allows you to use any sampler in distributed mode. 114 | 115 | It is especially useful in conjunction with 116 | `torch.nn.parallel.DistributedDataParallel`. In such case, each 117 | process can pass a DistributedSamplerWrapper instance as a DataLoader 118 | sampler, and load a subset of subsampled data of the original dataset 119 | that is exclusive to it. 120 | 121 | .. note:: 122 | Sampler is assumed to be of constant size. 123 | """ 124 | 125 | def __init__( 126 | self, 127 | sampler, 128 | num_replicas: Optional[int] = None, 129 | rank: Optional[int] = None, 130 | shuffle: bool = True, 131 | ): 132 | """ 133 | 134 | Args: 135 | sampler: Sampler used for subsampling 136 | num_replicas (int, optional): Number of processes participating in 137 | distributed training 138 | rank (int, optional): Rank of the current process 139 | within ``num_replicas`` 140 | shuffle (bool, optional): If true (default), 141 | sampler will shuffle the indices 142 | """ 143 | super(DistributedSamplerWrapper, self).__init__( 144 | DatasetFromSampler(sampler), 145 | num_replicas=num_replicas, 146 | rank=rank, 147 | shuffle=shuffle, 148 | ) 149 | self.sampler = sampler 150 | 151 | def __iter__(self): 152 | """@TODO: Docs. Contribution is welcome.""" 153 | self.dataset = DatasetFromSampler(self.sampler) 154 | indexes_of_indexes = super().__iter__() 155 | subsampler_indexes = self.dataset 156 | return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) -------------------------------------------------------------------------------- /src/pretrain_t5/trainer_temporal.py: -------------------------------------------------------------------------------- 1 | from transformers.trainer import * 2 | from transformers import Seq2SeqTrainer 3 | import torch 4 | from sampler import BalanceClassSampler, DistributedSamplerWrapper 5 | 6 | class Seq2SeqTrainerForTemporalPretraining(Seq2SeqTrainer): 7 | def init_temporal( 8 | self, 9 | losses_keys=None, 10 | downsample_non_temporal_data=False, 11 | temporal_labels=None, 12 | ): 13 | self.losses_keys = losses_keys 14 | 15 | if downsample_non_temporal_data and temporal_labels is None: 16 | raise ValueError("``downsample_non_temporal_data'' requires ``temporal_labels'' to be given. ") 17 | 18 | self.downsample_non_temporal_data = downsample_non_temporal_data 19 | self.temporal_labels = temporal_labels 20 | self.log_global_step = 0 21 | 22 | def compute_loss(self, model, inputs, return_outputs=False): 23 | """ 24 | How the loss is computed by Trainer. By default, all models return the loss in the first element. 25 | 26 | Subclass and override for custom behavior. 27 | """ 28 | if self.label_smoother is not None and "labels" in inputs: 29 | labels = inputs.pop("labels") 30 | else: 31 | labels = None 32 | outputs = model(**inputs) 33 | # Save past state if it exists 34 | # TODO: this needs to be fixed and made cleaner later. 35 | if self.args.past_index >= 0: 36 | self._past = outputs[self.args.past_index] 37 | 38 | if labels is not None: 39 | if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): 40 | loss = self.label_smoother(outputs, labels, shift_labels=True) 41 | else: 42 | loss = self.label_smoother(outputs, labels) 43 | else: 44 | if isinstance(outputs, dict) and "loss" not in outputs: 45 | raise ValueError( 46 | "The model did not return a loss from the inputs, only the following keys: " 47 | f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." 48 | ) 49 | # We don't use .loss here since the model may return tuples instead of ModelOutput. 50 | loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] 51 | 52 | if self.log_global_step != self.state.global_step and \ 53 | (self.state.global_step % self.args.logging_steps == 0 or self.state.global_step == 1): 54 | self.log_global_step = self.state.global_step 55 | # loss_dict = { 56 | # "global_step": self.state.global_step, 57 | # } 58 | loss_dict = dict() 59 | for key in self.losses_keys: 60 | if key in outputs: 61 | if outputs[key] is not None and (outputs[key] != 0).sum() > 0: 62 | if outputs[key].dim() == 0: 63 | loss_dict[key] = round(outputs[key].detach().mean().item(), 4) 64 | else: 65 | loss_dict[key] = round(outputs[key][outputs[key].nonzero()].view(-1).detach().mean().item(), 4) 66 | else: 67 | loss_dict[key] = round(0.0, 4) 68 | self.log(logs=loss_dict) 69 | 70 | return (loss, outputs) if return_outputs else loss 71 | 72 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 73 | if self.train_dataset is None or not has_length(self.train_dataset): 74 | return None 75 | 76 | generator = None 77 | if self.args.world_size <= 1: 78 | generator = torch.Generator() 79 | # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with 80 | # `args.seed`) if data_seed isn't provided. 81 | # Further on in this method, we default to `args.seed` instead. 82 | if self.args.data_seed is None: 83 | seed = int(torch.empty((), dtype=torch.int64).random_().item()) 84 | else: 85 | seed = self.args.data_seed 86 | generator.manual_seed(seed) 87 | 88 | seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed 89 | 90 | # Build the sampler. 91 | if self.args.group_by_length: 92 | raise NotImplementedError("We do not use group_by_length for temporal-modeling. ") 93 | if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset): 94 | lengths = ( 95 | self.train_dataset[self.args.length_column_name] 96 | if self.args.length_column_name in self.train_dataset.column_names 97 | else None 98 | ) 99 | else: 100 | lengths = None 101 | model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None 102 | if self.args.world_size <= 1: 103 | return LengthGroupedSampler( 104 | self.args.train_batch_size * self.args.gradient_accumulation_steps, 105 | dataset=self.train_dataset, 106 | lengths=lengths, 107 | model_input_name=model_input_name, 108 | generator=generator, 109 | ) 110 | else: 111 | return DistributedLengthGroupedSampler( 112 | self.args.train_batch_size * self.args.gradient_accumulation_steps, 113 | dataset=self.train_dataset, 114 | num_replicas=self.args.world_size, 115 | rank=self.args.process_index, 116 | lengths=lengths, 117 | model_input_name=model_input_name, 118 | seed=seed, 119 | ) 120 | 121 | else: 122 | if self.args.world_size <= 1: 123 | if self.downsample_non_temporal_data: 124 | return BalanceClassSampler(self.temporal_labels, mode="downsampling") 125 | else: 126 | return RandomSampler(self.train_dataset, generator=generator) 127 | elif ( 128 | self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL] 129 | and not self.args.dataloader_drop_last 130 | ): 131 | raise NotImplementedError("This set of code (i.e., BalanceClassSampler) is not tested on TPU. ") 132 | # Use a loop for TPUs when drop_last is False to have all batches have the same size. 133 | return DistributedSamplerWithLoop( 134 | self.train_dataset, 135 | batch_size=self.args.per_device_train_batch_size, 136 | num_replicas=self.args.world_size, 137 | rank=self.args.process_index, 138 | seed=seed, 139 | ) 140 | else: 141 | if self.downsample_non_temporal_data: 142 | return DistributedSamplerWrapper(BalanceClassSampler(self.temporal_labels, mode="downsampling")) 143 | else: 144 | return DistributedSampler( 145 | self.train_dataset, 146 | num_replicas=self.args.world_size, 147 | rank=self.args.process_index, 148 | seed=seed, 149 | shuffle=True, # add shuffle 150 | ) 151 | 152 | def create_optimizer(self): 153 | """ 154 | Setup the optimizer. 155 | 156 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the 157 | Trainer's init through `optimizers`, or subclass and override this method in a subclass. 158 | """ 159 | opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model 160 | 161 | if self.optimizer is None: 162 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) 163 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 164 | optimizer_grouped_parameters = [ 165 | { 166 | "params": [p for n, p in opt_model.named_parameters() if n in decay_parameters], 167 | "weight_decay": self.args.weight_decay, 168 | }, 169 | { 170 | "params": [p for n, p in opt_model.named_parameters() if n not in decay_parameters], 171 | "weight_decay": 0.0, 172 | }, 173 | ] 174 | 175 | optimizer_cls, optimizer_kwargs = self.new_get_optimizer_cls_and_kwargs(self.args) 176 | 177 | if self.sharded_ddp == ShardedDDPOption.SIMPLE: 178 | self.optimizer = OSS( 179 | params=optimizer_grouped_parameters, 180 | optim=optimizer_cls, 181 | **optimizer_kwargs, 182 | ) 183 | else: 184 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 185 | if optimizer_cls.__name__ == "Adam8bit": 186 | import bitsandbytes 187 | 188 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance() 189 | 190 | for module in opt_model.modules(): 191 | if isinstance(module, nn.Embedding): 192 | manager.register_module_override(module, "weight", {"optim_bits": 32}) 193 | logger.debug(f"bitsandbytes: will optimize {module} in fp32") 194 | 195 | if is_sagemaker_mp_enabled(): 196 | self.optimizer = smp.DistributedOptimizer(self.optimizer) 197 | 198 | return self.optimizer 199 | 200 | def new_get_optimizer_cls_and_kwargs(self, args: TrainingArguments) -> Tuple[Any, Any]: 201 | """ 202 | Returns the optimizer class and optimizer parameters based on the training arguments. 203 | Args: 204 | args (`transformers.training_args.TrainingArguments`): 205 | The training arguments for the training session. 206 | """ 207 | 208 | # parse args.optim_args 209 | optim_args = {} 210 | if hasattr(args, "optim_args") and args.optim_args: 211 | for mapping in args.optim_args.replace(" ", "").split(","): 212 | key, value = mapping.split("=") 213 | optim_args[key] = value 214 | 215 | optimizer_kwargs = {"lr": args.learning_rate} 216 | 217 | adam_kwargs = { 218 | "betas": (args.adam_beta1, args.adam_beta2), 219 | "eps": args.adam_epsilon, 220 | } 221 | if args.optim == OptimizerNames.ADAFACTOR: 222 | optimizer_cls = Adafactor 223 | # optimizer_kwargs.update({"scale_parameter": False, "relative_step": False}) 224 | logger.info("Using adafactor with scale_parameter=True ! ") 225 | print("Using adafactor with scale_parameter=True ! ") 226 | optimizer_kwargs.update({"scale_parameter": True, "relative_step": False}) 227 | elif args.optim == OptimizerNames.ADAMW_HF: 228 | from transformers.optimization import AdamW 229 | 230 | optimizer_cls = AdamW 231 | optimizer_kwargs.update(adam_kwargs) 232 | elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]: 233 | from torch.optim import AdamW 234 | 235 | optimizer_cls = AdamW 236 | optimizer_kwargs.update(adam_kwargs) 237 | if args.optim == OptimizerNames.ADAMW_TORCH_FUSED: 238 | optimizer_kwargs.update({"fused": True}) 239 | elif args.optim == OptimizerNames.ADAMW_TORCH_XLA: 240 | try: 241 | from torch_xla.amp.syncfree import AdamW 242 | 243 | optimizer_cls = AdamW 244 | optimizer_kwargs.update(adam_kwargs) 245 | except ImportError: 246 | raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.") 247 | elif args.optim == OptimizerNames.ADAMW_APEX_FUSED: 248 | try: 249 | from apex.optimizers import FusedAdam 250 | 251 | optimizer_cls = FusedAdam 252 | optimizer_kwargs.update(adam_kwargs) 253 | except ImportError: 254 | raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!") 255 | elif args.optim == OptimizerNames.ADAMW_BNB: 256 | try: 257 | from bitsandbytes.optim import Adam8bit 258 | 259 | optimizer_cls = Adam8bit 260 | optimizer_kwargs.update(adam_kwargs) 261 | except ImportError: 262 | raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!") 263 | elif args.optim == OptimizerNames.ADAMW_ANYPRECISION: 264 | try: 265 | from torchdistx.optimizers import AnyPrecisionAdamW 266 | 267 | optimizer_cls = AnyPrecisionAdamW 268 | optimizer_kwargs.update(adam_kwargs) 269 | 270 | # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx. 271 | optimizer_kwargs.update( 272 | { 273 | "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")), 274 | "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")), 275 | "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")), 276 | "compensation_buffer_dtype": getattr( 277 | torch, optim_args.get("compensation_buffer_dtype", "bfloat16") 278 | ), 279 | } 280 | ) 281 | except ImportError: 282 | raise ValueError("Please install https://github.com/pytorch/torchdistx") 283 | elif args.optim == OptimizerNames.SGD: 284 | optimizer_cls = torch.optim.SGD 285 | elif args.optim == OptimizerNames.ADAGRAD: 286 | optimizer_cls = torch.optim.Adagrad 287 | else: 288 | raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") 289 | return optimizer_cls, optimizer_kwargs -------------------------------------------------------------------------------- /src/pretrain_t5/utils.py: -------------------------------------------------------------------------------- 1 | TIME_SPECIAL_TOKENS = { 2 | "bos_token": "[TIME]", 3 | "eos_token": "[/TIME]", 4 | } -------------------------------------------------------------------------------- /src/time_expression/inference_time_identification.sh: -------------------------------------------------------------------------------- 1 | 2 | export ROOT_DIR="../.." 3 | export CURR_INPUT="" 4 | export CURR_OUTPUT="" 5 | 6 | python token-classification/run_ner.py \ 7 | --model_name_or_path ${ROOT_DIR}/model_checkpoints/time_expression/roberta_for_time_identification \ 8 | --train_file ${ROOT_DIR}/data/time_expression/ner_task/train.json \ 9 | --validation_file ${ROOT_DIR}/data/time_expression/ner_task/val.json \ 10 | --text_column_name tokens \ 11 | --label_column_name ner_tags \ 12 | --per_device_eval_batch_size 128 \ 13 | --output_dir ${ROOT_DIR}/model_checkpoints/zzz/ \ 14 | --test_file ${CURR_INPUT} \ 15 | --predict_output ${CURR_OUTPUT} \ 16 | --report_to none \ 17 | --run_name predict-timeNER \ 18 | --max_seq_length 512 \ 19 | --do_predict 20 | -------------------------------------------------------------------------------- /src/time_expression/token-classification/README.md: -------------------------------------------------------------------------------- 1 | 16 | 17 | # Token classification 18 | 19 | ## PyTorch version 20 | 21 | Fine-tuning the library models for token classification task such as Named Entity Recognition (NER), Parts-of-speech 22 | tagging (POS) or phrase extraction (CHUNKS). The main scrip `run_ner.py` leverages the 🤗 Datasets library and the Trainer API. You can easily 23 | customize it to your needs if you need extra processing on your datasets. 24 | 25 | It will either run on a datasets hosted on our [hub](https://huggingface.co/datasets) or with your own text files for 26 | training and validation, you might just need to add some tweaks in the data preprocessing. 27 | 28 | The following example fine-tunes BERT on CoNLL-2003: 29 | 30 | ```bash 31 | python run_ner.py \ 32 | --model_name_or_path bert-base-uncased \ 33 | --dataset_name conll2003 \ 34 | --output_dir /tmp/test-ner \ 35 | --do_train \ 36 | --do_eval 37 | ``` 38 | 39 | or just can just run the bash script `run.sh`. 40 | 41 | To run on your own training and validation files, use the following command: 42 | 43 | ```bash 44 | python run_ner.py \ 45 | --model_name_or_path bert-base-uncased \ 46 | --train_file path_to_train_file \ 47 | --validation_file path_to_validation_file \ 48 | --output_dir /tmp/test-ner \ 49 | --do_train \ 50 | --do_eval 51 | ``` 52 | 53 | **Note:** This script only works with models that have a fast tokenizer (backed by the 🤗 Tokenizers library) as it 54 | uses special features of those tokenizers. You can check if your favorite model has a fast tokenizer in 55 | [this table](https://huggingface.co/transformers/index.html#supported-frameworks), if it doesn't you can still use the old version 56 | of the script. 57 | 58 | > If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it. 59 | 60 | ## Old version of the script 61 | 62 | You can find the old version of the PyTorch script [here](https://github.com/huggingface/transformers/blob/main/examples/legacy/token-classification/run_ner.py). 63 | 64 | ## Pytorch version, no Trainer 65 | 66 | Based on the script [run_ner_no_trainer.py](https://github.com/huggingface/transformers/blob/main/examples/pytorch/token-classification/run_ner_no_trainer.py). 67 | 68 | Like `run_ner.py`, this script allows you to fine-tune any of the models on the [hub](https://huggingface.co/models) on a 69 | token classification task, either NER, POS or CHUNKS tasks or your own data in a csv or a JSON file. The main difference is that this 70 | script exposes the bare training loop, to allow you to quickly experiment and add any customization you would like. 71 | 72 | It offers less options than the script with `Trainer` (for instance you can easily change the options for the optimizer 73 | or the dataloaders directly in the script) but still run in a distributed setup, on TPU and supports mixed precision by 74 | the mean of the [🤗 `Accelerate`](https://github.com/huggingface/accelerate) library. You can use the script normally 75 | after installing it: 76 | 77 | ```bash 78 | pip install git+https://github.com/huggingface/accelerate 79 | ``` 80 | 81 | then 82 | 83 | ```bash 84 | export TASK_NAME=ner 85 | 86 | python run_ner_no_trainer.py \ 87 | --model_name_or_path bert-base-cased \ 88 | --dataset_name conll2003 \ 89 | --task_name $TASK_NAME \ 90 | --max_length 128 \ 91 | --per_device_train_batch_size 32 \ 92 | --learning_rate 2e-5 \ 93 | --num_train_epochs 3 \ 94 | --output_dir /tmp/$TASK_NAME/ 95 | ``` 96 | 97 | You can then use your usual launchers to run in it in a distributed environment, but the easiest way is to run 98 | 99 | ```bash 100 | accelerate config 101 | ``` 102 | 103 | and reply to the questions asked. Then 104 | 105 | ```bash 106 | accelerate test 107 | ``` 108 | 109 | that will check everything is ready for training. Finally, you can launch training with 110 | 111 | ```bash 112 | export TASK_NAME=ner 113 | 114 | accelerate launch run_ner_no_trainer.py \ 115 | --model_name_or_path bert-base-cased \ 116 | --dataset_name conll2003 \ 117 | --task_name $TASK_NAME \ 118 | --max_length 128 \ 119 | --per_device_train_batch_size 32 \ 120 | --learning_rate 2e-5 \ 121 | --num_train_epochs 3 \ 122 | --output_dir /tmp/$TASK_NAME/ 123 | ``` 124 | 125 | This command is the same and will work for: 126 | 127 | - a CPU-only setup 128 | - a setup with one GPU 129 | - a distributed training with several GPUs (single or multi node) 130 | - a training on TPUs 131 | 132 | Note that this library is in alpha release so your feedback is more than welcome if you encounter any problem using it. 133 | -------------------------------------------------------------------------------- /src/time_expression/token-classification/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | seqeval 3 | datasets >= 1.8.0 4 | torch >= 1.3 5 | evaluate -------------------------------------------------------------------------------- /src/time_expression/token-classification/run.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | python3 run_ner.py \ 16 | --model_name_or_path bert-base-uncased \ 17 | --dataset_name conll2003 \ 18 | --output_dir /tmp/test-ner \ 19 | --do_train \ 20 | --do_eval 21 | -------------------------------------------------------------------------------- /src/time_expression/token-classification/run_ner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Team All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for token classification. 18 | """ 19 | # You can also adapt this script on your own token classification task and datasets. Pointers for this are left as 20 | # comments. 21 | 22 | import logging 23 | import os 24 | import sys 25 | from dataclasses import dataclass, field 26 | from typing import Optional 27 | 28 | import datasets 29 | import numpy as np 30 | from datasets import ClassLabel, load_dataset 31 | 32 | import evaluate 33 | import transformers 34 | from transformers import ( 35 | AutoConfig, 36 | AutoModelForTokenClassification, 37 | AutoTokenizer, 38 | DataCollatorForTokenClassification, 39 | HfArgumentParser, 40 | PretrainedConfig, 41 | PreTrainedTokenizerFast, 42 | Trainer, 43 | TrainingArguments, 44 | set_seed, 45 | ) 46 | from transformers.trainer_utils import get_last_checkpoint 47 | from transformers.utils import check_min_version, send_example_telemetry 48 | from transformers.utils.versions import require_version 49 | 50 | # import wandb 51 | 52 | 53 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 54 | # check_min_version("4.25.0.dev0") 55 | 56 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") 57 | 58 | logger = logging.getLogger(__name__) 59 | 60 | 61 | @dataclass 62 | class ModelArguments: 63 | """ 64 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 65 | """ 66 | 67 | model_name_or_path: str = field( 68 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 69 | ) 70 | config_name: Optional[str] = field( 71 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 72 | ) 73 | tokenizer_name: Optional[str] = field( 74 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 75 | ) 76 | cache_dir: Optional[str] = field( 77 | default=None, 78 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 79 | ) 80 | model_revision: str = field( 81 | default="main", 82 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 83 | ) 84 | use_auth_token: bool = field( 85 | default=False, 86 | metadata={ 87 | "help": ( 88 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 89 | "with private models)." 90 | ) 91 | }, 92 | ) 93 | ignore_mismatched_sizes: bool = field( 94 | default=False, 95 | metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."}, 96 | ) 97 | 98 | 99 | @dataclass 100 | class DataTrainingArguments: 101 | """ 102 | Arguments pertaining to what data we are going to input our model for training and eval. 103 | """ 104 | 105 | task_name: Optional[str] = field(default="ner", metadata={"help": "The name of the task (ner, pos...)."}) 106 | dataset_name: Optional[str] = field( 107 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 108 | ) 109 | dataset_config_name: Optional[str] = field( 110 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 111 | ) 112 | train_file: Optional[str] = field( 113 | default=None, metadata={"help": "The input training data file (a csv or JSON file)."} 114 | ) 115 | validation_file: Optional[str] = field( 116 | default=None, 117 | metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."}, 118 | ) 119 | test_file: Optional[str] = field( 120 | default=None, 121 | metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."}, 122 | ) 123 | text_column_name: Optional[str] = field( 124 | default=None, metadata={"help": "The column name of text to input in the file (a csv or JSON file)."} 125 | ) 126 | label_column_name: Optional[str] = field( 127 | default=None, metadata={"help": "The column name of label to input in the file (a csv or JSON file)."} 128 | ) 129 | overwrite_cache: bool = field( 130 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 131 | ) 132 | preprocessing_num_workers: Optional[int] = field( 133 | default=None, 134 | metadata={"help": "The number of processes to use for the preprocessing."}, 135 | ) 136 | max_seq_length: int = field( 137 | default=None, 138 | metadata={ 139 | "help": ( 140 | "The maximum total input sequence length after tokenization. If set, sequences longer " 141 | "than this will be truncated, sequences shorter will be padded." 142 | ) 143 | }, 144 | ) 145 | pad_to_max_length: bool = field( 146 | default=False, 147 | metadata={ 148 | "help": ( 149 | "Whether to pad all samples to model maximum sentence length. " 150 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 151 | "efficient on GPU but very bad for TPU." 152 | ) 153 | }, 154 | ) 155 | max_train_samples: Optional[int] = field( 156 | default=None, 157 | metadata={ 158 | "help": ( 159 | "For debugging purposes or quicker training, truncate the number of training examples to this " 160 | "value if set." 161 | ) 162 | }, 163 | ) 164 | max_eval_samples: Optional[int] = field( 165 | default=None, 166 | metadata={ 167 | "help": ( 168 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 169 | "value if set." 170 | ) 171 | }, 172 | ) 173 | max_predict_samples: Optional[int] = field( 174 | default=None, 175 | metadata={ 176 | "help": ( 177 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 178 | "value if set." 179 | ) 180 | }, 181 | ) 182 | label_all_tokens: bool = field( 183 | default=False, 184 | metadata={ 185 | "help": ( 186 | "Whether to put the label for one word on all tokens of generated by that word or just on the " 187 | "one (in which case the other tokens will have a padding index)." 188 | ) 189 | }, 190 | ) 191 | return_entity_level_metrics: bool = field( 192 | default=False, 193 | metadata={"help": "Whether to return all the entity levels during evaluation or just the overall ones."}, 194 | ) 195 | ### Added by Sen Yang 196 | predict_output: Optional[str] = field( 197 | default="predictions.txt", metadata={"help": "The predict output file path."} 198 | ) 199 | 200 | def __post_init__(self): 201 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 202 | raise ValueError("Need either a dataset name or a training/validation file.") 203 | else: 204 | if self.train_file is not None: 205 | extension = self.train_file.split(".")[-1] 206 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 207 | if self.validation_file is not None: 208 | extension = self.validation_file.split(".")[-1] 209 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 210 | self.task_name = self.task_name.lower() 211 | 212 | 213 | def main(): 214 | # See all possible arguments in src/transformers/training_args.py 215 | # or by passing the --help flag to this script. 216 | # We now keep distinct sets of args, for a cleaner separation of concerns. 217 | 218 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 219 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 220 | # If we pass only one argument to the script and it's the path to a json file, 221 | # let's parse it to get our arguments. 222 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 223 | else: 224 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 225 | 226 | # config = dict ( 227 | # learning_rate = 0.01, 228 | # momentum = 0.2, 229 | # architecture = "CNN", 230 | # dataset_id = "peds-0192", 231 | # infra = "AWS", 232 | # ) 233 | 234 | # wandb.init( 235 | # project="pretrain_temporal", 236 | # notes="predict time_ner on seis17&18", 237 | # tags=["time_ner", "predict", "processing data"], 238 | # config=(model_args, data_args, training_args), 239 | # ) 240 | 241 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The 242 | # information sent is the one passed as arguments along with your Python/PyTorch versions. 243 | send_example_telemetry("run_ner", model_args, data_args) 244 | 245 | # Setup logging 246 | logging.basicConfig( 247 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 248 | datefmt="%m/%d/%Y %H:%M:%S", 249 | handlers=[logging.StreamHandler(sys.stdout)], 250 | ) 251 | 252 | log_level = training_args.get_process_log_level() 253 | logger.setLevel(log_level) 254 | datasets.utils.logging.set_verbosity(log_level) 255 | transformers.utils.logging.set_verbosity(log_level) 256 | transformers.utils.logging.enable_default_handler() 257 | transformers.utils.logging.enable_explicit_format() 258 | 259 | # Log on each process the small summary: 260 | logger.warning( 261 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 262 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 263 | ) 264 | logger.info(f"Training/evaluation parameters {training_args}") 265 | 266 | # Detecting last checkpoint. 267 | last_checkpoint = None 268 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 269 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 270 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 271 | raise ValueError( 272 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 273 | "Use --overwrite_output_dir to overcome." 274 | ) 275 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 276 | logger.info( 277 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 278 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 279 | ) 280 | 281 | # Set seed before initializing model. 282 | set_seed(training_args.seed) 283 | 284 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 285 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 286 | # (the dataset will be downloaded automatically from the datasets Hub). 287 | # 288 | # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called 289 | # 'text' is found. You can easily tweak this behavior (see below). 290 | # 291 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 292 | # download the dataset. 293 | if data_args.dataset_name is not None: 294 | # Downloading and loading a dataset from the hub. 295 | raw_datasets = load_dataset( 296 | data_args.dataset_name, 297 | data_args.dataset_config_name, 298 | cache_dir=model_args.cache_dir, 299 | use_auth_token=True if model_args.use_auth_token else None, 300 | ) 301 | else: 302 | data_files = {} 303 | if data_args.train_file is not None: 304 | data_files["train"] = data_args.train_file 305 | if data_args.validation_file is not None: 306 | data_files["validation"] = data_args.validation_file 307 | if data_args.test_file is not None: 308 | data_files["test"] = data_args.test_file 309 | extension = data_args.train_file.split(".")[-1] 310 | raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) 311 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 312 | # https://huggingface.co/docs/datasets/loading_datasets.html. 313 | 314 | if training_args.do_train: 315 | column_names = raw_datasets["train"].column_names 316 | features = raw_datasets["train"].features 317 | else: 318 | column_names = raw_datasets["validation"].column_names 319 | features = raw_datasets["validation"].features 320 | 321 | if data_args.text_column_name is not None: 322 | text_column_name = data_args.text_column_name 323 | elif "tokens" in column_names: 324 | text_column_name = "tokens" 325 | else: 326 | text_column_name = column_names[0] 327 | 328 | if data_args.label_column_name is not None: 329 | label_column_name = data_args.label_column_name 330 | elif f"{data_args.task_name}_tags" in column_names: 331 | label_column_name = f"{data_args.task_name}_tags" 332 | else: 333 | label_column_name = column_names[1] 334 | 335 | # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the 336 | # unique labels. 337 | def get_label_list(labels): 338 | unique_labels = set() 339 | for label in labels: 340 | unique_labels = unique_labels | set(label) 341 | label_list = list(unique_labels) 342 | label_list.sort() 343 | return label_list 344 | 345 | # If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere. 346 | # Otherwise, we have to get the list of labels manually. 347 | labels_are_int = isinstance(features[label_column_name].feature, ClassLabel) 348 | if labels_are_int: 349 | label_list = features[label_column_name].feature.names 350 | label_to_id = {i: i for i in range(len(label_list))} 351 | else: 352 | label_list = get_label_list(raw_datasets["train"][label_column_name]) 353 | label_to_id = {l: i for i, l in enumerate(label_list)} 354 | 355 | num_labels = len(label_list) 356 | 357 | # Load pretrained model and tokenizer 358 | # 359 | # Distributed training: 360 | # The .from_pretrained methods guarantee that only one local process can concurrently 361 | # download model & vocab. 362 | config = AutoConfig.from_pretrained( 363 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 364 | num_labels=num_labels, 365 | finetuning_task=data_args.task_name, 366 | cache_dir=model_args.cache_dir, 367 | revision=model_args.model_revision, 368 | use_auth_token=True if model_args.use_auth_token else None, 369 | ) 370 | 371 | tokenizer_name_or_path = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path 372 | if config.model_type in {"bloom", "gpt2", "roberta"}: 373 | tokenizer = AutoTokenizer.from_pretrained( 374 | tokenizer_name_or_path, 375 | cache_dir=model_args.cache_dir, 376 | use_fast=True, 377 | revision=model_args.model_revision, 378 | use_auth_token=True if model_args.use_auth_token else None, 379 | add_prefix_space=True, 380 | ) 381 | else: 382 | tokenizer = AutoTokenizer.from_pretrained( 383 | tokenizer_name_or_path, 384 | cache_dir=model_args.cache_dir, 385 | use_fast=True, 386 | revision=model_args.model_revision, 387 | use_auth_token=True if model_args.use_auth_token else None, 388 | ) 389 | 390 | model = AutoModelForTokenClassification.from_pretrained( 391 | model_args.model_name_or_path, 392 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 393 | config=config, 394 | cache_dir=model_args.cache_dir, 395 | revision=model_args.model_revision, 396 | use_auth_token=True if model_args.use_auth_token else None, 397 | ignore_mismatched_sizes=model_args.ignore_mismatched_sizes, 398 | ) 399 | 400 | # Tokenizer check: this script requires a fast tokenizer. 401 | if not isinstance(tokenizer, PreTrainedTokenizerFast): 402 | raise ValueError( 403 | "This example script only works for models that have a fast tokenizer. Checkout the big table of models at" 404 | " https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet" 405 | " this requirement" 406 | ) 407 | 408 | # Model has labels -> use them. 409 | if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id: 410 | if list(sorted(model.config.label2id.keys())) == list(sorted(label_list)): 411 | # Reorganize `label_list` to match the ordering of the model. 412 | if labels_are_int: 413 | label_to_id = {i: int(model.config.label2id[l]) for i, l in enumerate(label_list)} 414 | label_list = [model.config.id2label[i] for i in range(num_labels)] 415 | else: 416 | label_list = [model.config.id2label[i] for i in range(num_labels)] 417 | label_to_id = {l: i for i, l in enumerate(label_list)} 418 | else: 419 | logger.warning( 420 | "Your model seems to have been trained with labels, but they don't match the dataset: ", 421 | f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels:" 422 | f" {list(sorted(label_list))}.\nIgnoring the model labels as a result.", 423 | ) 424 | 425 | # Set the correspondences label/ID inside the model config 426 | model.config.label2id = {l: i for i, l in enumerate(label_list)} 427 | model.config.id2label = {i: l for i, l in enumerate(label_list)} 428 | 429 | # Map that sends B-Xxx label to its I-Xxx counterpart 430 | b_to_i_label = [] 431 | for idx, label in enumerate(label_list): 432 | if label.startswith("B-") and label.replace("B-", "I-") in label_list: 433 | b_to_i_label.append(label_list.index(label.replace("B-", "I-"))) 434 | else: 435 | b_to_i_label.append(idx) 436 | 437 | # Preprocessing the dataset 438 | # Padding strategy 439 | padding = "max_length" if data_args.pad_to_max_length else False 440 | 441 | # Tokenize all texts and align the labels with them. 442 | def tokenize_and_align_labels(examples): 443 | tokenized_inputs = tokenizer( 444 | examples[text_column_name], 445 | padding=padding, 446 | truncation=True, 447 | max_length=data_args.max_seq_length, 448 | # We use this argument because the texts in our dataset are lists of words (with a label for each word). 449 | is_split_into_words=True, 450 | ) 451 | labels = [] 452 | for i, label in enumerate(examples[label_column_name]): 453 | word_ids = tokenized_inputs.word_ids(batch_index=i) 454 | previous_word_idx = None 455 | label_ids = [] 456 | for word_idx in word_ids: 457 | # Special tokens have a word id that is None. We set the label to -100 so they are automatically 458 | # ignored in the loss function. 459 | if word_idx is None: 460 | label_ids.append(-100) 461 | # We set the label for the first token of each word. 462 | elif word_idx != previous_word_idx: 463 | label_ids.append(label_to_id[label[word_idx]]) 464 | # For the other tokens in a word, we set the label to either the current label or -100, depending on 465 | # the label_all_tokens flag. 466 | else: 467 | if data_args.label_all_tokens: 468 | label_ids.append(b_to_i_label[label_to_id[label[word_idx]]]) 469 | else: 470 | label_ids.append(-100) 471 | previous_word_idx = word_idx 472 | 473 | labels.append(label_ids) 474 | tokenized_inputs["labels"] = labels 475 | return tokenized_inputs 476 | 477 | if training_args.do_train: 478 | if "train" not in raw_datasets: 479 | raise ValueError("--do_train requires a train dataset") 480 | train_dataset = raw_datasets["train"] 481 | if data_args.max_train_samples is not None: 482 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 483 | train_dataset = train_dataset.select(range(max_train_samples)) 484 | with training_args.main_process_first(desc="train dataset map pre-processing"): 485 | train_dataset = train_dataset.map( 486 | tokenize_and_align_labels, 487 | batched=True, 488 | num_proc=data_args.preprocessing_num_workers, 489 | load_from_cache_file=not data_args.overwrite_cache, 490 | desc="Running tokenizer on train dataset", 491 | ) 492 | 493 | if training_args.do_eval: 494 | if "validation" not in raw_datasets: 495 | raise ValueError("--do_eval requires a validation dataset") 496 | eval_dataset = raw_datasets["validation"] 497 | if data_args.max_eval_samples is not None: 498 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 499 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 500 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 501 | eval_dataset = eval_dataset.map( 502 | tokenize_and_align_labels, 503 | batched=True, 504 | num_proc=data_args.preprocessing_num_workers, 505 | load_from_cache_file=not data_args.overwrite_cache, 506 | desc="Running tokenizer on validation dataset", 507 | ) 508 | 509 | if training_args.do_predict: 510 | if "test" not in raw_datasets: 511 | raise ValueError("--do_predict requires a test dataset") 512 | predict_dataset = raw_datasets["test"] 513 | if data_args.max_predict_samples is not None: 514 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) 515 | predict_dataset = predict_dataset.select(range(max_predict_samples)) 516 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 517 | predict_dataset = predict_dataset.map( 518 | tokenize_and_align_labels, 519 | batched=True, 520 | num_proc=data_args.preprocessing_num_workers, 521 | load_from_cache_file=not data_args.overwrite_cache, 522 | desc="Running tokenizer on prediction dataset", 523 | ) 524 | 525 | # Data collator 526 | data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None) 527 | 528 | # Metrics 529 | metric = evaluate.load("seqeval") 530 | 531 | def compute_metrics(p): 532 | predictions, labels = p 533 | predictions = np.argmax(predictions, axis=2) 534 | 535 | # Remove ignored index (special tokens) 536 | true_predictions = [ 537 | [label_list[p] for (p, l) in zip(prediction, label) if l != -100] 538 | for prediction, label in zip(predictions, labels) 539 | ] 540 | true_labels = [ 541 | [label_list[l] for (p, l) in zip(prediction, label) if l != -100] 542 | for prediction, label in zip(predictions, labels) 543 | ] 544 | 545 | results = metric.compute(predictions=true_predictions, references=true_labels) 546 | if data_args.return_entity_level_metrics: 547 | # Unpack nested dictionaries 548 | final_results = {} 549 | for key, value in results.items(): 550 | if isinstance(value, dict): 551 | for n, v in value.items(): 552 | final_results[f"{key}_{n}"] = v 553 | else: 554 | final_results[key] = value 555 | return final_results 556 | else: 557 | return { 558 | "precision": results["overall_precision"], 559 | "recall": results["overall_recall"], 560 | "f1": results["overall_f1"], 561 | "accuracy": results["overall_accuracy"], 562 | } 563 | 564 | # Initialize our Trainer 565 | trainer = Trainer( 566 | model=model, 567 | args=training_args, 568 | train_dataset=train_dataset if training_args.do_train else None, 569 | eval_dataset=eval_dataset if training_args.do_eval else None, 570 | tokenizer=tokenizer, 571 | data_collator=data_collator, 572 | compute_metrics=compute_metrics, 573 | ) 574 | 575 | # Training 576 | if training_args.do_train: 577 | checkpoint = None 578 | if training_args.resume_from_checkpoint is not None: 579 | checkpoint = training_args.resume_from_checkpoint 580 | elif last_checkpoint is not None: 581 | checkpoint = last_checkpoint 582 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 583 | metrics = train_result.metrics 584 | trainer.save_model() # Saves the tokenizer too for easy upload 585 | 586 | max_train_samples = ( 587 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 588 | ) 589 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 590 | 591 | trainer.log_metrics("train", metrics) 592 | trainer.save_metrics("train", metrics) 593 | trainer.save_state() 594 | 595 | # Evaluationpredict_output 596 | if training_args.do_eval: 597 | logger.info("*** Evaluate ***") 598 | 599 | metrics = trainer.evaluate() 600 | 601 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 602 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 603 | 604 | trainer.log_metrics("eval", metrics) 605 | trainer.save_metrics("eval", metrics) 606 | 607 | # Predict 608 | if training_args.do_predict: 609 | logger.info("*** Predict ***") 610 | 611 | predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict") 612 | predictions = np.argmax(predictions, axis=2) 613 | 614 | # Remove ignored index (special tokens) 615 | true_predictions = [ 616 | [label_list[p] for (p, l) in zip(prediction, label) if l != -100] 617 | for prediction, label in zip(predictions, labels) 618 | ] 619 | 620 | trainer.log_metrics("predict", metrics) 621 | trainer.save_metrics("predict", metrics) 622 | 623 | # Save predictions 624 | # output_predictions_file = os.path.join(training_args.output_dir, data_args.predict_output) 625 | output_predictions_file = data_args.predict_output 626 | if trainer.is_world_process_zero(): 627 | with open(output_predictions_file, "w") as writer: 628 | for prediction in true_predictions: 629 | writer.write(" ".join(prediction) + "\n") 630 | 631 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "token-classification"} 632 | if data_args.dataset_name is not None: 633 | kwargs["dataset_tags"] = data_args.dataset_name 634 | if data_args.dataset_config_name is not None: 635 | kwargs["dataset_args"] = data_args.dataset_config_name 636 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 637 | else: 638 | kwargs["dataset"] = data_args.dataset_name 639 | 640 | if training_args.push_to_hub: 641 | trainer.push_to_hub(**kwargs) 642 | else: 643 | trainer.create_model_card(**kwargs) 644 | 645 | 646 | def _mp_fn(index): 647 | # For xla_spawn (TPUs) 648 | main() 649 | 650 | 651 | if __name__ == "__main__": 652 | main() 653 | -------------------------------------------------------------------------------- /src/time_expression/token-classification/run_no_trainer.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | accelerate launch run_ner_no_trainer.py \ 16 | --model_name_or_path bert-base-uncased \ 17 | --dataset_name conll2003 \ 18 | --output_dir /tmp/test-ner \ 19 | --pad_to_max_length \ 20 | --task_name ner \ 21 | --return_entity_level_metrics 22 | -------------------------------------------------------------------------------- /src/time_expression/train_time_identification_model.sh: -------------------------------------------------------------------------------- 1 | python token-classification/run_ner.py \ 2 | --model_name_or_path roberta-large \ 3 | --train_file ../../data/time_expression/ner_task/train.json \ 4 | --validation_file ../../data/time_expression/ner_task/val.json \ 5 | --text_column_name tokens \ 6 | --label_column_name ner_tags \ 7 | --output_dir ../../model_checkpoints/my_own_roberta_for_time_identification \ 8 | --do_train \ 9 | --do_eval --------------------------------------------------------------------------------