├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── covid_event ├── README.md ├── arch.png ├── c_data.py ├── eval.py ├── event_level_perf.png ├── extra │ ├── args.json │ ├── train.log │ └── unmatched.txt ├── finetune.py ├── finetune_pt.py ├── inference.py ├── mismatches.png ├── predict.py ├── preds │ └── val_preds.json ├── prepare.py ├── results.png ├── submit.py └── subs │ ├── golden │ ├── can_not_test_sol.jsonl │ ├── cure_sol.jsonl │ ├── death_sol.jsonl │ ├── negative_sol.jsonl │ └── positive_sol.jsonl │ ├── post_process.py │ ├── run-1 │ ├── can_not_test.jsonl │ ├── cure.jsonl │ ├── death.jsonl │ ├── negative.jsonl │ └── positive.jsonl │ ├── run-2 │ ├── can_not_test.jsonl │ ├── cure.jsonl │ ├── death.jsonl │ ├── negative.jsonl │ └── positive.jsonl │ ├── run-3 │ ├── can_not_test.jsonl │ ├── cure.jsonl │ ├── death.jsonl │ ├── negative.jsonl │ └── positive.jsonl │ └── testid2candidates.pkl ├── example_bert.py ├── example_t5.py ├── example_trans_t5.py ├── run.py ├── setup.cfg ├── setup.py ├── ttt ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── args.cpython-37.pyc │ ├── evaluators.cpython-37.pyc │ ├── inputs.cpython-37.pyc │ ├── models.cpython-37.pyc │ ├── t2t_trainer.cpython-37.pyc │ └── utils.cpython-37.pyc ├── args.py ├── evaluators.py ├── inputs.py ├── models.py ├── t2t_trainer.py └── utils.py ├── ttt_demo.png ├── ttt_logo.png ├── ttt_notebook.ipynb └── use_tpu_tutorial ├── README.md ├── cls_metric.py ├── covid_data.py ├── images ├── apply_success_email.png ├── second_apply_email.png └── tpu_workflow.png └── run_train_test.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv/ 2 | data/ 3 | runs/ 4 | tmp/ 5 | exps/ 6 | .idea/ 7 | wandb/ 8 | __pycache__/ 9 | ttt/__pycache__/ 10 | example_covid_t5.py 11 | covid/ 12 | chinese/ 13 | dist/ 14 | pytriplet.egg-info/ 15 | demos/ 16 | externals/ 17 | covid_event/congcongwang/ 18 | scripts/ 19 | ptt/ 20 | pytriplet.tar/ 21 | pytriplet.tar.gz/ 22 | trecis/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 wangcongcongcc 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |
3 | 4 |
5 |

6 | 7 |

8 |
9 | 10 | 11 | 12 | 13 | 14 |
15 |

16 | 17 | ## TTT: Fine-tuning Transformers with TPUs or GPUs acceleration, written in Tensorflow2.0+ 18 | 19 | **TTT** or (**Triple T**) is short for a package for fine-tuning 🤗 **T**ransformers with **T**PUs, written in **T**ensorflow2.0+. It is motivated to be completed due to bugs I found tricky to solve when using [the xla library](https://github.com/pytorch/xla) with PyTorch. As a newcomer to the TF world, I am humble to learn more from the community and hence it is open sourced here. 20 | 21 | 22 | #### Update (2020-11-4): 23 | - [Tutorial in progress - Guide to use Google's TPUs with Good Details](./use_tpu_tutorial). 24 | - [Our work at W-NUT 2020 Shared task 3 for COVID event extraction on social media](./covid_event). 25 | 26 | ## Demo 27 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wangcongcong123/ttt/blob/master/ttt_notebook.ipynb) 28 | 29 | The following demonstrates the example of fine-tuning T5-small for sst2 ([example_t5.py](example_t5.py)). 30 | 31 | ![](ttt_demo.png) 32 | 35 | ## Features 36 | - Switch between TPUs and GPUs easily. 37 | - Stable training on TPUs. 38 | - Customize datasets or load from [HF's datasets library](https://huggingface.co/nlp/viewer/?dataset=aeslc). 39 | - Using pretrained tensorflow weights from the open-source library - [🤗 transformers](https://github.com/huggingface/transformers). 40 | - Fine-tuning BERT-like transformers (DistilBert, ALBERT, Electra, RoBERTa) using keras High-level API. 41 | - Fine-tuning T5-like transformers using customize training loop, written in tensorflow2.0. 42 | - Supported tasks include single sequence-based classification task (both BERT-like models and T5 model), and translation, QA, or summarization (T5, as long as an example is characterized by: `{"source","....","target","...."}` 43 | 44 | ## Quickstart 45 | 46 | #### Install 47 | ``` 48 | pip install pytriplet 49 | ``` 50 | 51 | or if you want to get the latest updates: 52 | 53 | ```shell 54 | git clone https://github.com/wangcongcong123/ttt.git 55 | cd ttt 56 | pip install -e . 57 | ``` 58 | 59 | * make sure `transformers>=3.1.0`. If not, install via `pip install transformers -U` 60 | #### update (2020-09-13): Example generation for T5 pretraining objective 61 | ```python 62 | from ttt import iid_denoise_text 63 | text="ttt is short for a package for fine-tuning 🤗 Transformers with TPUs, written in Tensorflow2.0" 64 | # here the text is split by space to tokens, you can use huggingface's T5Tokenizer to tokenize as well. 65 | original, source, target=iid_denoise_text(text.split(), span_length=3, corrupt_ratio=0.25) 66 | 67 | # original: ['ttt', 'is', 'short', 'for', 'a', 'package', 'for', 'fine-tuning', '🤗', 'Transformers', 'with', 'TPUs,', 'written', 'in', 'Tensorflow2.0'] 68 | # source: ['ttt', '', 'a', 'package', 'for', 'fine-tuning', '🤗', 'Transformers', 'with', '', ''] 69 | # target: ['', 'is', 'short', 'for', '', 'TPUs,', 'written', 'in', 'Tensorflow2.0'] 70 | ``` 71 | 72 | #### Update (2020-10-15): Example of fine-tuning T5 for translation ([example_trans_t5.py](example_trans_t5.py)) 73 | 74 | 75 | **Fine-tuning**: No boilerplate codes changed (the same as [example_t5](example_t5.py)) except for the following args: 76 | ```python3 77 | # any one from MODELS_SUPPORT (check:ttt/args.py) 78 | args.model_select = "t5-small" 79 | # the path to the translation dataset, each line represents an example in jsonl format like: {"target": "...", "source","..."} 80 | # it will download automatically for the frist time from: https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_ro.tar.gz 81 | args.data_path = "data/wmt_en_ro" 82 | # any one from TASKS_SUPPORT (check:ttt/args.py) 83 | args.task = "translation" 84 | args.max_src_length=128 85 | args.max_tgt_length=128 86 | args.source_field_name="source" 87 | args.target_field_name="target" 88 | args.eval_on="bleu" #this refers to sacrebleu as used in T5 paper 89 | ``` 90 | 91 | ** On a TPUv3-8, the bleu score achieved by t5-base is 27.9 (very close to 28 as reported in [the T5 paper](https://arxiv.org/abs/1910.10683)), the fine-tuning args are [here](https://ucdcs-student.ucd.ie/~cwang/ttt/models/en2ro_t5_base/args.json) and training log is [here](https://ucdcs-student.ucd.ie/~cwang/ttt/models/en2ro_t5_base/train.log). 92 | 93 | #### Example of fine-tuning BERT for sst2 ([example_bert.py](example_bert.py)) 94 | ```python3 95 | from ttt import * 96 | 97 | if __name__ == '__main__': 98 | args = get_args() 99 | # check what args are available 100 | logger.info(f"args: {json.dumps(args.__dict__, indent=2)}") 101 | ############### customize args 102 | # args.use_gpu = True 103 | args.use_tpu = True 104 | args.do_train = True 105 | args.use_tb = True 106 | # any one from MODELS_SUPPORT (check:ttt/args.py) 107 | args.model_select = "bert-base-uncased" 108 | # select a dataset following jsonl format, where text filed name is "text" and label field name is "label" 109 | args.data_path = "data/glue/sst2" 110 | # any one from TASKS_SUPPORT (check:ttt/args.py) 111 | args.task = "single-label-cls" 112 | args.log_steps = 400 113 | # any one from LR_SCHEDULER_SUPPORT (check:ttt/args.py) 114 | args.scheduler="warmuplinear" 115 | # set do_eval = False if your data does not contain a validation set. In that case, patience, and early_stop will be invalid 116 | args.do_eval = True 117 | args.tpu_address = "x.x.x.x" # replace with yours 118 | ############### end customize args 119 | # to have a sanity check for the args 120 | sanity_check(args) 121 | # seed everything, make deterministic 122 | set_seed(args.seed) 123 | tokenizer = get_tokenizer(args) 124 | inputs = get_inputs(tokenizer, args) 125 | model, _ = create_model(args, logger, get_model) 126 | # start training, here we keras high-level API 127 | training_history = model.fit( 128 | inputs["x_train"], 129 | inputs["y_train"], 130 | epochs=args.num_epochs_train, 131 | verbose=2, 132 | batch_size=args.per_device_train_batch_size*args.num_replicas_in_sync, 133 | callbacks=get_callbacks(args, inputs, logger, get_evaluator), 134 | ) 135 | ``` 136 | 137 | So far the package has included the following supports for `args.model_select`, `args.task` and `args.scheduler` ([args.py](ttt/args.py)). 138 | 139 | ```python3 140 | # these have been tested and work fine. more can be added to this list to test 141 | MODELS_SUPPORT = ["distilbert-base-cased","bert-base-uncased", "bert-large-uncased", "google/electra-base-discriminator", 142 | "google/electra-large-discriminator", "albert-base-v2", "roberta-base", 143 | "t5-small","t5-base"] 144 | # if using t5 models, the tasks has to be t2t* ones 145 | TASKS_SUPPORT = ["single-label-cls", "t2t"] 146 | # in the future, more schedulers will be added, such as warmupconstant, warmupcosine, etc. 147 | LR_SCHEDULER_SUPPORT = ["warmuplinear", "warmupconstant", "constant"] 148 | ``` 149 | 150 | ## Command lines (suited in GCP) 151 | 152 | This has to be run in Google GCP VM instance since the tpu_address is internal IP from Google (or change `--use_tpu` to `use_gpu` if you have enough GPUs). The flag `--tpu_address` should be replaced with yours. Notice: these runs are run with a set of "look-good" hyper-parameters but not exhaustively selected. 153 | 154 | #### Experiment BERT on sst2 using TPUv2-8 155 | 156 | C-1-1: 157 | ``` 158 | python3 run.py --model_select bert-base-uncased --data_path data/glue/sst2 --task single-label-cls --per_device_train_batch_size 8 --num_epochs_train 6 --max_seq_length 128 --lr 5e-5 --schedule warmuplinear --do_train --do_eval --do_test --use_tpu --tpu_address x.x.x.x 159 | ``` 160 | 161 | C-1-2: 162 | 163 | ``` 164 | python3 run.py --model_select bert-large-uncased --data_path data/glue/sst2 --task single-label-cls --per_device_train_batch_size 8 --num_epochs_train 6 --max_seq_length 128 --lr 5e-5 --schedule warmuplinear --do_train --do_eval --do_test --use_tpu --tpu_address x.x.x.x 165 | ``` 166 | 167 | ** In addition, experiments on larger batch sizes were also conducted on TPUv2-8. For example, when `per_device_train_batch_size` is 128 (batch size=8*128=1024), this first epoch takes around ~1 minute and the rest of each takes just ~15 seconds! That is fast but the sst2 accuracy goes down significantly. 168 | 169 | #### Results 170 | 171 | | | bert-base-uncased (110M) | | | | bert-large-uncased (340M) | | | | 172 | |-------------|:--------------------------:|:----------------------------------------------:|:---------------------------:|---------------------------------|:---------------------------:|:----------------------------------------------:|:---------------------------:|---------------------------------| 173 | | | here | [BERT paper](https://arxiv.org/abs/1810.04805) | reproduction (here) command | time spent on a [n1-standard-8](https://cloud.google.com/compute/docs/machine-types) * | here | [BERT paper](https://arxiv.org/abs/1810.04805) | reproduction (here) command | time spent on a [n1-standard-8](https://cloud.google.com/compute/docs/machine-types) * | 174 | | sst2 (test set, acc.) | 93.36 | 93.5 | C-1-1 | 16 minutes | 94.45 | 94.9 | C-1-2 | 37 minutes | 175 | * *refer to the estimated time including training, every 400 steps evaluation and evaluation on testing. 176 | * Looks good, the results are close to the original reported results. 177 | 178 | ### Experiment T5 on sst2 using TPUv2-8 179 | 180 | C-2-1: 181 | ``` 182 | python3 run.py --model_select t5-small --data_path data/glue/sst2 --task t2t --per_device_train_batch_size 8 --num_epochs_train 6 --max_seq_length 128 --lr 5e-5 --schedule warmuplinear --do_train --do_eval --do_test --use_tpu --tpu_address x.x.x.x 183 | ``` 184 | C-2-2: 185 | ``` 186 | python3 run.py --model_select t5-base --data_path data/glue/sst2 --task t2t --per_device_train_batch_size 8 --num_epochs_train 6 --max_seq_length 128 --lr 5e-5 --schedule warmuplinear --do_train --do_eval --do_test --use_tpu --tpu_address x.x.x.x 187 | ``` 188 | 189 | C-2-3: 190 | ``` 191 | python3 run.py --model_select t5-large --data_path data/glue/sst2 --task t2t --per_device_train_batch_size 2 --eval_batch_size 8 --num_epochs_train 6 --max_seq_length 128 --lr 5e-5 --schedule warmuplinear --do_train --do_eval --do_test --use_tpu --tpu_address x.x.x.x 192 | ``` 193 | ** failed (out-of-memory) although `per_device_train_batch_size`=2. Does a TPUv2-8 not have enough memory to fine-tune a `t5-large` model? Looking for solutions to fine-tune `t5-large`. **Update:** Later on, I am lucky to get a TPUv3-8 (128G), so it is run successfully. 194 | 195 | #### Results 196 | 197 | | | t5-small (60M) | | | | t5-base (220M) | | | | t5-large (770 M) | | | | 198 | |-----------------------|:--------------:|:--------------------------------------------:|:---------------------------:|:-------------------------------:|:--------------:|:--------------------------------------------:|:---------------------------:|:-------------------------------:|:----------------:|:--------------------------------------------:|:---------------------------:|:-------------------------------:| 199 | | | here | [T5 paper](https://arxiv.org/abs/1910.10683) | reproduction (here) command | time spent on a n1-standard-8 * | here | [T5 paper](https://arxiv.org/abs/1910.10683) | reproduction (here) command | time spent on a n1-standard-8 * | here | [T5 paper](https://arxiv.org/abs/1910.10683) | reproduction (here) command | time spent on a n1-standard-8 ** | 200 | | sst2 (test set, acc.) | 90.12 | 91.8 | C-2-1 | 20 minutes | 94.18 | 95.2 | C-2-2 | 36 minutes | 95.77 | 96.3 | C-2-3 | 4.5 hours | 201 | 202 | * *refer to the estimated time including training, every 400 steps evaluation and evaluation on testing. 203 | * **the same but with a TPUv3-8 and smaller batch size (see command C-2-3). 204 | * Looks not bad, the results are a bit close to the original reported results. 205 | 206 | 207 | ## Contributions 208 | - Contributions are welcome. 209 | 210 | ## Todo ideas 211 | - To include more different language tasks, such as sequence-pair based classificaton, t5 toy pretraining, etc. 212 | - LR scheduler so far include "warmuplinear", "warmupconstant", "constant", "constantlinear". The plan is to implement all these that are available in [optimizer_schedules](https://huggingface.co/transformers/main_classes/optimizer_schedules.html#schedules). 213 | - Now all fine-tuning use Adam as the default optimizer. The plan is to implement others such as AdaFactor, etc. 214 | - Optimizations include: TF clip_grad_norm as used in PyTroch fine-tuning, AMP training, etc. 215 | 216 | 217 | ## Last 218 | 219 | I have been looking for PyTorch alternatives that can help train large models with Google's TPUs in Google's GCP VM instance env. Although the [xla](https://github.com/pytorch/xla) lib seems good, I gave it up due to some bugs I found hard to fix. Something like "process terminated with SIGKILL" confused me a lot, and took me loads of time, and eventually fail to solve after searching all kinds of answers online ([ref1](https://github.com/PyTorchLightning/pytorch-lightning/issues/1590), [ref2](https://github.com/huggingface/transformers/issues/3660), the community looks not that active in this field). Later on, some clues online tell me this problem is something related to memory overloading and I expect the xla lib will be more stable release in the future. It works well when being experimented with [the MNIST example](https://cloud.google.com/tpu/docs/tutorials/mnist) provided in Google's official website but comes up the "memory" problem when tested on big models like transformers (I did not make this 🤗 transformers' [xla_spawn.py](https://github.com/huggingface/transformers/blob/master/examples/xla_spawn.py) run successful either). 220 | 221 | Hence, I shift to learn Tensorflow as a newcomer from PyTorch to make my life easy whenever I feel needed to train a model on TPUs. Thankfully, Tensorflow-2.0 makes this shift not that difficult although some [complains](https://twitter.com/snrrrub/status/1301228252325797888) on it always go on. After around three days of researching and coding, I end up with this simple package. This package is made public-available in hope of helping whoever has the same encountering as me. Most of the training code (so-called boilerplate codes) flow in this package looks a style of PyTorch due to my old habit. Hopefully, this makes it easy to know Tensorflow-2.0 when you are from PyTorch and you need TPUs. 222 | 223 | ### Ack. 224 | Thanks for [Google's TFRC Program](https://www.tensorflow.org/tfrc) giving TPUs credits to make this possible. 225 | -------------------------------------------------------------------------------- /covid_event/README.md: -------------------------------------------------------------------------------- 1 | ### This unit guides to reproduce the results in our paper titled "UCD-CS at W-NUT 2020 Shared Task-3: A Text to Text Approach for COVID-19 Event Extraction on Social Media" (https://www.aclweb.org/anthology/2020.wnut-1.78/, [Cite](#cite)), accepted to W-NUT EMNLP 2020. 2 | 3 |

4 |
5 | 6 | 7 | 8 | 9 |
10 |

11 | 12 | #### Citation 13 | ``` 14 | @inproceedings{wang-lillis-2020-ucd, 15 | title = "{UCD}-{CS} at {W}-{NUT} 2020 Shared Task-3: A Text to Text Approach for {COVID}-19 Event Extraction on Social Media", 16 | author = "Wang, Congcong and 17 | Lillis, David", 18 | booktitle = "Proceedings of the Sixth Workshop on Noisy User-generated Text (W-NUT 2020)", 19 | month = nov, 20 | year = "2020", 21 | address = "Online", 22 | publisher = "Association for Computational Linguistics", 23 | url = "https://www.aclweb.org/anthology/2020.wnut-1.78", 24 | doi = "10.18653/v1/2020.wnut-1.78", 25 | pages = "514--521", 26 | abstract = "In this paper, we describe our approach in the shared task: COVID-19 event extraction from Twitter. The objective of this task is to extract answers from COVID-related tweets to a set of predefined slot-filling questions. Our approach treats the event extraction task as a question answering task by leveraging the transformer-based T5 text-to-text model. According to the official evaluation scores returned, namely F1, our submitted run achieves competitive performance compared to other participating runs (Top 3). However, we argue that this evaluation may underestimate the actual performance of runs based on text-generation. Although some such runs may answer the slot questions well, they may not be an exact string match for the gold standard answers. To measure the extent of this underestimation, we adopt a simple exact-answer transformation method aiming at converting the well-answered predictions to exactly-matched predictions. The results show that after this transformation our run overall reaches the same level of performance as the best participating run and state-of-the-art F1 scores in three of five COVID-related events. Our code is publicly available to aid reproducibility", 27 | } 28 | ``` 29 | 30 | 31 | ### Changelog 32 | - 2021-03-10, update the paper from ACL Anthology 33 | - 2020-10-15, add fine-tuned t5-base model. 34 | 35 | ### Demo ([inference.py](inference.py)) 36 | 37 | ```python3 38 | from transformers import T5Tokenizer, TFT5ForConditionalGeneration, T5ForConditionalGeneration 39 | # the model will be downloaded automatically from Huggingface's model hub, corresponding to run-1 in the paper. 40 | model_name_or_path = "congcongwang/t5-large-fine-tuned-wnut-2020-task3" 41 | # Or try replace "congcongwang/t5-large-fine-tuned-wnut-2020-task3" with ""congcongwang/t5-base-fine-tuned-wnut-2020-task3" that is much lighter than the but still hits a decent performance (see table 2a) 42 | 43 | # PyTorch 44 | model = T5ForConditionalGeneration.from_pretrained(model_name_or_path) 45 | 46 | # Or Tensorflow2.0 47 | # model = TFT5ForConditionalGeneration.from_pretrained(model_name_or_path,from_pt=True) 48 | 49 | tokenizer = T5Tokenizer.from_pretrained(model_name_or_path) 50 | 51 | source = "context: *Prince Charles tests positive for Corona* Prince William knowing he's " \ 52 | "the next in line to the throne: https://t.co/B1nmIpLj69. question: Who is tested positive?" \ 53 | "choices: author of the tweet, not specified, the next in line, line to the throne, *Prince Charles," \ 54 | " Corona* Prince William, he, the next, line, the throne." 55 | 56 | inputs = tokenizer.encode(source, return_tensors="tf") # Batch size 1. change "pt" to "tf" if using Tensorflow2.0 model 57 | result = model.generate(inputs) 58 | # output: Prince Charles 59 | ``` 60 | 61 | Quick links: 62 | - [the dataset release and task proposal page](https://github.com/viczong/extract_COVID19_events_from_Twitter/tree/master/shared_task) 63 | - [details of slot questions and candidate answers](https://docs.google.com/document/d/1OWFTXOZpoXNrDULq6PFXvIGarSZwpU-uLQRuV4wrJwI/edit) 64 | - the hyper-parameters and training process of the above demonstrated model: [args.json](extra/args.json) and [train.log](extra/train.log). 65 | - [the complete list](extra/unmatched.txt) (run-2/post-run-2) of unmatched generated predictions of the above demonstrated model based on test set annotations that can be found [here](preds/golden). 66 | - fine-tuned model weights (both TF2.0 and PyTorch) [downloading link](https://drive.google.com/file/d/1tuI54jDK7OfiVemninyZbUo3sybYHosg/view?usp=sharing) 67 | 68 | 69 | ### Quick reproduction 70 | 71 | #### reproduction of table 2b 72 | ```bash 73 | python eval.py --run_name run-3 74 | python eval.py --run_name run-2 75 | python eval.py --run_name run-1 76 | ``` 77 | 78 | Now `python subs/post_process.py` to post processe these runs to `post-run-3`, `post-run-2`, and `post-run-1` respectively, which corresponds to the runs in table 2c. To reproduce the results in table 2c, then run the following commands 79 | 80 | ```bash 81 | python eval.py --run_name post-run-3 82 | python eval.py --run_name post-run-2 83 | python eval.py --run_name post-run-1 84 | ``` 85 | 86 | ### Reproduction from scratch 87 | 88 | #### Data preparation 89 | 90 | Download the corpus from [this repository](https://github.com/viczong/extract_COVID19_events_from_Twitter/tree/master/shared_task) or sent an email at [wangcongcongcc@gmail.com](wangcongcongcc@gmail.com) to request the 7149 tweets used in the paper. Then name the obtained corpus as `corpus.json` and put it under `./data` (create it first). In `corpus.json`, each line represents an example with jsonl format as follows: 91 | 92 | ```bash 93 | {"id_str": "...", "full_text": "...", "candidate_chunks_offsets": [[34, 40], [51, 56], [58, 62], [67, 70], [127, 137], [12, 28]], "annotation": {"part1.Response": ["..."], "part2-relation.Response": ["..."], "part2-symptoms.Response": ["..."], "part2-name.Response": ["..."], "part2-when.Response": ["..."], "part2-where.Response": ["..."]}, "event_type": "can_not_test"} 94 | ``` 95 | Then run the command to prepare the training and validation set (splitting and constructing the source and target sequences) ready for subsequent model training. 96 | 97 | ```bash 98 | python prepare.py 99 | ``` 100 | 101 | This will generate two folders with `train.json` and `val.json` inside`./data/middle` and `./data/final`. The `./data/final` is what we need to for model training. To do the same process for test set, request `test.json` at [wangcongcongcc@gmail.com](wangcongcongcc@gmail.com) first and make some changes in `prepare.py` accordingly to prepare the test set. 102 | 103 | #### Fine-tune T5 104 | To skip the following time-consuming fine-tuning, it is recommended to download the already fine-tuned model from [here](https://drive.google.com/file/d/1tuI54jDK7OfiVemninyZbUo3sybYHosg/view?usp=sharing), which is fine-tuned on T5-large with 12 epochs (corresponding to run-1 in the paper) and ready for predictions. After downloading, unzip and put it `./tmp/` (create it first). 105 | 106 | Before fine-tuning, ensure to install the dependencies first: 107 | 108 | ```` 109 | git clone https://github.com/wangcongcong123/ttt.git 110 | cd ttt 111 | pip install -e . 112 | ```` 113 | 114 | Start fine-tuning with Tensorflow2.0 115 | 116 | ``` 117 | python finetune.py --model_select t5-small --data_path data/final --task t2t --per_device_train_batch_size 8 --num_epochs_train 12 --max_src_length 512 --max_tgt_length 78 --lr 5e-5 --schedule warmuplinear --warmup_ratio 0.1 --log_steps -1 --keep_ck_num 3 --use_gpu --source_field_name source --target_field_name target 118 | ``` 119 | 120 | Or you prefer fine-tuning with PyTorch. 121 | 122 | ```python 123 | python finetune_pt.py 124 | ``` 125 | 126 | Reminders 127 | * Try `python finetune.py --help` or `python finetune_pt.py --help` to know the flags. 128 | * `finetune_pt.py` has the same set of flags as the TF one using one GPU by default. To manipulate the flags, have a look at the script. 129 | * For `finetune.py`, it fine-tunes a T5-small with GPU as an example. If you have more resources like TPUs to train larger models, just change flag `--use-gpu`, to `--use_tpu` and add an extra flag `--tpu_address x.x.x.x`. 130 | 131 | After the fine-tuning is done, the training details and model weights (checkpoints of last three epochs) can be found at `./tmp/{model_save_path}`. 132 | 133 | #### Predictions 134 | 135 | Assume the trained model is saved to `./tmp/t5-large-fine-tuned-wnut-2020-task3` 136 | 137 | ``` 138 | python predict.py 139 | ``` 140 | 141 | ** This will make predictions for `data/final/val.json` and output the predictions to `./preds/val_preds.json`. 142 | 143 | #### Submission 144 | 145 | Since the predictions are made in flat over slots. We need to merge them by tweet ids to meet [the required format](https://github.com/viczong/extract_COVID19_events_from_Twitter/tree/master/shared_task) before evaluation. In `subs/` includes a fixture of `val_preds.json` for example here. To convert it: 146 | 147 | ```python 148 | python submit.py 149 | ``` 150 | 151 | After this, the converted predictions are saved to `subs/val-run-1/`. We now have the runs ready for evaluation (the same as the test runs as mentioned in the [quick reproduction](#quick) section). 152 | 153 | 154 | -------------------------------------------------------------------------------- /covid_event/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/covid_event/arch.png -------------------------------------------------------------------------------- /covid_event/c_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # This script is finished following HF's datasets' template: 3 | # https://github.com/huggingface/datasets/blob/master/templates/new_dataset_script.py 4 | # More examples as references to write a customized dataset can be found here: 5 | # https://github.com/huggingface/datasets/tree/master/datasets 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import json 10 | 11 | import datasets 12 | 13 | _CITATION = """\ 14 | 15 | """ 16 | _DESCRIPTION = """\ 17 | """ 18 | 19 | _TRAIN_DOWNLOAD_URL = "data/final/train.json" 20 | _VAL_DOWNLOAD_URL = "data/final/val.json" 21 | 22 | class CData(datasets.GeneratorBasedBuilder): 23 | """covid event data script.""" 24 | # VERSION = datasets.Version("1.0.0") 25 | def _info(self): 26 | return datasets.DatasetInfo( 27 | description=_DESCRIPTION, 28 | features=datasets.Features( 29 | { 30 | "source": datasets.Value("string"), 31 | "target": datasets.Value("string"), 32 | } 33 | ), 34 | supervised_keys=None, 35 | homepage="#", 36 | citation=_CITATION, 37 | ) 38 | 39 | def _split_generators(self, dl_manager): 40 | train_path = dl_manager.download_and_extract(_TRAIN_DOWNLOAD_URL) 41 | val_path = dl_manager.download_and_extract(_VAL_DOWNLOAD_URL) 42 | return [ 43 | datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": train_path}), 44 | datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"filepath": val_path}), 45 | ] 46 | 47 | def _generate_examples(self, filepath): 48 | with open(filepath, encoding='utf-8') as f: 49 | for id_, row in enumerate(f): 50 | data = json.loads(row) 51 | yield id_, { 52 | "source": data["source"], 53 | "target": data["target"], 54 | } 55 | -------------------------------------------------------------------------------- /covid_event/eval.py: -------------------------------------------------------------------------------- 1 | ''' 2 | copyright declaration: 3 | this eval script is adapted from: https://github.com/viczong/extract_COVID19_events_from_Twitter/tree/master/shared_task 4 | ''' 5 | import argparse 6 | import json 7 | import numpy as np 8 | 9 | ### read file 10 | def readJSONLine(path): 11 | output = [] 12 | with open(path, 'r') as f: 13 | for line in f: 14 | output.append(json.loads(line)) 15 | return output 16 | 17 | ### evaluation script 18 | def runEvaluation(system_predictions, golden_predictions): 19 | ## read in files 20 | golden_predictions_dict = {} 21 | for each_line in golden_predictions: 22 | golden_predictions_dict[each_line['id']] = each_line 23 | 24 | ## question tags 25 | question_tag = [i for i in golden_predictions[0]['golden_annotation'] if 'part2' in i] 26 | ## evaluation 27 | result = {} 28 | for each_task in question_tag: 29 | # evaluate curr task 30 | curr_task = {} 31 | TP, FP, FN = 0.0, 0.0, 0.0 32 | for each_line in system_predictions: 33 | curr_sys_pred = [i.lower() for i in each_line['predicted_annotation'][each_task] if \ 34 | i != 'Not Specified' and i != 'not specified' and i != 'not_effective'] 35 | # print(golden_predictions_dict[each_line['id']]['golden_annotation'][each_task]) 36 | curr_golden_ann = [i.lower() for i in 37 | golden_predictions_dict[each_line['id']]['golden_annotation'][each_task] \ 38 | if i != 'Not Specified' and i != 'not specified' and i != 'not_effective'] 39 | # print(curr_sys_pred, curr_golden_ann) 40 | if len(curr_golden_ann) > 0: 41 | for predicted_chunk in curr_sys_pred: 42 | if predicted_chunk in curr_golden_ann: 43 | TP += 1 # True positives are predicted spans that appear in the gold labels. 44 | else: 45 | FP += 1 # False positives are predicted spans that don't appear in the gold labels. 46 | for gold_chunk in curr_golden_ann: 47 | if gold_chunk not in curr_sys_pred: 48 | FN += 1 # False negatives are gold spans that weren't in the set of spans predicted by the model. 49 | else: 50 | if len(curr_sys_pred) > 0: 51 | for predicted_chunk in curr_sys_pred: 52 | FP += 1 # False positives are predicted spans that don't appear in the gold labels. 53 | # print 54 | if TP + FP == 0: 55 | P = 0.0 56 | else: 57 | P = TP / (TP + FP) 58 | 59 | if TP + FN == 0: 60 | R = 0.0 61 | else: 62 | R = TP / (TP + FN) 63 | 64 | if P + R == 0: 65 | F1 = 0.0 66 | else: 67 | F1 = 2.0 * P * R / (P + R) 68 | 69 | curr_task["F1"] = F1 70 | curr_task["P"] = P 71 | curr_task["R"] = R 72 | curr_task["TP"] = TP 73 | curr_task["FP"] = FP 74 | curr_task["FN"] = FN 75 | N = TP + FN 76 | curr_task["N"] = N 77 | # print(curr_task) 78 | result[each_task.replace('.Response', '')] = curr_task 79 | # print 80 | # print(each_task.replace('.Response', '')) 81 | # print('P:', curr_task['P'], 'R:', curr_task['R'], 'F1:', curr_task['F1']) 82 | # print('=======') 83 | ### calculate micro-F1 84 | all_TP = np.sum([i[1]['TP'] for i in result.items()]) 85 | all_FP = np.sum([i[1]['FP'] for i in result.items()]) 86 | all_FN = np.sum([i[1]['FN'] for i in result.items()]) 87 | 88 | all_P = all_TP / (all_TP + all_FP) 89 | all_R = all_TP / (all_TP + all_FN) 90 | all_F1 = 2.0 * all_P * all_R / (all_P + all_R) 91 | 92 | ## append 93 | result['micro'] = {} 94 | result['micro']['TP'] = all_TP 95 | result['micro']['FP'] = all_FP 96 | result['micro']['FN'] = all_FN 97 | result['micro']['P'] = all_P 98 | result['micro']['R'] = all_R 99 | result['micro']['F1'] = all_F1 100 | result['micro']['N'] = all_TP + all_FN 101 | # print('micro F1', all_F1) 102 | return result 103 | 104 | if __name__ == '__main__': 105 | ##### Attention: replace YOUR_TEAM_NAME with your actual team name 106 | ## YOUR_TEAM_NAME = 'OSU_NLP' 107 | parser = argparse.ArgumentParser(description='Hyper params') 108 | parser.add_argument('--run_name', type=str, default="run-2", 109 | help='run name for evaluation on test set (2500 tweets, 500 per event)') 110 | args = parser.parse_args() 111 | input_path = './subs/' + args.run_name +'/' 112 | golden_path = './subs/golden/' 113 | team_name = input_path.split('/')[-2] 114 | print('team name:', team_name) 115 | ### score each category 116 | category_flag = ['positive', 'negative', 'can_not_test', 'death', 'cure'] 117 | curr_team = {} 118 | curr_team['team_name'] = team_name 119 | ## loop each category 120 | all_category_results = {} 121 | for each_category in category_flag: 122 | ## read in data 123 | curr_pred = readJSONLine(input_path + each_category + '.jsonl') 124 | curr_sol = readJSONLine(golden_path + each_category + '_sol.jsonl') 125 | ## generate result 126 | curr_result = runEvaluation(curr_pred, curr_sol) 127 | ## print 128 | print(team_name, each_category, 'F1:', curr_result['micro']['F1']) 129 | ## append result 130 | all_category_results[each_category] = curr_result 131 | ### overall 132 | all_cate_TP = np.sum([i[1]['micro']['TP'] for i in all_category_results.items()]) 133 | all_cate_FP = np.sum([i[1]['micro']['FP'] for i in all_category_results.items()]) 134 | all_cate_FN = np.sum([i[1]['micro']['FN'] for i in all_category_results.items()]) 135 | # print(all_cate_TP + all_cate_FN) 136 | ### micro-F1 137 | all_cate_P = all_cate_TP / (all_cate_TP + all_cate_FP) 138 | all_cate_R = all_cate_TP / (all_cate_TP + all_cate_FN) 139 | all_cate_F1 = 2.0 * all_cate_P * all_cate_R / (all_cate_P + all_cate_R) 140 | curr_team['category_perf'] = all_category_results 141 | merged_performance = {} 142 | merged_performance['TP'] = all_cate_TP 143 | merged_performance['FP'] = all_cate_FP 144 | merged_performance['FN'] = all_cate_FN 145 | merged_performance['P'] = all_cate_P 146 | merged_performance['R'] = all_cate_R 147 | merged_performance['F1'] = all_cate_F1 148 | curr_team['overall_perf'] = merged_performance 149 | print('-----') 150 | print(merged_performance) 151 | print(team_name, 'overall', 'F1:', all_cate_F1) 152 | print('======') 153 | -------------------------------------------------------------------------------- /covid_event/event_level_perf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/covid_event/event_level_perf.png -------------------------------------------------------------------------------- /covid_event/extra/args.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_select": "t5-large", 3 | "data_path": "data/covid_event", 4 | "task": "translation", 5 | "per_device_train_batch_size": 2, 6 | "eval_batch_size": 8, 7 | "num_epochs_train": 12, 8 | "log_steps": -1, 9 | "max_seq_length": 128, 10 | "max_src_length": 512, 11 | "max_tgt_length": 512, 12 | "source_field_name": "source", 13 | "target_field_name": "target", 14 | "lr": 5e-05, 15 | "warmup_ratio": 0.1, 16 | "patience": 20, 17 | "scheduler": "warmuplinear", 18 | "seed": 122, 19 | "eval_on": "acc", 20 | "keep_ck_num": 3, 21 | "ck_index_select": 0, 22 | "do_train": true, 23 | "do_eval": false, 24 | "do_test": false, 25 | "use_gpu": false, 26 | "use_tpu": true, 27 | "use_tb": true, 28 | "tpu_address": "x.x.x.x", 29 | "default_store": false, 30 | "output_folder": "t5-large_translation_covid_event", 31 | "output_path": "tmp/t5-large_translation_covid_event", 32 | "target_special_append_when_reading": false, 33 | "is_data_cache": true, 34 | "data_cache_path": "data/covid_event/t5-large-data.pkl", 35 | "source_sequence_length": 473, 36 | "target_sequence_length": 78, 37 | "num_replicas_in_sync": 8 38 | } -------------------------------------------------------------------------------- /covid_event/extra/train.log: -------------------------------------------------------------------------------- 1 | 2020-09-07 21:59:51,553.553 INFO inputs - get_with_prepare_func: reading cached data from data/covid_event/t5-large-data.pkl 2 | 2020-09-07 21:59:51,553.553 WARNING inputs - get_with_prepare_func: if you changed the max_seq_length/max_src_length/max_tgt_length, this may not correctly loaded, since the data/covid_event/t5-large-data.pkl is pickled based on first time loading 3 | 2020-09-07 21:59:59,860.860 INFO utils - create_model: All TPU devices: 4 | 2020-09-07 21:59:59,860.860 INFO utils - create_model: LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU') 5 | 2020-09-07 21:59:59,860.860 INFO utils - create_model: LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU') 6 | 2020-09-07 21:59:59,860.860 INFO utils - create_model: LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU') 7 | 2020-09-07 21:59:59,860.860 INFO utils - create_model: LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU') 8 | 2020-09-07 21:59:59,860.860 INFO utils - create_model: LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU') 9 | 2020-09-07 21:59:59,860.860 INFO utils - create_model: LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU') 10 | 2020-09-07 21:59:59,860.860 INFO utils - create_model: LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU') 11 | 2020-09-07 21:59:59,860.860 INFO utils - create_model: LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU') 12 | 2020-09-07 22:01:40,501.501 INFO utils - create_model: None 13 | 2020-09-07 22:01:40,938.938 INFO t2t_trainer - train: start training at epoch = 0 14 | 2020-09-07 22:01:40,938.938 INFO t2t_trainer - train: global train batch size = 16 15 | 2020-09-07 22:01:40,938.938 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear 16 | 2020-09-07 22:01:40,938.938 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809 17 | 2020-09-07 22:01:40,938.938 INFO t2t_trainer - train: warmup_steps:2170 18 | 2020-09-07 22:30:06,563.563 INFO t2t_trainer - train: train loss at end of epoch 0: 144.05117502817632 19 | 2020-09-07 22:30:06,565.565 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_0.h5 20 | 2020-09-07 22:30:19,389.389 INFO t2t_trainer - train: start training at epoch = 1 21 | 2020-09-07 22:30:19,389.389 INFO t2t_trainer - train: global train batch size = 16 22 | 2020-09-07 22:30:19,389.389 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear 23 | 2020-09-07 22:30:19,389.389 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809 24 | 2020-09-07 22:30:19,389.389 INFO t2t_trainer - train: warmup_steps:2170 25 | 2020-09-07 22:42:43,364.364 INFO t2t_trainer - train: train loss at end of epoch 1: 1.0729178690272778 26 | 2020-09-07 22:42:43,365.365 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_1.h5 27 | 2020-09-07 22:42:55,281.281 INFO t2t_trainer - train: start training at epoch = 2 28 | 2020-09-07 22:42:55,282.282 INFO t2t_trainer - train: global train batch size = 16 29 | 2020-09-07 22:42:55,282.282 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear 30 | 2020-09-07 22:42:55,282.282 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809 31 | 2020-09-07 22:42:55,282.282 INFO t2t_trainer - train: warmup_steps:2170 32 | 2020-09-07 22:55:21,735.735 INFO t2t_trainer - train: train loss at end of epoch 2: 0.7845621262859778 33 | 2020-09-07 22:55:21,736.736 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_2.h5 34 | 2020-09-07 22:55:35,783.783 INFO t2t_trainer - train: start training at epoch = 3 35 | 2020-09-07 22:55:35,784.784 INFO t2t_trainer - train: global train batch size = 16 36 | 2020-09-07 22:55:35,784.784 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear 37 | 2020-09-07 22:55:35,784.784 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809 38 | 2020-09-07 22:55:35,784.784 INFO t2t_trainer - train: warmup_steps:2170 39 | 2020-09-07 23:08:00,730.730 INFO t2t_trainer - train: train loss at end of epoch 3: 0.6215832809108836 40 | 2020-09-07 23:08:00,731.731 INFO t2t_trainer - save_ck: there are already 3 checkpoints saved that will be more than keep_ck_num=3 41 | 2020-09-07 23:08:00,731.731 INFO t2t_trainer - save_ck: hence, remove the oldest one: tmp/t5-large_translation_covid_event/ck_at_epoch_0.h5 42 | 2020-09-07 23:08:01,097.097 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_3.h5 43 | 2020-09-07 23:08:12,035.035 INFO t2t_trainer - train: start training at epoch = 4 44 | 2020-09-07 23:08:12,035.035 INFO t2t_trainer - train: global train batch size = 16 45 | 2020-09-07 23:08:12,035.035 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear 46 | 2020-09-07 23:08:12,035.035 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809 47 | 2020-09-07 23:08:12,035.035 INFO t2t_trainer - train: warmup_steps:2170 48 | 2020-09-07 23:20:35,271.271 INFO t2t_trainer - train: train loss at end of epoch 4: 0.50728934064699 49 | 2020-09-07 23:20:35,272.272 INFO t2t_trainer - save_ck: there are already 3 checkpoints saved that will be more than keep_ck_num=3 50 | 2020-09-07 23:20:35,272.272 INFO t2t_trainer - save_ck: hence, remove the oldest one: tmp/t5-large_translation_covid_event/ck_at_epoch_1.h5 51 | 2020-09-07 23:20:35,691.691 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_4.h5 52 | 2020-09-07 23:20:44,568.568 INFO t2t_trainer - train: start training at epoch = 5 53 | 2020-09-07 23:20:44,568.568 INFO t2t_trainer - train: global train batch size = 16 54 | 2020-09-07 23:20:44,568.568 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear 55 | 2020-09-07 23:20:44,568.568 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809 56 | 2020-09-07 23:20:44,568.568 INFO t2t_trainer - train: warmup_steps:2170 57 | 2020-09-07 23:33:08,921.921 INFO t2t_trainer - train: train loss at end of epoch 5: 0.4129441963964862 58 | 2020-09-07 23:33:08,922.922 INFO t2t_trainer - save_ck: there are already 3 checkpoints saved that will be more than keep_ck_num=3 59 | 2020-09-07 23:33:08,922.922 INFO t2t_trainer - save_ck: hence, remove the oldest one: tmp/t5-large_translation_covid_event/ck_at_epoch_2.h5 60 | 2020-09-07 23:33:09,288.288 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_5.h5 61 | 2020-09-07 23:33:17,834.834 INFO t2t_trainer - train: start training at epoch = 6 62 | 2020-09-07 23:33:17,834.834 INFO t2t_trainer - train: global train batch size = 16 63 | 2020-09-07 23:33:17,834.834 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear 64 | 2020-09-07 23:33:17,834.834 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809 65 | 2020-09-07 23:33:17,834.834 INFO t2t_trainer - train: warmup_steps:2170 66 | 2020-09-07 23:45:39,810.810 INFO t2t_trainer - train: train loss at end of epoch 6: 0.34766357622586597 67 | 2020-09-07 23:45:39,813.813 INFO t2t_trainer - save_ck: there are already 3 checkpoints saved that will be more than keep_ck_num=3 68 | 2020-09-07 23:45:39,813.813 INFO t2t_trainer - save_ck: hence, remove the oldest one: tmp/t5-large_translation_covid_event/ck_at_epoch_3.h5 69 | 2020-09-07 23:45:40,221.221 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_6.h5 70 | 2020-09-07 23:45:49,928.928 INFO t2t_trainer - train: start training at epoch = 7 71 | 2020-09-07 23:45:49,928.928 INFO t2t_trainer - train: global train batch size = 16 72 | 2020-09-07 23:45:49,928.928 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear 73 | 2020-09-07 23:45:49,929.929 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809 74 | 2020-09-07 23:45:49,929.929 INFO t2t_trainer - train: warmup_steps:2170 75 | 2020-09-07 23:58:28,345.345 INFO t2t_trainer - train: train loss at end of epoch 7: 0.29620171093390696 76 | 2020-09-07 23:58:28,346.346 INFO t2t_trainer - save_ck: there are already 3 checkpoints saved that will be more than keep_ck_num=3 77 | 2020-09-07 23:58:28,346.346 INFO t2t_trainer - save_ck: hence, remove the oldest one: tmp/t5-large_translation_covid_event/ck_at_epoch_4.h5 78 | 2020-09-07 23:58:28,736.736 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_7.h5 79 | 2020-09-07 23:58:37,138.138 INFO t2t_trainer - train: start training at epoch = 8 80 | 2020-09-07 23:58:37,138.138 INFO t2t_trainer - train: global train batch size = 16 81 | 2020-09-07 23:58:37,138.138 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear 82 | 2020-09-07 23:58:37,138.138 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809 83 | 2020-09-07 23:58:37,138.138 INFO t2t_trainer - train: warmup_steps:2170 84 | 2020-09-08 00:10:59,173.173 INFO t2t_trainer - train: train loss at end of epoch 8: 0.24166541466643368 85 | 2020-09-08 00:10:59,174.174 INFO t2t_trainer - save_ck: there are already 3 checkpoints saved that will be more than keep_ck_num=3 86 | 2020-09-08 00:10:59,175.175 INFO t2t_trainer - save_ck: hence, remove the oldest one: tmp/t5-large_translation_covid_event/ck_at_epoch_5.h5 87 | 2020-09-08 00:10:59,680.680 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_8.h5 88 | 2020-09-08 00:11:09,666.666 INFO t2t_trainer - train: start training at epoch = 9 89 | 2020-09-08 00:11:09,666.666 INFO t2t_trainer - train: global train batch size = 16 90 | 2020-09-08 00:11:09,666.666 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear 91 | 2020-09-08 00:11:09,666.666 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809 92 | 2020-09-08 00:11:09,666.666 INFO t2t_trainer - train: warmup_steps:2170 93 | 2020-09-08 00:23:31,439.439 INFO t2t_trainer - train: train loss at end of epoch 9: 0.1981308291555007 94 | 2020-09-08 00:23:31,440.440 INFO t2t_trainer - save_ck: there are already 3 checkpoints saved that will be more than keep_ck_num=3 95 | 2020-09-08 00:23:31,440.440 INFO t2t_trainer - save_ck: hence, remove the oldest one: tmp/t5-large_translation_covid_event/ck_at_epoch_6.h5 96 | 2020-09-08 00:23:31,846.846 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_9.h5 97 | 2020-09-08 00:23:41,433.433 INFO t2t_trainer - train: start training at epoch = 10 98 | 2020-09-08 00:23:41,433.433 INFO t2t_trainer - train: global train batch size = 16 99 | 2020-09-08 00:23:41,433.433 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear 100 | 2020-09-08 00:23:41,433.433 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809 101 | 2020-09-08 00:23:41,433.433 INFO t2t_trainer - train: warmup_steps:2170 102 | 2020-09-08 00:36:03,731.731 INFO t2t_trainer - train: train loss at end of epoch 10: 0.1577673581983842 103 | 2020-09-08 00:36:03,732.732 INFO t2t_trainer - save_ck: there are already 3 checkpoints saved that will be more than keep_ck_num=3 104 | 2020-09-08 00:36:03,732.732 INFO t2t_trainer - save_ck: hence, remove the oldest one: tmp/t5-large_translation_covid_event/ck_at_epoch_7.h5 105 | 2020-09-08 00:36:04,471.471 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_10.h5 106 | 2020-09-08 00:36:17,836.836 INFO t2t_trainer - train: start training at epoch = 11 107 | 2020-09-08 00:36:17,836.836 INFO t2t_trainer - train: global train batch size = 16 108 | 2020-09-08 00:36:17,836.836 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear 109 | 2020-09-08 00:36:17,836.836 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809 110 | 2020-09-08 00:36:17,836.836 INFO t2t_trainer - train: warmup_steps:2170 111 | 2020-09-08 00:48:38,736.736 INFO t2t_trainer - train: train loss at end of epoch 11: 0.13311179534950685 112 | 2020-09-08 00:48:38,737.737 INFO t2t_trainer - save_ck: there are already 3 checkpoints saved that will be more than keep_ck_num=3 113 | 2020-09-08 00:48:38,737.737 INFO t2t_trainer - save_ck: hence, remove the oldest one: tmp/t5-large_translation_covid_event/ck_at_epoch_8.h5 114 | 2020-09-08 00:48:39,441.441 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_11.h5 115 | -------------------------------------------------------------------------------- /covid_event/finetune.py: -------------------------------------------------------------------------------- 1 | from ttt import * 2 | 3 | if __name__ == '__main__': 4 | args = get_args() 5 | ## uncomment if debugging 6 | # logger.info(f"args: {json.dumps(args.__dict__, indent=2)}") 7 | # ############### customize args 8 | # args.use_gpu = True 9 | # # args.use_tpu = True 10 | # # args.tpu_address = "x.x.x.x" 11 | # args.do_train = True 12 | # args.use_tb = True 13 | # # any one from MODELS_SUPPORT (check:ttt/args.py) 14 | # args.model_select = "t5-base" 15 | # # select a dataset. First check if it is from nlp, if yes load it here and save locally to the data_path 16 | # # or customize a data in the data_path (train.json, val.json, test.json) where examples are organised in jsonl format 17 | # # each line represents an example like this: {"text": "...", "label","..."} 18 | # args.data_path = "data/final" 19 | # # any one from TASKS_SUPPORT (check:ttt/args.py) 20 | # args.task = "t2t" 21 | # args.log_steps = -1 22 | # # set do_eval = False if your data does not contain a validation set. In that case, patience, and early_stop will be invalid 23 | # args.do_eval = True 24 | # args.eval_batch_size=32 25 | # args.per_device_train_batch_size=8 26 | # args.num_epochs_train=12 27 | # args.source_field_name = "source" 28 | # args.target_field_name = "target" 29 | # args.max_src_length = 512 30 | # args.max_tgt_length = 512 31 | # args.task = "translation" # translation here generalizes to all source-target like tasks 32 | # args.lr=5e-5 33 | # # any one from LR_SCHEDULER_SUPPORT (check:ttt/args.py) 34 | # args.scheduler = "warmuplinear" 35 | ############### end customize args 36 | # to have a sanity check for the args 37 | sanity_check(args) 38 | # seed everything, make deterministic 39 | set_seed(args.seed) 40 | tokenizer = get_tokenizer(args) 41 | inputs = get_inputs(tokenizer, args) 42 | model, strategy = create_model(args, logger, get_model) 43 | # start training, here we customize T2TTrainer to get more control and flexibility 44 | trainer = T2TTrainer(args) 45 | trainer.train(model, strategy, tokenizer, inputs) 46 | -------------------------------------------------------------------------------- /covid_event/finetune_pt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ttt import * 3 | from datasets import load_dataset 4 | from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config 5 | import transformers 6 | from sklearn.metrics import accuracy_score, classification_report 7 | 8 | transformers.logging.set_verbosity_info() 9 | logger = transformers.logging.get_logger() 10 | 11 | 12 | def get_scheduler(optimizer, scheduler: str, warmup_steps: int, num_total: int): 13 | assert scheduler in ["constantlr", "warmuplinear", "warmupconstant", "warmupcosine", 14 | "warmupcosinewithhardrestarts"], ( 15 | 'scheduler should be one of ["constantlr","warmupconstant","warmupcosine","warmupcosinewithhardrestarts"]') 16 | if scheduler == 'constantlr': 17 | return transformers.get_constant_schedule(optimizer) 18 | elif scheduler == 'warmupconstant': 19 | return transformers.get_constant_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps) 20 | elif scheduler == 'warmuplinear': 21 | return transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, 22 | num_training_steps=num_total) 23 | elif scheduler == 'warmupcosine': 24 | return transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, 25 | num_training_steps=num_total) 26 | elif scheduler == 'warmupcosinewithhardrestarts': 27 | return transformers.get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, 28 | num_warmup_steps=warmup_steps, 29 | num_training_steps=num_total) 30 | 31 | 32 | # we use transformers logger here so we can log message to a file locally 33 | if __name__ == '__main__': 34 | args = get_args() 35 | ############### customize args 36 | args.dataset_name = "c_data" 37 | args.model_select = "t5-small-ex-pretrain" 38 | args.from_pretrained = False 39 | 40 | args.load_train_num = -1 41 | args.load_val_num = -1 42 | args.batch_size = 8 43 | args.epochs = 12 44 | args.do_eval = True # eval at end of epoch 45 | args.log_steps = -1 # eval at end of epoch 46 | args.lr = 5e-5 47 | args.scheduler = "warmuplinear" 48 | args.warmup_steps = 0.1 49 | args.grad_norm_clip = 1.0 50 | 51 | ################## 52 | add_filehandler_for_logger(".", logger) 53 | 54 | # The following may be necessary 55 | # pyarrow_path = r"C:\Users\USERNAME\.cache\huggingface\datasets\{}\default\0.0.0".format(args.dataset_name) 56 | # if not sys.platform.startswith("win"): 57 | # pyarrow_path = f"/home/USERNAME/.cache/huggingface/datasets/{args.dataset_name}/default/0.0.0" 58 | # if not os.path.isdir(pyarrow_path): 59 | # os.makedirs(pyarrow_path, exist_ok=True) 60 | 61 | dataset = load_dataset(f"{args.dataset_name}.py") 62 | 63 | if args.load_train_num > 0: 64 | train = load_dataset(f"{args.dataset_name}.py", split=f"train[:{args.load_train_num}]") 65 | dataset["train"] = train 66 | 67 | if args.load_val_num > 0: 68 | val = load_dataset(f"{args.dataset_name}.py", split=f"validation[:{args.load_val_num}]") 69 | dataset["validation"] = val 70 | 71 | tokenizer = T5Tokenizer.from_pretrained(args.model_select) 72 | 73 | 74 | def convert_to_features(example_batch): 75 | encoded_source = tokenizer(example_batch["source"]) 76 | encoded_target = tokenizer(example_batch["target"]) 77 | encoded_source.update( 78 | {"labels": encoded_target["input_ids"], "decoder_attention_mask": encoded_target["attention_mask"]}) 79 | return encoded_source 80 | 81 | 82 | def collate_fn(examples): 83 | # dynamically padding 84 | source_inputs = [{"input_ids": each["input_ids"], "attention_mask": each["attention_mask"]} for each in 85 | examples] 86 | target_inputs = [{"input_ids": each["labels"], "attention_mask": each["decoder_attention_mask"]} for each in 87 | examples] 88 | source_inputs_padded = tokenizer.pad(source_inputs, return_tensors='pt') 89 | target_inputs_padded = tokenizer.pad(target_inputs, return_tensors='pt') 90 | source_inputs_padded.update({"labels": target_inputs_padded["input_ids"], 91 | "decoder_attention_mask": target_inputs_padded["attention_mask"]}) 92 | return source_inputs_padded 93 | 94 | 95 | encoded = dataset.map(convert_to_features, batched=True) 96 | columns = ['input_ids', 'attention_mask', 'labels', 'decoder_attention_mask'] 97 | encoded.set_format(type='torch', columns=columns) 98 | 99 | train_dataloader = torch.utils.data.DataLoader(encoded["train"], collate_fn=collate_fn, batch_size=args.batch_size) 100 | val_dataloader = torch.utils.data.DataLoader(encoded["validation"], collate_fn=collate_fn, 101 | batch_size=args.batch_size * 4) 102 | 103 | if args.from_pretrained: 104 | model = T5ForConditionalGeneration.from_pretrained(args.model_select) 105 | else: 106 | config = T5Config.from_pretrained(args.model_select) 107 | model = T5ForConditionalGeneration(config) 108 | 109 | no_decay = ["bias", "LayerNorm.weight"] 110 | params_decay = [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)] 111 | params_nodecay = [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)] 112 | optim_groups = [ 113 | {"params": params_decay, "weight_decay": 0.1}, 114 | {"params": params_nodecay, "weight_decay": 0.0}, 115 | ] 116 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 117 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 118 | model.train().to(device) 119 | scheduler = get_scheduler(optimizer, scheduler=args.scheduler, warmup_steps=int(0.1 * len(train_dataloader)), 120 | num_total=args.epochs * len(train_dataloader)) 121 | 122 | 123 | def evaluate(): 124 | model.eval() 125 | gts = [] 126 | preds = [] 127 | for batch in tqdm(val_dataloader, total=len(val_dataloader), desc="evaluating..."): 128 | with torch.no_grad(): 129 | batch.to(device) 130 | predictions = model.generate(input_ids=batch["input_ids"], 131 | attention_mask=batch["attention_mask"]) 132 | pred = [tokenizer.decode(ids) for ids in predictions] 133 | gt = [tokenizer.decode(ids) for ids in batch["labels"]] 134 | preds.extend(pred) 135 | gts.extend(gt) 136 | 137 | eval_score = accuracy_score(gts, preds) 138 | logger.info(f"val_eval_score: {eval_score}") 139 | logger.info(f"val_cls_report: {classification_report(gts, preds, digits=4)}") 140 | 141 | 142 | losses = [] 143 | for epoch in tqdm(range(args.epochs), desc='epochs'): 144 | logger.info(f"start training epoch {epoch + 1}/{args.epochs}") 145 | base_steps = len(train_dataloader) * epoch 146 | pbar = tqdm(enumerate(train_dataloader), total=len(train_dataloader)) 147 | for it, batch in pbar: 148 | batch.to(device) 149 | outputs = model(**batch, return_dict=True) 150 | loss = outputs.loss 151 | loss = loss.mean() # collapse all losses if they are scattered on multiple gpus 152 | losses.append(loss.item()) 153 | 154 | loss.backward() 155 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm_clip) 156 | optimizer.step() 157 | scheduler.step() 158 | model.zero_grad() 159 | pbar.set_description( 160 | f"training - epoch {epoch + 1}/{args.epochs} iter {it}: train loss {loss.item():.5f}. lr {scheduler.get_last_lr()[0]:e}") 161 | 162 | if args.log_steps > 0 and (base_steps + it + 1) % args.log_steps == 0: 163 | logger.info(f'Step {base_steps + it + 1} - mean train loss: {np.mean(losses):.3}') 164 | logger.info(f"evaluate at global step = {base_steps + it + 1}") 165 | if args.do_eval: 166 | evaluate() 167 | model.train() 168 | 169 | if args.log_steps < 0: 170 | logger.info(f'Epoch {epoch + 1} - mean train loss: {np.mean(losses):.3}') 171 | logger.info(f"evaluate at epoch = {base_steps + it + 1}") 172 | if args.do_eval: 173 | evaluate() 174 | model.train() 175 | -------------------------------------------------------------------------------- /covid_event/inference.py: -------------------------------------------------------------------------------- 1 | from transformers import T5Tokenizer, TFT5ForConditionalGeneration, T5ForConditionalGeneration 2 | # the model will be downloaded automatically from Huggingface's model hub 3 | model_name_or_path = "congcongwang/t5-large-fine-tuned-wnut-2020-task3" 4 | # Tensorflow2.0 5 | model = TFT5ForConditionalGeneration.from_pretrained(model_name_or_path) 6 | 7 | # or PyTorch 8 | # model = T5ForConditionalGeneration.from_pretrained(model_name_or_path) 9 | 10 | tokenizer = T5Tokenizer.from_pretrained(model_name_or_path) 11 | 12 | source = "context: *Prince Charles tests positive for Corona* Prince William knowing he's " \ 13 | "the next in line to the throne: https://t.co/B1nmIpLj69. question: Who is tested positive?" \ 14 | "choices: author of the tweet, not specified, the next in line, line to the throne, *Prince Charles," \ 15 | " Corona* Prince William, he, the next, line, the throne." 16 | 17 | inputs = tokenizer.encode(source, return_tensors="tf") # Batch size 1. change "tf" to "pt" if using pytorch model 18 | result = model.generate(inputs) 19 | 20 | print(tokenizer.decode(result[0])) 21 | # output: Prince Charles 22 | -------------------------------------------------------------------------------- /covid_event/mismatches.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/covid_event/mismatches.png -------------------------------------------------------------------------------- /covid_event/predict.py: -------------------------------------------------------------------------------- 1 | import math 2 | from ttt import * 3 | 4 | if __name__ == '__main__': 5 | args = get_args() 6 | model_path = "tmp/t5-large-fine-tuned-wnut-task3" 7 | args.output_path = model_path 8 | 9 | assert os.path.isdir(args.output_path) 10 | assert os.path.isfile(os.path.join(args.output_path, "args.json")) 11 | args = Args(**json.load(open(os.path.join(args.output_path, "args.json")))) 12 | set_name = "val" 13 | args.output_path = model_path 14 | out_name = set_name + "_preds" 15 | args.data_path = "data/final" 16 | args.use_gpu = True 17 | args.use_tpu = False 18 | args.eval_batch_size = 4 19 | model, strategy = create_model(args, logger, get_model, save_args=False) 20 | model.load_weights(args.output_path + "/tf_model.h5") 21 | 22 | tokenizer = AutoTokenizer.from_pretrained(args.output_path) 23 | logger.info(f"********************start predicting {set_name} set********************") 24 | source_texts, encoded_source, _, meta = convert_t2t_examples( 25 | os.path.join(args.data_path, f'{set_name}.json'), tokenizer, args, with_target=False, return_meta=True) 26 | source_input_ids = encoded_source["input_ids"] 27 | predict_dataset = tf.data.Dataset.from_tensor_slices( 28 | (source_input_ids, encoded_source["attention_mask"])) 29 | 30 | predict_dataset = predict_dataset.batch(args.eval_batch_size) 31 | iter_num = math.ceil(len(source_input_ids) / args.eval_batch_size) 32 | preds = [] 33 | # with strategy.scope(): 34 | for input_ids, attention_mask in tqdm(predict_dataset, total=iter_num, desc=f"predicting {set_name}..."): 35 | predicts = model.generate(input_ids=input_ids, 36 | attention_mask=attention_mask, 37 | max_length=args.max_tgt_length) 38 | preds.extend([tokenizer.decode(ids) for ids in predicts]) 39 | 40 | if not os.path.isdir("preds"): 41 | os.makedirs("preds") 42 | with open(f"preds/{out_name}.json", "w+") as out: 43 | for meta, source, target in zip(meta, source_texts, preds): 44 | meta["source"] = source 45 | meta["target"] = target 46 | out.write(json.dumps(meta) + "\n") 47 | 48 | print("done predicting") 49 | -------------------------------------------------------------------------------- /covid_event/prepare.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | 5 | from tqdm import tqdm 6 | 7 | eventname2question = { 8 | "positive": "Does this tweet report an individual or a small group of people who is tested postive for coronavirus?", 9 | "negative": "Does this tweet report an individual or a small group of people who is tested negative for coronavirus?", 10 | "cure_and_prevention": "Does this tweet report cure and prevention for coronavirus?", 11 | "can_not_test": "Does this tweet report an individual or a small group of people who can not be tested for coronavirus?", 12 | "death": "Does this tweet report dealth for coronavirus?", 13 | } 14 | 15 | short2question_positive = {"name": "Who is tested positive?", 16 | "close_contact": "Who is in close contact with the person tested positive?", 17 | "employer": "Who is the employer of the people tested positive?", 18 | "recent_travel": "Where did the people tested positive recently visit?", 19 | "relation": "Does the infected person have a relationship with the author of the tweet?", 20 | "gender": "What is the gender of the people tested positive?", 21 | "age": "What is the age of the people tested positive?", 22 | "when": "When are tested positive cases reported?", 23 | "where": "Where are tested positive cases reported?", } 24 | 25 | short2choices_positive = {"name": ["chunks", "author of the tweet", "not specified"], 26 | "close_contact": ["chunks", "author of the tweet", "not specified"], 27 | "employer": ["chunks", "not specified"], 28 | "recent_travel": ["chunks", "near author of the tweet", "not specified"], 29 | "relation": ["yes", "no", "not specified"], 30 | "gender": ["male", "female", "not specified"], 31 | "age": ["chunks", "not specified"], 32 | "when": ["chunks", "not specified"], 33 | "where": ["chunks", "near author of the tweet", "not specified"], } 34 | 35 | short2question_negative = {"name": "Who is tested negative?", 36 | "close_contact": "Who is in close contact with the person tested negative?", 37 | "relation": "Does the infected person have a relationship with the author of the tweet?", 38 | "gender": "What is the gender of the people tested negative?", 39 | "age": "What is the age of the people tested negative?", 40 | "when": "When are tested negative cases reported?", 41 | "where": "Where are tested negative cases reported?", 42 | "how_long": "How long does it take to get to know the test results?"} 43 | 44 | short2choices_negative = {"name": ["chunks", "author of the tweet", "not specified"], 45 | "close_contact": ["chunks", "author of the tweet", "not specified"], 46 | "relation": ["yes", "no", "not specified"], 47 | "gender": ["male", "female", "not specified"], 48 | "age": ["chunks", "not specified"], 49 | "when": ["chunks", "not specified"], 50 | "where": ["chunks", "near author of the tweet", "not specified"], 51 | "how_long": ["chunks", "not specified"], } 52 | 53 | short2question_can_not_test = {"name": "Who can not get a test?", 54 | "symptoms": "Is the untested person currently experiencing any COVID-19 related symptoms?", 55 | "relation": "Does the untested person have a relationship with the author of the tweet?", 56 | "when": "When is the can’t-be-tested situation reported?", 57 | "where": "Where is the can’t-be-tested situation reported?", } 58 | 59 | short2choices_can_not_test = {"name": ["chunks", "author of the tweet", "not specified"], 60 | "symptoms": ["yes", "no", "not specified"], 61 | "relation": ["yes", "no", "not specified"], 62 | "when": ["chunks", "not specified"], 63 | "where": ["chunks", "near author of the tweet", "not specified"], } 64 | 65 | short2question_dealth = {"name": "Who is dead for coronavirus?", 66 | "symptoms": "Did the person who was dead experience COVID-19 related symptoms?", 67 | "relation": "Does the deceased person have a relationship with the author of the tweet?", 68 | "when": "When is the dead case reported?", 69 | "where": "Where is the dead case reported?", 70 | "age": "What is the age of the people who is dead of COVID-19?", } 71 | 72 | short2choices_death = {"name": ["chunks", "author of the tweet", "not specified"], 73 | "symptoms": ["yes", "no", "not specified"], 74 | "relation": ["yes", "no", "not specified"], 75 | "when": ["chunks", "not specified"], 76 | "where": ["chunks", "near author of the tweet", "not specified"], 77 | "age": ["chunks", "not specified"], } 78 | 79 | short2question_cure_and_prevention = {"opinion": "Does the author of tweet believe the cure method is effective?", 80 | "what_cure": "What is the cure for coronavirus mentioned by the author of the tweet?", 81 | "who_cure": "Who is promoting the cure for coronavirus?", } 82 | 83 | short2choices_cure_and_prevention = {"opinion": ["not_effective", "effective"], 84 | "what_cure": ["chunks", "not specified"], 85 | "who_cure": ["chunks", "author of the tweet", "not specified"], } 86 | 87 | eventname2questionmapping = {"positive": short2question_positive, 88 | "negative": short2question_negative, 89 | "can_not_test": short2question_can_not_test, 90 | "death": short2question_dealth, 91 | "cure_and_prevention": short2question_cure_and_prevention, 92 | } 93 | 94 | eventname2choicesnmapping = {"positive": short2choices_positive, 95 | "negative": short2choices_negative, 96 | "can_not_test": short2choices_can_not_test, 97 | "death": short2choices_death, 98 | "cure_and_prevention": short2choices_cure_and_prevention, } 99 | 100 | event_type2part2annotation_keys = { 101 | "can_not_test": [ 102 | "part2-relation.Response", 103 | "part2-symptoms.Response", 104 | "part2-name.Response", 105 | "part2-when.Response", 106 | "part2-where.Response" 107 | ], 108 | "cure_and_prevention": [ 109 | "part2-opinion.Response", 110 | "part2-what_cure.Response", 111 | "part2-who_cure.Response" 112 | ], 113 | "negative": [ 114 | "part2-age.Response", 115 | "part2-close_contact.Response", 116 | "part2-gender.Response", 117 | "part2-how_long.Response", 118 | "part2-name.Response", 119 | "part2-relation.Response", 120 | "part2-when.Response", 121 | "part2-where.Response" 122 | ], 123 | "death": [ 124 | "part2-age.Response", 125 | "part2-name.Response", 126 | "part2-relation.Response", 127 | "part2-symptoms.Response", 128 | "part2-when.Response", 129 | "part2-where.Response" 130 | ], 131 | "positive": [ 132 | "part2-age.Response", 133 | "part2-close_contact.Response", 134 | "part2-employer.Response", 135 | "part2-gender.Response", 136 | "part2-name.Response", 137 | "part2-recent_travel.Response", 138 | "part2-relation.Response", 139 | "part2-when.Response", 140 | "part2-where.Response" 141 | ] 142 | } 143 | 144 | def get_part1_new_example(example): 145 | event_type = example["event_type"] 146 | new_example = {} 147 | new_example["id"] = example["id_str"] # id 148 | new_example["event_type"] = event_type # event_type 149 | new_example["slot_type"] = "part1.Response" # slot_type 150 | new_example["context"] = example["full_text"] # context 151 | new_example["question"] = eventname2question[event_type] # question 152 | new_example["choices"] = "yes, no" # candidate choices 153 | new_example["answer"] = example["annotation"]["part1.Response"][0] 154 | return new_example 155 | 156 | def get_text_chunks(example): 157 | full_text = example["full_text"] # context 158 | candidate_chunks_offsets = example["candidate_chunks_offsets"] 159 | text_chunks = [full_text[each[0]:each[1]] for each in candidate_chunks_offsets] 160 | return text_chunks 161 | 162 | 163 | def get_total(filepath): 164 | total = 0 165 | with open(filepath, "r") as f: 166 | for line in f: 167 | total += 1 168 | return total 169 | 170 | 171 | def build_test(filepath, event_type="can_not_test"): 172 | new_examples = [] 173 | with open(filepath, "r") as f: 174 | for line in tqdm(f): 175 | example = json.loads(line.strip()) 176 | text_chunks = get_text_chunks(example) 177 | part2annotation_keys = event_type2part2annotation_keys[event_type] 178 | for each_part2_annotation_key in part2annotation_keys: 179 | slot_key = each_part2_annotation_key.split(".")[0].split("-")[-1] 180 | slot_question = eventname2questionmapping[event_type][slot_key] 181 | slot_candidate_choices = copy.deepcopy(eventname2choicesnmapping[event_type][slot_key]) 182 | if "chunks" in slot_candidate_choices: 183 | slot_candidate_choices.remove("chunks") 184 | slot_candidate_choices.extend(text_chunks) 185 | new_example = {} 186 | full_text = example["text"] 187 | new_example["id"] = example["id"] # id 188 | new_example["event_type"] = event_type # event_type 189 | new_example["slot_type"] = each_part2_annotation_key # slot_type 190 | new_example["context"] = full_text # context 191 | new_example["question"] = slot_question # question 192 | new_example["candidates"] = slot_candidate_choices # candidate choices 193 | new_example["choices"] = ", ".join(slot_candidate_choices) # candidate choices 194 | new_examples.append(new_example) 195 | return new_examples 196 | 197 | 198 | def build_train_val(filepath, test_size=0.1, out_folder="data/middle"): 199 | total = get_total(filepath) 200 | test_no = int(total * test_size) 201 | # new_examples = [] 202 | train, val = [], [] 203 | NO_CONSENSUS_count = 0 204 | with open(filepath, "r") as f: 205 | for line in f: 206 | if total <= test_no: 207 | new_examples = val 208 | else: 209 | new_examples = train 210 | example = json.loads(line.strip()) 211 | if example["annotation"]["part1.Response"][0] == "yes": 212 | new_example = get_part1_new_example(example) 213 | new_examples.append(new_example) 214 | text_chunks = get_text_chunks(example) 215 | event_type = example["event_type"] 216 | annotation = example["annotation"] 217 | annotation.pop("part1.Response") 218 | for each_part2_annotation_key, value in annotation.items(): 219 | if value == "NO_CONSENSUS": 220 | NO_CONSENSUS_count += 1 221 | continue 222 | slot_key = each_part2_annotation_key.split(".")[0].split("-")[-1] 223 | slot_question = eventname2questionmapping[event_type][slot_key] 224 | slot_candidate_choices = copy.deepcopy(eventname2choicesnmapping[event_type][slot_key]) 225 | if "chunks" in slot_candidate_choices: 226 | slot_candidate_choices.remove("chunks") 227 | slot_candidate_choices.extend(text_chunks) 228 | new_example = {} 229 | full_text = example["full_text"] 230 | new_example["id"] = example["id_str"] # id 231 | new_example["event_type"] = event_type # event_type 232 | new_example["slot_type"] = each_part2_annotation_key # slot_type 233 | new_example["context"] = full_text # context 234 | new_example["question"] = slot_question # question 235 | new_example["choices"] = ", ".join(slot_candidate_choices) # candidate choices 236 | answers = [] 237 | for v in value: 238 | if isinstance(v, str): 239 | if v.lower() in ["no_cure", "not_effective", "no_opinion"]: 240 | answers.append("not_effective") 241 | else: 242 | answers.append(v.lower()) 243 | else: 244 | answers.append(full_text[v[0]:v[1]]) 245 | assert set(answers).intersection(set(slot_candidate_choices)) != set() 246 | new_example["answer"] = ", ".join(answers) # answer 247 | new_examples.append(new_example) 248 | else: 249 | new_example = get_part1_new_example(example) 250 | new_examples.append(new_example) 251 | total -= 1 252 | print(f"there are {NO_CONSENSUS_count} NO_CONSENSUS") 253 | if not os.path.isdir(out_folder): 254 | os.makedirs(out_folder, exist_ok=True) 255 | with open(out_folder + "/train.json", "w+") as f: 256 | for ex in train: 257 | f.write(json.dumps(ex) + "\n") 258 | with open(out_folder + "/val.json", "w+") as f: 259 | for ex in val: 260 | f.write(json.dumps(ex) + "\n") 261 | print(f"done building train and val from {filepath}, written to {out_folder}") 262 | 263 | 264 | def construct(data_path, use_choices=True, no_answer=False, out_folder="data/final"): 265 | if not os.path.isdir(out_folder): 266 | os.makedirs(out_folder, exist_ok=True) 267 | 268 | with(open(out_folder + "/" + data_path.split("/")[-1], "w+")) as tgt: 269 | with open(data_path, "r") as f: 270 | for line in tqdm(f, desc="reading..."): 271 | example = json.loads(line.strip()) 272 | source_sequence = "context: " + example["context"] + " question: " + example["question"] 273 | if use_choices: 274 | source_sequence += " choices: " + example["choices"] 275 | if not no_answer: 276 | target_sequence = example["answer"] 277 | tgt.write(json.dumps( 278 | {"id": example["id"], "event_type": example["event_type"], "slot_type": example["slot_type"], 279 | "source": source_sequence, "target": target_sequence}) + "\n") 280 | else: 281 | tgt.write(json.dumps( 282 | {"id": example["id"], "event_type": example["event_type"], "slot_type": example["slot_type"], 283 | "source": source_sequence, "candidates": example["candidates"]}) + "\n") 284 | print("done source and target sequences construction") 285 | 286 | if __name__ == '__main__': 287 | ''' 288 | sequence construction examples: 289 | X1: context: @CPRewritten I heard from @Corona_Bot__ that G is tested positive of COVID-19. question: Who is tested positive? choices: Not Specified, A, B, C. 290 | y1: A 291 | X2: context: @CPRewritten I heard from @Corona_Bot__ that G is tested positive of COVID-19. question: Does this message report an individual or a small group of people who is tested postive for coronavirus. choices: yes, no. 292 | y2: yes 293 | ''' 294 | build_train_val("data/corpus.json", test_size=0.1, out_folder="data/middle") 295 | construct("data/middle/train.json", out_folder="data/final") 296 | construct("data/middle/val.json", out_folder="data/final") 297 | 298 | ### when test set is available, uncomment the following 299 | # can_not_test_examples = build_test("shared_task-test-can_not_test.jsonl", event_type="can_not_test") 300 | # cure_and_prevention_examples = build_test("shared_task-test-cure.jsonl", event_type="cure_and_prevention") 301 | # death_examples = build_test("shared_task-test-death.jsonl", event_type="death") 302 | # negative_examples = build_test("shared_task-test-negative.jsonl", event_type="negative") 303 | # postive_examples = build_test("shared_task-test-positive.jsonl", event_type="positive") 304 | # all = can_not_test_examples + cure_and_prevention_examples + death_examples + negative_examples + postive_examples 305 | # with open("../test.json", "w+") as f: 306 | # for ex in all: 307 | # f.write(json.dumps(ex) + "\n") 308 | # construct("ori/test.json", no_answer=True) 309 | -------------------------------------------------------------------------------- /covid_event/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/covid_event/results.png -------------------------------------------------------------------------------- /covid_event/submit.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import json, os 3 | 4 | def process_target(event_type, slot_type, target): 5 | # For opinion slot in cure & prevention category, you shall only have two labels: "effective" and "not_effective" ("no_opinion", "no_cure" and "not_effective" will be merged into "not_effective") 6 | if event_type == "cure_and_prevention" and "opinion" in slot_type: 7 | if "no_opinion" in target: 8 | target.remove("no_opinion") 9 | target.append("not_effective") 10 | if "no_cure" in target: 11 | target.remove("no_cure") 12 | target.append("not_effective") 13 | # "I" (along with its variations) shall be replaced with "AUTHOR OF THE TWEET" both for NAME and CLOSE_CONTACT slot for all categories. 14 | if "name" in slot_type or "close_contact" in slot_type: 15 | if "i" in target: 16 | # It doesn't matter if your predictions are lowercased or uppercased. 17 | target.remove("i") 18 | target.append("author of the tweet") 19 | if "i'm" in target: 20 | target.remove("i'm") 21 | target.append("author of the tweet") 22 | # For relation slot and symptoms slot, you shall only have two labels: "yes" and "not specified" ("no" and "not specified" will be merged into "not specified"). 23 | if "relation" in slot_type or "symptoms" in slot_type: 24 | if "no" in target: 25 | target.remove("no") 26 | target.append("not specified") 27 | return list(set(target)) # do not submit repeated answers 28 | 29 | def convert(data_path, output_path="t5_sub", annotation_key="predicted_annotation"): 30 | with open(data_path, "r") as f: 31 | event2converted_examples = {} 32 | converted_examples = [] 33 | converted_one_example = {'id': '', annotation_key: {}} 34 | last_event = "" 35 | last_id = "" 36 | switch = False 37 | 38 | for line in tqdm(f): 39 | example = json.loads(line.strip()) 40 | slot_type = example["slot_type"] 41 | if "part2" in slot_type: 42 | event_type = example["event_type"] 43 | if event_type != last_event: 44 | if last_event != "": 45 | switch = False 46 | converted_examples.append(converted_one_example) 47 | event2converted_examples[last_event] = converted_examples 48 | converted_examples = [] 49 | 50 | id = example["id"] 51 | target = example["target"].split(", ") 52 | 53 | if id != last_id: 54 | if last_id != "" and switch: 55 | converted_examples.append(converted_one_example) 56 | 57 | # Death: "symptoms" slot will be excluded 58 | # Tested Negative: "how long" slots will be excluded 59 | if (event_type == "death" and "symptoms" in slot_type) or ( 60 | event_type == "negative" and "how_long" in slot_type): 61 | continue 62 | 63 | target = process_target(event_type, slot_type, target) 64 | 65 | converted_one_example = {'id': id, 66 | annotation_key: {slot_type: target}} 67 | 68 | else: 69 | switch = True 70 | # Death: "symptoms" slot will be excluded 71 | # Tested Negative: "how long" slots will be excluded 72 | if (event_type == "death" and "symptoms" in slot_type) or ( 73 | event_type == "negative" and "how_long" in slot_type): 74 | continue 75 | target = process_target(event_type, slot_type, target) 76 | converted_one_example[annotation_key][slot_type] = target 77 | 78 | last_id = id 79 | last_event = event_type 80 | converted_examples.append(converted_one_example) 81 | event2converted_examples[last_event] = converted_examples 82 | 83 | if not os.path.isdir(output_path): 84 | os.makedirs(output_path, exist_ok=True) 85 | 86 | for event, examples in event2converted_examples.items(): 87 | if "cure" in event: 88 | event = "cure" 89 | output_file_path = output_path + event + ".jsonl" 90 | with open(output_file_path, "w+") as f: 91 | for example in examples: 92 | f.write(json.dumps(example) + "\n") 93 | print(f'done writing to {output_file_path}') 94 | return event2converted_examples 95 | 96 | if __name__ == '__main__': 97 | convert("preds/val_preds.json", output_path="subs/val-run-1/") 98 | -------------------------------------------------------------------------------- /covid_event/subs/post_process.py: -------------------------------------------------------------------------------- 1 | import os, json, pickle 2 | 3 | def read_groundtruth(gt_dir="subs/golden"): 4 | event_types = ['positive', 'negative', 'can_not_test', 'death', 'cure'] 5 | output = {} 6 | for each_event in event_types: 7 | with open(os.path.join(gt_dir, each_event + "_sol.jsonl"), "r") as f: 8 | for line in f: 9 | ins = json.loads(line) 10 | if ins["id"] not in output: 11 | output[ins["id"]] = {} 12 | for key, slot_gt_list in ins["golden_annotation"].items(): 13 | output[ins["id"]][key] = [e.lower() for e in slot_gt_list] 14 | return output 15 | 16 | grounds = read_groundtruth() 17 | 18 | def levenshtein(s1, s2): 19 | ''' 20 | from: https://en.wikibooks.org/wiki/Algorithm_Implementation/Strings/Levenshtein_distance 21 | :param s1: first string 22 | :param s2: second string 23 | :return: 24 | ''' 25 | if len(s1) < len(s2): 26 | return levenshtein(s2, s1) 27 | # len(s1) >= len(s2) 28 | if len(s2) == 0: 29 | return len(s1) 30 | 31 | previous_row = range(len(s2) + 1) 32 | for i, c1 in enumerate(s1): 33 | current_row = [i + 1] 34 | for j, c2 in enumerate(s2): 35 | insertions = previous_row[ 36 | j + 1] + 1 # j+1 instead of j since previous_row and current_row are one character longer 37 | deletions = current_row[j] + 1 # than s2 38 | substitutions = previous_row[j] + (c1 != c2) 39 | current_row.append(min(insertions, deletions, substitutions)) 40 | previous_row = current_row 41 | return previous_row[-1] 42 | 43 | 44 | def convert_to_candidates(id, slot_type, target, id2condidates): 45 | candidates = id2condidates[id][slot_type] 46 | intersection = set(candidates).intersection(set(target)) 47 | if len(intersection) == 0: 48 | transformed_targets = [] 49 | for t in [", ".join(target)]: 50 | min = 1000000 51 | min_index = 0 52 | for idx, c in enumerate(candidates): 53 | distance = levenshtein(c, t) / len(c) 54 | if distance < min: 55 | min = distance 56 | min_index = idx 57 | transformed_targets.append(candidates[min_index]) 58 | else: 59 | transformed_targets = target 60 | return transformed_targets 61 | 62 | 63 | def readJsonl(path): 64 | output = {} 65 | with open(path, 'r') as f: 66 | for line in f: 67 | ins = json.loads(line) 68 | output[ins["id"]] = ins["predicted_annotation"] 69 | return output 70 | 71 | 72 | def read_id2text(data_file="data/middle/test.json"): 73 | output = {} 74 | with open(data_file, 'r') as f: 75 | for line in f: 76 | ins = json.loads(line) 77 | if ins["id"] not in output: 78 | output[ins["id"]] = {} 79 | output[ins["id"]][ins["slot_type"]] = ins["context"].lower() 80 | return output 81 | 82 | 83 | def transform(): 84 | dir = "subs" 85 | id2condidates = pickle.load(open(os.path.join(dir, "testid2candidates.pkl"), "rb")) 86 | # the testid2candidates.pkl can also obtained by calling read_id2text() 87 | # read the README file on how to create data/middle/test.json (is unlabeled before the annotation release) 88 | from_runs = ["run-1", "run-2", "run-3"] 89 | to_runs = ["post-run-1", "post-run-2", "post-run-3"] 90 | event_types = ['positive', 'negative', 'can_not_test', 'death', 'cure'] 91 | for from_run, to_run in zip(from_runs, to_runs): 92 | for each_event in event_types: 93 | if not os.path.isdir(os.path.join(dir, to_run)): 94 | os.makedirs(os.path.join(dir, to_run)) 95 | with open(os.path.join(dir, to_run, each_event + ".jsonl"), "w+") as out: 96 | preds = readJsonl(os.path.join(dir, from_run, each_event + '.jsonl')) 97 | for id, slot_example_preds in preds.items(): 98 | transformed_pred_dict = {} 99 | for slot_key, slot_pred in slot_example_preds.items(): 100 | transformed_target = convert_to_candidates(id, slot_key, slot_pred, id2condidates) 101 | transformed_pred_dict[slot_key] = transformed_target 102 | out.write(json.dumps({"id": id, "predicted_annotation": transformed_pred_dict}) + "\n") 103 | print("done") 104 | 105 | 106 | if __name__ == '__main__': 107 | transform() 108 | -------------------------------------------------------------------------------- /covid_event/subs/testid2candidates.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/covid_event/subs/testid2candidates.pkl -------------------------------------------------------------------------------- /example_bert.py: -------------------------------------------------------------------------------- 1 | from ttt import * 2 | 3 | if __name__ == '__main__': 4 | args = get_args() 5 | # check what args are available 6 | logger.info(f"args: {json.dumps(args.__dict__, indent=2)}") 7 | ############### customize args 8 | args.use_gpu = True 9 | # args.use_tpu = True 10 | # args.tpu_address = "x.x.x.x" # replace with yours 11 | args.do_train = True 12 | args.use_tb = True 13 | # any one from MODELS_SUPPORT (check:ttt/args.py) 14 | args.model_select = "bert-base-uncased" 15 | # select a dataset following jsonl format, where text filed name is "text" and label field name is "label" 16 | args.data_path = "data/glue/sst2" 17 | # any one from TASKS_SUPPORT (check:ttt/args.py) 18 | args.task = "single-label-cls" 19 | args.log_steps = 1000 20 | # any one from LR_SCHEDULER_SUPPORT (check:ttt/args.py) 21 | args.scheduler="warmuplinear" 22 | # set do_eval = False if your data does not contain a validation set. In that case, patience, and early_stop will be invalid 23 | args.do_eval = True 24 | ############### end customize args 25 | # to have a sanity check for the args 26 | sanity_check(args) 27 | # seed everything, make deterministic 28 | set_seed(args.seed) 29 | tokenizer = get_tokenizer(args) 30 | inputs = get_inputs(tokenizer, args) 31 | model, _ = create_model(args, logger, get_model) 32 | # start training, here we keras high-level API 33 | training_history = model.fit( 34 | inputs["x_train"], 35 | inputs["y_train"], 36 | epochs=args.num_epochs_train, 37 | verbose=2, 38 | batch_size=args.per_device_train_batch_size*args.num_replicas_in_sync, 39 | callbacks=get_callbacks(args, inputs, logger, get_evaluator), 40 | ) -------------------------------------------------------------------------------- /example_t5.py: -------------------------------------------------------------------------------- 1 | from ttt import * 2 | import transformers 3 | transformers.logging.set_verbosity_info() 4 | logger = transformers.logging.get_logger() 5 | 6 | if __name__ == '__main__': 7 | args = get_args() 8 | # check what args are available and their default values 9 | logger.info(f"args: {json.dumps(args.__dict__, indent=2)}") 10 | ############### customize args 11 | args.use_gpu = True 12 | # args.use_tpu = True 13 | # args.tpu_address = "x.x.x.x" 14 | args.do_train = True 15 | args.use_tb = True 16 | # any one from MODELS_SUPPORT (check:ttt/args.py) 17 | args.model_select = "t5-small" 18 | # select a dataset. First check if it is from nlp, if yes load it here and save locally to the data_path 19 | # or customize a data in the data_path (train.json, val.json, test.json) where examples are organised in jsonl format 20 | # each line represents an example like this: {"text": "...", "label","..."} 21 | args.data_path = "data/glue/sst2" 22 | # any one from TASKS_SUPPORT (check:ttt/args.py) 23 | args.task = "t2t" 24 | args.log_steps = 400 25 | args.eval_batch_size=32 26 | args.per_device_train_batch_size=8 27 | args.max_src_length=128 28 | args.load_train_num = 1000 29 | # any one from LR_SCHEDULER_SUPPORT (check:ttt/args.py) 30 | args.scheduler = "warmuplinear" 31 | args.lr=5e-5 32 | # set do_eval = False if your data does not contain a validation set. In that case, patience, and early_stop will be invalid 33 | args.do_eval = True 34 | ############### end customize args 35 | # to have a sanity check for the args 36 | sanity_check(args,logger=logger) 37 | # seed everything, make deterministic 38 | # set_seed(args.seed) let's do this in trainer before start training 39 | tokenizer = get_tokenizer(args) 40 | inputs = get_inputs(tokenizer, args) 41 | model, strategy = create_model(args, logger, get_model) 42 | # start training, here we customize T2TTrainer to get more control and flexibility 43 | trainer = T2TTrainer(args, logger) 44 | trainer.train(model, strategy, tokenizer, inputs) 45 | -------------------------------------------------------------------------------- /example_trans_t5.py: -------------------------------------------------------------------------------- 1 | from ttt import * 2 | 3 | if __name__ == '__main__': 4 | args = get_args() 5 | # check what args are available and their default values 6 | logger.info(f"args: {json.dumps(args.__dict__, indent=2)}") 7 | ############### customize args 8 | args.use_gpu = True 9 | # args.tpu_address = "x.x.x.x" 10 | # args.use_tpu = True 11 | args.do_train = True 12 | args.use_tb = True 13 | # any one from MODELS_SUPPORT (check:ttt/args.py) 14 | args.model_select = "t5-large" 15 | # the path to the translation dataset, each line represents an example in jsonl format like: {"target": "...", "source","..."} 16 | args.data_path = "data/wmt_en_ro" 17 | # any one from TASKS_SUPPORT (check:ttt/args.py) 18 | args.task = "translation" 19 | args.max_src_length=128 20 | args.max_tgt_length=128 21 | args.eval_batch_size = 8 22 | args.log_steps = 1000 23 | args.source_field_name = "text" 24 | args.target_field_name = "label" 25 | args.per_device_train_batch_size = 2 26 | args.eval_on="bleu" #this refers to sacrebleu as used in T5 paper 27 | # any one from LR_SCHEDULER_SUPPORT (check:ttt/args.py) 28 | args.scheduler = "warmuplinear" 29 | # set do_eval = False if your data does not contain a validation set. In that case, patience, and early_stop will be invalid 30 | args.do_eval = True 31 | ############### end customize args 32 | # to have a sanity check for the args 33 | sanity_check(args) 34 | # seed everything, make deterministic 35 | set_seed(args.seed) 36 | tokenizer = get_tokenizer(args) 37 | inputs = get_inputs(tokenizer, args) 38 | model, strategy = create_model(args, logger, get_model) 39 | # start training, here we customize T2TTrainer to get more control and flexibility 40 | trainer = T2TTrainer(args) 41 | trainer.train(model, strategy, tokenizer, inputs) 42 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | 2 | from ttt import * 3 | from sklearn.metrics import classification_report, accuracy_score 4 | import math 5 | 6 | logging.basicConfig( 7 | format='%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s', 8 | datefmt='%Y-%m-%d %H:%M:%S', 9 | level=logging.INFO 10 | ) 11 | logger = logging.getLogger(__name__) 12 | 13 | if __name__ == '__main__': 14 | # get args 15 | args = get_args() 16 | # check what args are available 17 | logger.info(f"args: {json.dumps(args.__dict__, indent=2)}") 18 | # args.do_train = True 19 | # args.do_eval = True 20 | # args.do_test = True 21 | 22 | # args.use_tpu = False 23 | # args.use_gpu = True 24 | sanity_check(args) 25 | 26 | if args.do_train: 27 | add_filehandler_for_logger(args.output_path, logger) 28 | logger.info(f"tf.version.VERSION: {tf.version.VERSION}") 29 | logger.info(f"set seed for random, numpy and tensorflow: seed = {args.seed}") 30 | set_seed(args.seed) 31 | tokenizer = get_tokenizer(args) 32 | inputs = get_inputs(tokenizer, args) 33 | model, strategy = create_model(args, logger, get_model) 34 | if not args.use_tpu: # used during development 35 | model.run_eagerly = True # for debugging, this is cool. TF also supports debugging as pytorch 36 | # start training here 37 | if args.task == "t2t": 38 | # t5-like model is customized because we want more flexibility and control over the training loop 39 | # customize_fit(model, strategy, tokenizer, inputs, args) 40 | trainer = T2TTrainer(args) 41 | trainer.train(model, strategy, tokenizer, inputs) 42 | else: 43 | training_history = model.fit( 44 | inputs["x_train"], 45 | inputs["y_train"], 46 | epochs=args.num_epochs_train, 47 | verbose=2, 48 | batch_size=args.per_device_train_batch_size*args.num_replicas_in_sync, 49 | callbacks=get_callbacks(args, inputs, logger, get_evaluator), 50 | ) 51 | 52 | if args.do_test: 53 | add_filehandler_for_logger(args.output_path, logger, out_name="test") 54 | assert os.path.isdir(args.output_path) 55 | assert os.path.isfile(os.path.join(args.output_path, "args.json")) 56 | args = Args(**json.load(open(os.path.join(args.output_path, "args.json")))) 57 | model, strategy = create_model(args,logger, get_model) 58 | # to make it ok to task == "single-label-cls" todo 59 | ck_path = glob.glob(os.path.join(args.output_path, "best*.h5"))[0] 60 | if args.ck_index_select < 0 and -args.ck_index_select <= args.keep_ck_num: 61 | cks_path_already = glob.glob(os.path.join(args.output_path, "*.h5")) 62 | index2path = {int(os.path.basename(each_ck_path).split(".")[0].split("_")[-1]): each_ck_path for 63 | each_ck_path in cks_path_already} 64 | sorted_indices = sorted(index2path) 65 | ck_path = index2path[sorted_indices[args.ck_index_select]] 66 | 67 | logger.info(f"evaluating using weights from checkpoint: {ck_path}") 68 | model.load_weights(ck_path) 69 | 70 | tokenizer = AutoTokenizer.from_pretrained(args.output_path) 71 | 72 | logger.info("********************start evaluating on test set********************") 73 | if args.task == "single-label-cls": 74 | # tokenizer = get_tokenizer(args) 75 | test_texts, encoded_test, y_test = convert_seq_single_cls_examples( 76 | os.path.join(args.data_path, 'test.json'), 77 | tokenizer, args.input_seq_length, 78 | args.label2id) 79 | 80 | x_test = [encoded_test["input_ids"], encoded_test["token_type_ids"], encoded_test["attention_mask"]] 81 | 82 | if not args.use_tpu: # used during development 83 | model.run_eagerly = True # for debugging, this is cool. TF also supports debugging as pytorch 84 | 85 | pred_probs = model.predict(x_test, batch_size=args.eval_batch_size*strategy.num_replicas_in_sync, steps=math.ceil(len(y_test) / 32), verbose=1) 86 | preds = tf.math.argmax(pred_probs, 1).numpy() 87 | 88 | acc = accuracy_score(y_test, preds) 89 | 90 | target_names = [''] * len(args.label2id) 91 | for label, id in args.label2id.items(): 92 | target_names[id] = label 93 | 94 | logger.info( 95 | f"test_eval_report: {classification_report(y_test, preds, digits=4, target_names=target_names)}") 96 | logger.info(f"test_eval_acc: {acc}") 97 | 98 | elif args.task == "t2t": 99 | source_texts, encoded_source, encoded_target = convert_t2t_examples( 100 | os.path.join(args.data_path, 'test.json'), tokenizer, args) 101 | source_input_ids = encoded_source["input_ids"] 102 | test_dataset = tf.data.Dataset.from_tensor_slices( 103 | (source_input_ids, encoded_source["attention_mask"], encoded_target["input_ids"])) 104 | test_dataset = test_dataset.batch(args.eval_batch_size*strategy.num_replicas_in_sync) 105 | iter_num = math.ceil(len(source_input_ids) / args.eval_batch_size) 106 | preds = [] 107 | gts = [] 108 | 109 | for input_ids, attention_mask, gt in tqdm(test_dataset, total=iter_num, desc="testing..."): 110 | predicts = model.generate(input_ids=input_ids, 111 | attention_mask=attention_mask, 112 | max_length=args.max_tgt_length) 113 | preds.extend([tokenizer.decode(ids) for ids in predicts]) 114 | gts.extend([tokenizer.decode(ids) for ids in gt]) 115 | 116 | acc = accuracy_score(gts, preds) 117 | logger.info( 118 | f"test_eval_report: {classification_report(gts, preds, digits=4)}") 119 | logger.info(f"test_eval_acc: {acc}") 120 | # tensorboard dev upload --logdir runs -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", mode="r", encoding="utf-8") as readme_file: 4 | readme = readme_file.read() 5 | 6 | setup( 7 | name="pytriplet", # named pytriplet in pypi to avoid repeated name 8 | version="0.0.5", 9 | author="Congcong Wang", 10 | author_email="wangcongcongcc@gmail.com", 11 | description="Fine-tuning Transformers with TPUs or GPUs acceleration, written in Tensorflow2.0", 12 | long_description=readme, 13 | long_description_content_type="text/markdown", 14 | license="MIT License", 15 | url="https://github.com/wangcongcong123/ttt", 16 | download_url="https://github.com/wangcongcong123/ttt/releases/download/v0.0.3/pytriplet.tar.gz", 17 | packages=find_packages(), 18 | install_requires=[ 19 | "tensorflow==2.3.0", 20 | "sklearn", 21 | "tqdm", 22 | "keras", 23 | "tensorboardX", 24 | "nlp", 25 | "sacrebleu", 26 | "datasets", 27 | "transformers" 28 | ], 29 | classifiers=[ 30 | "Development Status :: 4 - Beta", 31 | "Intended Audience :: Science/Research", 32 | "License :: OSI Approved :: Apache Software License", 33 | "Programming Language :: Python :: 3.7", 34 | "Topic :: Scientific/Engineering :: Artificial Intelligence" 35 | ], 36 | keywords="Transformers, Tensorflow, TPUs acceleration" 37 | ) 38 | # commands for uploading to pypi 39 | # python setup.py sdist 40 | # pip install twine 41 | # twine upload dist/* 42 | -------------------------------------------------------------------------------- /ttt/__init__.py: -------------------------------------------------------------------------------- 1 | from ttt.inputs import * 2 | from ttt.utils import * 3 | from ttt.args import * 4 | from ttt.models import get_model 5 | from ttt.t2t_trainer import T2TTrainer 6 | from ttt.evaluators import get_evaluator -------------------------------------------------------------------------------- /ttt/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/ttt/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ttt/__pycache__/args.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/ttt/__pycache__/args.cpython-37.pyc -------------------------------------------------------------------------------- /ttt/__pycache__/evaluators.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/ttt/__pycache__/evaluators.cpython-37.pyc -------------------------------------------------------------------------------- /ttt/__pycache__/inputs.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/ttt/__pycache__/inputs.cpython-37.pyc -------------------------------------------------------------------------------- /ttt/__pycache__/models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/ttt/__pycache__/models.cpython-37.pyc -------------------------------------------------------------------------------- /ttt/__pycache__/t2t_trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/ttt/__pycache__/t2t_trainer.cpython-37.pyc -------------------------------------------------------------------------------- /ttt/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/ttt/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /ttt/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tqdm import tqdm 3 | from sklearn.model_selection import train_test_split 4 | from datasets import load_dataset 5 | import urllib.request 6 | import os, json, sys 7 | import shutil 8 | import tarfile 9 | from .utils import add_filehandler_for_logger 10 | # these have been tested and work fine. more can be added to this list to test 11 | MODELS_SUPPORT = ["distilbert-base-cased", "bert-base-uncased", "bert-large-uncased", 12 | "google/electra-base-discriminator", 13 | "google/electra-large-discriminator", "albert-base-v2", "roberta-base", 14 | "t5-small", "t5-base", "t5-large"] 15 | 16 | # if using t5 models, the tasks has to be t2t* ones 17 | TASKS_SUPPORT = ["single-label-cls", "t2t", "translation", "pretrain"] 18 | # in the future, more schedulers will be added, such as warmupconstant, warmupcosine, etc. 19 | LR_SCHEDULER_SUPPORT = ["warmuplinear", "warmupconstant", "constant", "constantlinear"] 20 | # warmuplinear refers to increasing lr from 0 to a stage specified by warmup_ratio (ratio of total iterations) and then decreasing lr linearly 21 | ADDITIONAL_DATA_SUPPORT = {"wmt_en_ro": "https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_ro.tar.gz"} 22 | 23 | 24 | def ts2jsonl(data_path, t5_prefix_source=True): 25 | """ 26 | t5_prefix_source: by default, t5_prefix_source, "translate English to Romanian: " is prepended to source sequence as in: https://arxiv.org/pdf/1910.10683.pdf 27 | """ 28 | sets = ["train", "val", "test"] 29 | for set_name in sets: 30 | if os.path.isfile(os.path.join(data_path, set_name + ".source")): 31 | with open(os.path.join(data_path, set_name + ".source"), encoding="utf-8") as src: 32 | src_lines = src.readlines() 33 | with open(os.path.join(data_path, set_name + ".target"), encoding="utf-8") as src: 34 | tgt_lines = src.readlines() 35 | assert len(src_lines) == len(tgt_lines) 36 | if set_name == "train": 37 | #### for any train with examples more than 10000, we get 10000 to form a fixture training set (well suited for quick model development and prototyping) 38 | with open(os.path.join(data_path, set_name + "_fixture.json"), "w+", encoding="utf-8") as f: 39 | for source, target in tqdm(zip(src_lines[:10000], tgt_lines[:10000])): 40 | f.write(json.dumps({"source": "translate English to Romanian: " + source.strip(), 41 | "target": target.strip()}) + "\n") 42 | 43 | with open(os.path.join(data_path, set_name + ".json"), "w+", encoding="utf-8") as f: 44 | for source, target in tqdm(zip(src_lines, tgt_lines), 45 | desc=f"converting original data: {data_path} to jsonl formats"): 46 | if t5_prefix_source: 47 | source = "translate English to Romanian: " + source 48 | f.write(json.dumps({"source": source.strip(), "target": target.strip()}) + "\n") 49 | 50 | 51 | def data_check(args, sample_val_from_train=True, val_sample_portion=0.1,logger=None): 52 | if not os.path.isfile(os.path.join(args.data_path, "train.json")): 53 | # if it is from: https://huggingface.co/nlp/viewer, load it 54 | # if there is no validation set, by default, here a random sampling from the train (0.1 ratio) is used to form the val set 55 | # so far it works well for single-sequence cls datasets such as "glue/sst2", "20newsgroup", "ag_news", "imdb", "sentiment140", etc. 56 | # support more all kinds of other datasets that are availale in nlp, such as sequence-pair cls datasets (NLI), qa datasets, etc. -> todo 57 | set_name = args.data_path.split("/")[1:] 58 | try: 59 | dataset = load_dataset(*set_name) 60 | target_examples_dict = {} 61 | assert "train" in dataset, "not found train set in the given nlp dataset, make sure you give the correct name as listed in: https://huggingface.co/nlp/viewer/" 62 | label_names = dataset["train"].features["label"].names 63 | for set_key, examples in tqdm(dataset.items(), 64 | desc=f"not found the data locally, try to load {set_name} from nlp lib."): 65 | if set_key == "validation": 66 | set_key = "val" 67 | if set_key not in target_examples_dict: 68 | target_examples_dict[set_key] = [] 69 | 70 | if set_key == "test" and "sst2" in set_name: 71 | # sst2 does not have ground truths in test set from the nlp lib. 72 | # here a special branch is processed here: 73 | # to download a sst-test set with the same format as the train and val set loaded here 74 | # from here: https://ucdcs-student.ucd.ie/~cwang/sst2-test.json 75 | # for model testing 76 | sst2_test = urllib.request.urlopen("https://ucdcs-student.ucd.ie/~cwang/sst2-test.json") 77 | for line in sst2_test: 78 | decoded_line = line.decode("utf-8").strip() 79 | target_examples_dict[set_key].append(decoded_line) 80 | else: 81 | for example in examples: 82 | example["label"] = label_names[example["label"]] # convert to raw label 83 | if "sentence" in example: 84 | example["text"] = example["sentence"] 85 | del example["sentence"] 86 | target_examples_dict[set_key].append(json.dumps(example)) 87 | 88 | if sample_val_from_train and "val" not in target_examples_dict: 89 | train, val = train_test_split(target_examples_dict["train"], test_size=val_sample_portion, 90 | random_state=42) 91 | target_examples_dict["train"] = train 92 | target_examples_dict["val"] = val 93 | 94 | to_folder = os.path.join("data", *set_name) 95 | 96 | if not os.path.isdir(to_folder): 97 | os.makedirs(to_folder) 98 | 99 | splits_dict = {} 100 | for set_key, examples in tqdm(target_examples_dict.items(), desc="writing..."): 101 | with open(os.path.join(to_folder, set_key + ".json"), "w+") as target: 102 | target.write("\n".join(examples)) 103 | splits_dict[set_key] = len(examples) 104 | with open(os.path.join(to_folder, "splits.txt"), "w+") as tgt: 105 | tgt.write(json.dumps(splits_dict, indent=2)) 106 | 107 | except: 108 | if set_name[0] in ADDITIONAL_DATA_SUPPORT.keys(): 109 | fstream = urllib.request.urlopen(ADDITIONAL_DATA_SUPPORT[set_name[0]]) 110 | tfile = tarfile.open(fileobj=fstream, mode="r|gz") 111 | tfile.extractall(path="data") 112 | if set_name[0] == "wmt_en_ro": 113 | ts2jsonl(args.data_path) 114 | else: 115 | if logger is not None: 116 | logger.info("data not found") 117 | print("data not found") 118 | sys.exit(0) 119 | 120 | def check_output_path(output_path,force=False): 121 | if os.path.isdir(output_path): 122 | if force: 123 | print(f"{output_path} exists, remove it as force=True") 124 | shutil.rmtree(output_path) 125 | os.makedirs(output_path, exist_ok=True) 126 | else: 127 | out = input( 128 | "Output directory ({}) already exists and is not empty, you wanna remove it before start training? (y/n)".format( 129 | output_path)) 130 | if out.lower() == "y": 131 | shutil.rmtree(output_path) 132 | os.makedirs(output_path, exist_ok=True) 133 | else: 134 | sys.exit(0) 135 | else: 136 | print(f"{output_path} not found, create it now") 137 | os.makedirs(output_path, exist_ok=True) 138 | 139 | def sanity_check(args,logger=None,force=False): 140 | # auto-construct some args 141 | # check if data exists 142 | data_check(args,logger=logger) 143 | 144 | output_folder = args.model_select + "_" + args.task + "_" + "-".join(args.data_path.split("/")[1:]) 145 | output_path = os.path.join("tmp", output_folder) 146 | args.output_folder = output_folder 147 | args.output_path = output_path 148 | 149 | if args.do_train: 150 | check_output_path(output_path,force=force) 151 | 152 | assert args.model_select in MODELS_SUPPORT or os.path.isdir(args.model_select), F"set --model_select has to be in {MODELS_SUPPORT} or a local path where model's config and tokenizer_config exist" 153 | 154 | assert args.task in TASKS_SUPPORT, F"set --task to be in {TASKS_SUPPORT}" 155 | assert args.scheduler in LR_SCHEDULER_SUPPORT, F"set --scheduler to be in {TASKS_SUPPORT}" 156 | if "t5" in args.model_select: 157 | assert "t2t" in args.task or "translation" in args.task or "pretrain" in args.task, "t5 models (--model_select) only support t2t, translation, pretrain tasks (--task)" 158 | 159 | if "t5" not in args.model_select: 160 | assert "t2t" not in args.task, "BERT-like models (--model_select) only support non t2t tasks (--task)" 161 | 162 | if "translation" in args.task: 163 | assert "t5" in args.model_select, "translation task now is only compatible with T5 models" 164 | 165 | if "pretrain" in args.task: 166 | assert "t5" in args.model_select, "pretrain task now is only compatible with T5 models so far" 167 | if logger is not None: 168 | add_filehandler_for_logger(args.output_path, logger) 169 | 170 | class Args: 171 | ''' 172 | a Args class that maintain the same default args as argparse.ArgumentParser 173 | ''' 174 | model_select = "bert-base-uncased" 175 | data_path = "data/glue/sst2" 176 | dataset_name = "," 177 | task = "single-label-cls" 178 | per_device_train_batch_size = 8 179 | eval_batch_size = 32 180 | num_epochs_train = 6 181 | log_steps = 400 182 | max_seq_length = 128 183 | max_src_length = 128 184 | max_tgt_length = 20 185 | source_field_name = "text" 186 | target_field_name = "label" 187 | lr = 5e-5 188 | warmup_ratio = 0.1 189 | patience = 20 190 | scheduler = "warmuplinear" 191 | seed = 122 192 | eval_on = "acc" 193 | keep_ck_num = 3 194 | ck_index_select = 0 195 | do_train = False 196 | do_eval = False 197 | do_test = False 198 | use_gpu = False 199 | use_tpu = False 200 | use_tb = False 201 | tpu_address = "x.x.x.x" 202 | 203 | def __init__(self, **kwargs): 204 | for k, v in kwargs.items(): 205 | setattr(self, k, v) 206 | 207 | 208 | def get_args(): 209 | parser = argparse.ArgumentParser(description='Hyper params') 210 | 211 | parser.add_argument('--model_select', type=str, default="t5-small", 212 | help='model select from MODEL_MAP') 213 | 214 | parser.add_argument('--data_path', type=str, default="data/glue/sst2", 215 | help='data path') 216 | 217 | parser.add_argument('--dataset_name', type=str, default="", 218 | help="dataset name for HF's load_dataset") 219 | 220 | parser.add_argument('--task', type=str, default="t2t", 221 | help='tasks available in TASKS_SUPPORT') 222 | 223 | parser.add_argument('--per_device_train_batch_size', type=int, default=8, 224 | help='input batch size for training') 225 | 226 | parser.add_argument('--eval_batch_size', type=int, default=32, 227 | help='input batch size for training') 228 | 229 | parser.add_argument('--num_epochs_train', type=int, default=6, 230 | help='number of epochs to train') 231 | 232 | parser.add_argument('--log_steps', type=int, default=400, 233 | help='logging steps for evaluation based on global step if it is not -1 and based on epoch if it is -1, and tracking metrics using tensorboard if use_tb is active') 234 | 235 | parser.add_argument('--max_seq_length', type=int, default=128, 236 | help='maximum sequence length of samples in a batch for training') 237 | 238 | parser.add_argument('--max_src_length', type=int, default=128, 239 | help='only working for t5-like t2t-based tasks, maximum source sequence length of samples in a batch for training') 240 | 241 | parser.add_argument('--max_tgt_length', type=int, default=20, 242 | help='only working for t5-like t2t-based tasks, maximum target sequence length of samples in a batch for training') 243 | 244 | parser.add_argument('--source_field_name', type=str, default="text", 245 | help='only working for t5-like t2t-based tasks, the source field name of the provided jsonl-formatted data') 246 | 247 | parser.add_argument('--target_field_name', type=str, default="label", 248 | help='only working for t5-like t2t-based tasks, the target field name of the provided jsonl-formatted data') 249 | 250 | parser.add_argument('--lr', type=float, default=0.001, 251 | help='default learning rate for fine-tuning as described in the T5 paper') 252 | 253 | parser.add_argument('--warmup_ratio', type=float, default=0.1, 254 | help='warmup_ratio, only working if scheduler is not constant') 255 | 256 | parser.add_argument('--patience', type=int, default=20, 257 | help='patience based on the log_steps') 258 | 259 | parser.add_argument('--scheduler', type=str, default="constant", 260 | help='scheduler, default is constant as described in the T5 paper') 261 | 262 | parser.add_argument('--seed', type=int, default=122, 263 | help='random seed') 264 | 265 | parser.add_argument('--eval_on', type=str, default="acc", 266 | help='eval on for best ck saving and patience-based training early stop') 267 | 268 | parser.add_argument('--keep_ck_num', type=int, default=3, 269 | help='keep_ck_num except for the best ck (evaluated on validation set using the metric specified by --eval_on') 270 | 271 | parser.add_argument('--ck_index_select', type=int, default=0, 272 | help='ck_index_select, use the best one by default, negative one to specify a latest one, working when --do_test is active') 273 | 274 | parser.add_argument( 275 | "--do_train", action="store_true", help="Do training" 276 | ) 277 | parser.add_argument( 278 | "--do_eval", action="store_true", help="do evaluation on validation set for saving checkpoint" 279 | ) 280 | parser.add_argument( 281 | "--do_test", action="store_true", help="eval on test set if test set is availale" 282 | ) 283 | parser.add_argument( 284 | "--use_gpu", action="store_true", help="use gpu?" 285 | ) 286 | 287 | parser.add_argument( 288 | "--use_tpu", action="store_true", help="use tpu? " 289 | ) 290 | parser.add_argument( 291 | "--use_tb", action="store_true", help="use tensorboard for tracking training process, default save to ./runs" 292 | ) 293 | 294 | parser.add_argument('--tpu_address', type=str, default="x.x.x.x", 295 | help='cloud tpu address if using tpu') 296 | 297 | parser.add_argument( 298 | "--default_store", action="store_true", 299 | help="Store datasets, weights, logs, and relevant details to folders by default?" 300 | ) 301 | 302 | args = parser.parse_args() 303 | return args 304 | -------------------------------------------------------------------------------- /ttt/evaluators.py: -------------------------------------------------------------------------------- 1 | '''' 2 | this is a evaluation callback class for high-level Keras training (BERT-like models in this lib) 3 | ''' 4 | 5 | import sys 6 | from tensorflow import keras 7 | import numpy as np 8 | from sklearn.metrics import classification_report 9 | import tensorflow as tf 10 | import logging 11 | import os 12 | from ttt.utils import add_filehandler_for_logger, get_existing_cks 13 | from tensorboardX import SummaryWriter 14 | 15 | logging.basicConfig( 16 | format='%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s', 17 | datefmt='%Y-%m-%d %H:%M:%S', 18 | level=logging.INFO 19 | ) 20 | logger = logging.getLogger(__name__) 21 | 22 | class ClsEvaluator(keras.callbacks.Callback): 23 | def __init__(self, x_eval, y_eval, args): 24 | super(ClsEvaluator).__init__() 25 | self.x_eval = x_eval 26 | self.y_eval = y_eval 27 | self.eval_on = args.eval_on 28 | self.patience = args.patience 29 | self.log_steps = args.log_steps 30 | self.args = args 31 | self.use_tb = self.args.use_tb 32 | if self.use_tb: 33 | self._tb_writer = SummaryWriter(log_dir=os.path.join("runs", args.output_folder)) 34 | 35 | def on_train_begin(self, logs=None): 36 | self.global_step = 0 37 | self.wait = 0 38 | self.best = np.Inf if self.eval_on == "loss" else -np.Inf 39 | 40 | def on_train_end(self, logs=None): 41 | if self.use_tb: 42 | self._tb_writer.close() 43 | 44 | def on_batch_end(self, batch, logs=None): 45 | self.global_step += 1 46 | if self.log_steps != -1 and self.global_step % self.log_steps == 0: 47 | if self.args.do_eval: 48 | self.evaluate(self.global_step, tag="global_step", logs=logs) 49 | 50 | def evaluate(self, steps, tag="epoch", logs=None): 51 | logger.info("\n") 52 | logger.info(f"*************evaluating at {tag}={steps}*************") 53 | eval_results = self.model.evaluate(self.x_eval, self.y_eval, batch_size=self.args.eval_batch_size) 54 | dev_loss, acc = eval_results[0], eval_results[1] 55 | pred_probs = self.model.predict(self.x_eval, batch_size=self.args.eval_batch_size) 56 | preds = tf.math.argmax(pred_probs, 1).numpy() 57 | # acc = accuracy_score(preds, self.y_eval) 58 | logger.info(f"{tag}={steps}, eval_report: {classification_report(self.y_eval, preds, digits=4)}") 59 | logger.info(f"{tag}={steps}, eval_acc: {acc}") 60 | logger.info(f"{tag}={steps}, eval_results: {eval_results}") 61 | logger.info(f"{tag}={steps}, train_logs: {logs}") 62 | 63 | if self.use_tb: 64 | if logs is not None: 65 | logger.info("logging metrics with tensorboard") 66 | for key,value in logs.items(): 67 | self._tb_writer.add_scalar(f"train_{key}_{tag}",value, steps) 68 | self._tb_writer.add_scalar(f"train_lr_{tag}", self.model.optimizer.lr.numpy(), steps) 69 | self._tb_writer.add_scalar(f"val_acc_{tag}", acc, steps) 70 | self._tb_writer.add_scalar(f"val_loss_{tag}", dev_loss, steps) 71 | 72 | 73 | if self.eval_on == "loss": 74 | if dev_loss <= self.best: 75 | self.wait = 0 76 | self.best = dev_loss 77 | self.save_ck(steps, tag, best_ck=True) 78 | else: 79 | self.wait += 1 80 | else: 81 | if acc >= self.best: 82 | self.wait = 0 83 | self.best = acc 84 | self.save_ck(steps, tag, best_ck=True) 85 | else: 86 | self.wait += 1 87 | logger.info(f"early stop count: {self.wait}/{self.patience}") 88 | logger.info(f"{tag}={steps}, best_on_eval_since({self.eval_on}): {self.best}") 89 | self.save_ck(steps, tag) 90 | if self.wait >= self.patience: 91 | logger.info("run out of patience, early stop") 92 | if self.use_tb: 93 | self._tb_writer.close() 94 | sys.exit(0) 95 | 96 | def save_ck(self, steps, tag="epoch", best_ck=False): 97 | sorted_indices, index2path = get_existing_cks(self.args.output_path, best_ck=best_ck) 98 | if len(sorted_indices) >= self.args.keep_ck_num: 99 | logger.info( 100 | f"since there are already {len(sorted_indices)} checkpoints saved that will be more than keep_ck_num={self.args.keep_ck_num}") 101 | logger.info(f"remove the oldest one: {index2path[sorted_indices[0]]}") 102 | os.remove(index2path[sorted_indices[ 103 | 0]]) # remove the oldest checkpoint, i.e., the one with the lowest epoch number 104 | # write_args(self.args.output_path, self.args) 105 | if best_ck: 106 | logger.info( 107 | f'save best model weights to {os.path.join(self.args.output_path, f"best_ck_at_{tag}_{steps}.h5")}') 108 | self.model.save_weights(os.path.join(self.args.output_path, f"best_ck_at_{tag}_{steps}.h5"), 109 | overwrite=True) 110 | else: 111 | logger.info( 112 | f'save model weights to {os.path.join(self.args.output_path, f"ck_at_{tag}_{steps}.h5")}') 113 | self.model.save_weights(os.path.join(self.args.output_path, f"ck_at_{tag}_{steps}.h5"), 114 | overwrite=True) 115 | 116 | def on_epoch_end(self, epoch, logs=None): 117 | if self.log_steps == -1: 118 | if self.args.do_eval: 119 | self.evaluate(epoch, tag="epoch",logs=logs) 120 | if not self.args.do_eval: 121 | # if do not do evaluate, the checkpoint at the end of epoch needs to be saved 122 | self.save_ck(epoch, tag="epoch") 123 | 124 | def get_evaluator(x_eval, y_eval, args): 125 | add_filehandler_for_logger(args.output_path, logger) 126 | if args.task == "single-label-cls": 127 | return ClsEvaluator(x_eval, y_eval, args) 128 | elif args.task == "t2t": 129 | # it uses customize training toop, we do not need a evaluator here 130 | pass 131 | else: 132 | pass -------------------------------------------------------------------------------- /ttt/inputs.py: -------------------------------------------------------------------------------- 1 | import json, os 2 | from tqdm import tqdm 3 | import numpy as np 4 | import logging, pickle 5 | from ttt.utils import add_filehandler_for_logger 6 | 7 | logging.basicConfig( 8 | format='%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s', 9 | datefmt='%Y-%m-%d %H:%M:%S', 10 | level=logging.INFO 11 | ) 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def read_seq_single_cls_examples(data_path): 16 | texts = [] 17 | labels = [] 18 | with open(data_path, "r") as f: 19 | for line in tqdm(f, desc=f"reading from {data_path}"): 20 | example = json.loads(line.strip()) 21 | texts.append(example["text"]) 22 | labels.append(example["label"]) 23 | return texts, labels 24 | 25 | 26 | def convert_seq_single_cls_examples(data_path, tokenizer, max_seq_length, label2id): 27 | texts, labels = read_seq_single_cls_examples(data_path) 28 | y = np.array([label2id[label] for label in labels]) 29 | encoded_texts = tokenizer(texts, padding=True, truncation=True, return_tensors="np", 30 | max_length=max_seq_length) 31 | return texts, encoded_texts, y 32 | 33 | 34 | def prepare_seq_single_cls_inputs(tokenizer, args, load_train_num=-1, tokenize_batch_size=None): 35 | logger.info("reading train set") 36 | train_texts, train_labels = read_seq_single_cls_examples(os.path.join(args.data_path, "train.json")) 37 | 38 | if load_train_num > 0: 39 | assert load_train_num <= len(train_texts), f"there are {len(train_texts)} training examples" 40 | logger.info(f"loading only {load_train_num} training examples out of the totaling {len(train_texts)}") 41 | train_texts = train_texts[:load_train_num] 42 | train_labels = train_labels[:load_train_num] 43 | 44 | label2id = {} 45 | logger.info("building label2id from train examples") 46 | for i, label in enumerate(list(set(train_labels))): 47 | label2id[label] = i 48 | 49 | logger.info("converting labels to its ids for train examples") 50 | y_train = np.array([label2id[label] for label in train_labels]) 51 | logger.info(f"encoding train examples (num={len(train_texts)})") 52 | logger.info(f"using tokenizer with padding = True and truncation = True and max_length = {args.max_seq_length}") 53 | logger.info("This may take a while") 54 | # todo -> tokenize_batch_size 55 | encoded_train = tokenizer(train_texts, padding=True, truncation=True, return_tensors="np", 56 | max_length=args.max_seq_length) 57 | 58 | if "token_type_ids" not in encoded_train: 59 | # we need this for roberta tokenizer that does not return token_type_ids 60 | encoded_train["token_type_ids"] = np.zeros(encoded_train["input_ids"].shape, dtype=np.int32) 61 | 62 | x_train = [encoded_train["input_ids"], encoded_train["token_type_ids"], encoded_train["attention_mask"]] 63 | 64 | to_return = {"x_train": x_train, "y_train": y_train, "label2id": label2id} 65 | 66 | if os.path.isfile(os.path.join(args.data_path, "val.json")): 67 | logger.info(f"found validation set in {os.path.join(args.data_path, 'val.json')}") 68 | logger.info(f"encoding validation examples") 69 | val_texts, encoded_val, y_eval = convert_seq_single_cls_examples(os.path.join(args.data_path, 'val.json'), 70 | tokenizer, args.max_seq_length, label2id) 71 | if "token_type_ids" not in encoded_val: 72 | # we need this for roberta tokenizer that does not return token_type_ids 73 | encoded_val["token_type_ids"] = np.zeros(encoded_val["input_ids"].shape, dtype=np.int32) 74 | 75 | x_eval = [encoded_val["input_ids"], encoded_val["token_type_ids"], encoded_val["attention_mask"]] 76 | 77 | to_return.update({"x_eval": x_eval, "y_eval": y_eval, "eval_examples": val_texts}) 78 | # if os.path.isfile(os.path.join(args.data_path, "test.json")): 79 | # logger.info(f"found test set in {os.path.join(args.data_path, 'test.json')}") 80 | # logger.info(f"encoding test examples") 81 | # test_texts, encoded_test, y_test = convert_seq_single_cls_examples(os.path.join(args.data_path, 'test.json'), 82 | # tokenizer, args.max_seq_length, label2id) 83 | # x_test = [encoded_test["input_ids"], encoded_test["token_type_ids"], encoded_test["attention_mask"]] 84 | # to_return.update({"x_test": x_test, "y_test": y_test, "test_examples": test_texts}) 85 | return to_return 86 | 87 | 88 | def read_t2t_examples(data_path, source_field_name="text", target_field_name="label", with_target=True): 89 | source_texts = [] 90 | target_texts = [] 91 | other_attributes = [] 92 | with open(data_path, "r",encoding="utf-8") as f: 93 | for line in tqdm(f, desc=f"reading from {data_path}"): 94 | example = json.loads(line.strip()) 95 | 96 | source_texts.append(example.pop(source_field_name)) 97 | # the will be added when calling tokenizer, automatically when its add_special_tokens=True 98 | if with_target: 99 | target_texts.append(example.pop(target_field_name)) 100 | other_attributes.append(example) 101 | return source_texts, target_texts, other_attributes 102 | 103 | 104 | def convert_t2t_examples(data_path, tokenizer, args, with_target=True, return_meta=False): 105 | source_texts, target_texts, other_attributes = read_t2t_examples(data_path, 106 | source_field_name=args.source_field_name, 107 | target_field_name=args.target_field_name, 108 | with_target=with_target) 109 | 110 | # for pre - training t5, no need to append the to the source sequence 111 | encoded_source = tokenizer(source_texts, padding=True, truncation=True, return_tensors="np", 112 | max_length=args.max_src_length, add_special_tokens=not args.is_pretrain) 113 | 114 | encoded_target = None 115 | if with_target: 116 | # for making predictions purpose 117 | encoded_target = tokenizer(target_texts, padding=True, truncation=True, return_tensors="np", 118 | max_length=args.max_tgt_length, add_special_tokens=True) 119 | if return_meta: 120 | return source_texts, encoded_source, encoded_target, other_attributes 121 | 122 | return source_texts, encoded_source, encoded_target 123 | 124 | 125 | def replace_with_special_token(encoded, special_token_id, replace_token_id=0): 126 | # replace_token_id: padding id 127 | ids = encoded["input_ids"] 128 | attention_mask = encoded["attention_mask"] 129 | 130 | for i in range(ids.shape[0]): 131 | indices = np.where(ids[i, :] == replace_token_id)[0] 132 | if indices.size == 0: 133 | ids[i, -1] = special_token_id 134 | else: 135 | ids[i, indices[0]] = special_token_id 136 | attention_mask[i, indices[0]] = 1 137 | return ids, attention_mask 138 | 139 | 140 | def shift_to_right(input_ids, decoder_start_token_id): 141 | shifted_input_ids = np.zeros(input_ids.shape, dtype=input_ids.dtype) 142 | shifted_input_ids[..., 1:] = input_ids[..., :-1] 143 | shifted_input_ids[..., 0] = decoder_start_token_id 144 | # shifted_input_ids[shifted_input_ids == 0] = -100 145 | return shifted_input_ids 146 | 147 | 148 | def tokenize_with_progress_bar(tokenizer, args, text_list, batch_size=1000): 149 | assert batch_size > 0 150 | encoded_return = {"input_ids": [], "attention_mask": []} 151 | batch = [] 152 | # actual_max_seq = =args.max_src_length 153 | # tokenization in batch specified by batch_size, a bit different padding here compared to the at-one-go way, the batch here always pad to max_src_length although the longest one is not that long 154 | for idx, each_text in tqdm(enumerate(text_list), desc="tokenizing by batch...", total=len(text_list)): 155 | if (idx + 1) % batch_size == 0: 156 | batch.append(each_text) 157 | encoded = tokenizer(batch, padding="max_length", truncation=True, max_length=args.max_src_length, 158 | add_special_tokens=not args.is_pretrain) 159 | encoded_return["input_ids"].extend(encoded["input_ids"]) 160 | encoded_return["attention_mask"].extend(encoded["attention_mask"]) 161 | batch = [] 162 | else: 163 | batch.append(each_text) 164 | 165 | if batch != []: 166 | encoded = tokenizer(batch, padding="max_length", truncation=True, max_length=args.max_src_length, 167 | add_special_tokens=not args.is_pretrain) 168 | encoded_return["input_ids"].extend(encoded["input_ids"]) 169 | encoded_return["attention_mask"].extend(encoded["attention_mask"]) 170 | 171 | assert len(encoded_return["input_ids"]) == len(text_list) 172 | 173 | encoded_return["input_ids"] = np.array(encoded_return["input_ids"]) 174 | 175 | assert len(encoded_return["attention_mask"]) == len(text_list) 176 | 177 | encoded_return["attention_mask"] = np.array(encoded_return["attention_mask"]) 178 | 179 | from transformers import BatchEncoding 180 | return BatchEncoding(data=encoded_return) 181 | 182 | def tokenizet2t(tokenizer, args, source_text, target_text, batch_size=None): 183 | # this method is similar to T5Tokenizer.prepare_seq2seq_batch, it is rewritten here to get more control and flexibility 184 | logger.info(f"encoding source examples (num={len(source_text)})") 185 | logger.info(f"using tokenizer with padding = True and truncation = True and max_src_length = {args.max_src_length}") 186 | 187 | if batch_size is None: 188 | # if batch_size is None. refers to tokenization at one go. In this case, no progress bar but it can save memory potentially 189 | # for pre - training t5, no need to append the to the source sequence 190 | logger.info("The tokenization is conducted at the backend without progress bar") 191 | logger.info("This may take a while") 192 | encoded_source = tokenizer(source_text, padding=True, truncation=True, return_tensors="np", 193 | max_length=args.max_src_length - 1 if not args.is_pretrain else args.max_src_length, 194 | add_special_tokens=not args.is_pretrain) 195 | else: 196 | encoded_source = tokenize_with_progress_bar(tokenizer, args, source_text, batch_size=batch_size) 197 | 198 | logger.info(f"encoding target examples (num={len(target_text)})") 199 | logger.info(f"using tokenizer with padding = True and truncation = True and max_tgt_length = {args.max_tgt_length}") 200 | if batch_size is None: 201 | logger.info("The tokenization is conducted at the backend without progress bar") 202 | logger.info("This may take a while") 203 | encoded_target = tokenizer(target_text, padding=True, truncation=True, return_tensors="np", 204 | max_length=args.max_tgt_length, add_special_tokens=True) 205 | else: 206 | encoded_target = tokenize_with_progress_bar(tokenizer, args, target_text, batch_size=batch_size) 207 | return encoded_source, encoded_target 208 | 209 | 210 | def prepare_t2t_inputs(tokenizer, args, load_train_num=-1, tokenize_batch_size=None): 211 | logger.info("reading train set") 212 | source_texts_train, target_texts_train, _ = read_t2t_examples(os.path.join(args.data_path, "train.json"), 213 | source_field_name=args.source_field_name, 214 | target_field_name=args.target_field_name) 215 | if load_train_num > 0: 216 | # assert load_train_num <= len(source_texts_train), f"there are {len(source_texts_train)} training examples" 217 | logger.info(f"loading only {load_train_num} training examples out of the totaling {len(source_texts_train)}") 218 | source_texts_train = source_texts_train[:load_train_num] 219 | target_texts_train = target_texts_train[:load_train_num] 220 | 221 | encoded_source_train, encoded_target_train = tokenizet2t(tokenizer, args, source_texts_train, target_texts_train, 222 | batch_size=tokenize_batch_size) 223 | 224 | # special_token_id = tokenizer.eos_token_id 225 | decoder_start_token_id = tokenizer.pad_token_id 226 | 227 | train_source_input_ids, train_source_attention_mask = encoded_source_train["input_ids"], encoded_source_train[ 228 | "attention_mask"] 229 | 230 | train_target_input_ids, train_target_attention_mask = encoded_target_train["input_ids"], encoded_target_train[ 231 | "attention_mask"] 232 | # this is pytorch's cross entropy's ignore index. to figure out this in tensorflow-2.0 233 | # -> update: this can be found in transformers.modeling_tf_utils:TFCausalLanguageModelingLoss (since transformers 3.1.0) 234 | # shift_to_right(train_target_input_ids, decoder_start_token_id), this not needed since transformers 3.1.0 (it automatically takes care of these by simply passing labels argument to the model, just like in pytorch) 235 | # x_train = [train_source_input_ids, train_source_attention_mask, 236 | # shift_to_right(train_target_input_ids, decoder_start_token_id), train_target_attention_mask] 237 | x_train = {"source_input_ids":train_source_input_ids, "source_attention_mask":train_source_attention_mask, 238 | "shifted_target_input_ids":shift_to_right(train_target_input_ids, decoder_start_token_id), "target_attention_mask":train_target_attention_mask} 239 | train_target_input_ids[train_target_input_ids == 0] = -100 240 | to_return = {"x_train": x_train, "y_train": {"target_input_ids":train_target_input_ids}} 241 | 242 | if args.do_eval: 243 | assert os.path.isfile( 244 | os.path.join(args.data_path, "val.json")), "do_eval=True, and no validation data (val.json) is found" 245 | 246 | if os.path.isfile(os.path.join(args.data_path, "val.json")): 247 | logger.info(f"found validation set in {os.path.join(args.data_path, 'val.json')}") 248 | logger.info(f"encoding validation examples") 249 | source_texts, encoded_source, encoded_target = convert_t2t_examples( 250 | os.path.join(args.data_path, 'val.json'), 251 | tokenizer, args) 252 | # add "" add the end of each sequence 253 | source_input_ids, source_attention_mask = encoded_source["input_ids"], encoded_source["attention_mask"] 254 | target_input_ids, target_attention_mask = encoded_target["input_ids"], encoded_target["attention_mask"] 255 | # target_input_ids[target_input_ids == 0] = -100 256 | # in t5 trainer, eval lm labels are not used for calculating loss but for generation comparison, os we do not apply -100 replacement here 257 | # x_eval = [source_input_ids, source_attention_mask, shift_to_right(target_input_ids, decoder_start_token_id), 258 | # target_attention_mask] 259 | x_eval = {"source_input_ids": source_input_ids, "source_attention_mask": source_attention_mask, 260 | "shifted_target_input_ids": shift_to_right(target_input_ids, decoder_start_token_id), 261 | "target_attention_mask": target_attention_mask} 262 | to_return.update({"x_eval": x_eval, "y_eval": {"target_input_ids":target_input_ids}, "eval_examples": {"source_texts":source_texts}}) 263 | 264 | # if os.path.isfile(os.path.join(args.data_path, "test.json")): 265 | # logger.info(f"found test set in {os.path.join(args.data_path, 'test.json')}") 266 | # logger.info(f"encoding test examples") 267 | # source_texts, encoded_source, encoded_target = convert_t2t_examples( 268 | # os.path.join(args.data_path, 'test.json'), 269 | # tokenizer, args) 270 | # # add "" add the end of each sequence 271 | # source_input_ids, source_attention_mask = encoded_source["input_ids"], encoded_source["attention_mask"] 272 | # target_input_ids, target_attention_mask = encoded_target["input_ids"], encoded_target["attention_mask"] 273 | # # target_input_ids[target_input_ids == 0] = -100 274 | # # in t5 trainer, eval lm labels are not used for calculating loss but for generation comparison, os we do not apply -100 replacement here 275 | # x_test = [source_input_ids, source_attention_mask, shift_to_right(target_input_ids, decoder_start_token_id), 276 | # target_attention_mask] 277 | # to_return.update({"x_test": x_test, "y_test": target_input_ids, "test_examples": source_texts}) 278 | return to_return 279 | 280 | 281 | def get_with_prepare_func(tokenizer, args, prepare_func, load_train_num=-1, check_cache=False, 282 | tokenize_batch_size=None): 283 | ''' 284 | :param tokenizer: 285 | :param args: 286 | :return: 287 | ''' 288 | args.is_load_from_data_cache = False 289 | if check_cache: 290 | if load_train_num > 0: 291 | data_cache_path = os.path.join(args.data_path, 292 | f"{args.model_select.replace('/', '-')}-data-{load_train_num}.pkl") 293 | else: 294 | data_cache_path = os.path.join(args.data_path, f"{args.model_select.replace('/', '-')}-data.pkl") 295 | args.data_cache_path = data_cache_path 296 | if os.path.isfile(data_cache_path): 297 | args.is_load_from_data_cache = True 298 | with open(data_cache_path, "rb") as f: 299 | logger.info(f"reading cached data from {data_cache_path}") 300 | logger.warning( 301 | f"if you changed the max_seq_length/max_src_length/max_tgt_length, this may not correctly loaded, since the {data_cache_path} is pickled based on first time loading") 302 | to_return = pickle.load(f) 303 | else: 304 | to_return = prepare_func(tokenizer, args, load_train_num=load_train_num, 305 | tokenize_batch_size=tokenize_batch_size) 306 | with open(data_cache_path, "wb") as f: 307 | logger.info(f"caching data to {data_cache_path}") 308 | pickle.dump(to_return, f) 309 | else: 310 | to_return = prepare_func(tokenizer, args, load_train_num=load_train_num) 311 | return to_return 312 | 313 | def get_inputs(tokenizer, args): 314 | add_filehandler_for_logger(args.output_path, logger) 315 | if args.task == "single-label-cls": 316 | inputs = get_with_prepare_func(tokenizer, args, prepare_seq_single_cls_inputs, check_cache=True, 317 | load_train_num=-1) 318 | args.input_seq_length = inputs["x_train"][0].shape[-1] 319 | args.label2id = inputs["label2id"] 320 | return inputs 321 | elif args.task == "t2t" or args.task == "translation" or args.task == "pretrain": 322 | args.is_pretrain = args.task == "pretrain" 323 | 324 | tokenize_batch_size=args.tokenize_batch_size if hasattr(args,"tokenize_batch_size") else None 325 | load_train_num=args.load_train_num if hasattr(args,"load_train_num") else -1 326 | 327 | data_dict = get_with_prepare_func(tokenizer, args, prepare_t2t_inputs, check_cache=True, load_train_num=load_train_num, 328 | tokenize_batch_size=tokenize_batch_size) 329 | args.source_sequence_length = data_dict["x_train"]["source_input_ids"].shape[-1] 330 | args.target_sequence_length = data_dict["y_train"]["target_input_ids"].shape[-1] 331 | return data_dict 332 | else: 333 | # when more tasks are supported -> todo 334 | pass 335 | -------------------------------------------------------------------------------- /ttt/models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | from tensorflow.keras import layers 4 | from transformers import TFAutoModel, TFT5ForConditionalGeneration, T5Config 5 | import os 6 | from transformers.modeling_tf_utils import get_initializer 7 | 8 | 9 | def get_lr_metric(optimizer): 10 | def lr(x, y): 11 | return optimizer.lr 12 | 13 | return lr 14 | 15 | 16 | def create_sl_cls_model(model_name_or_path, input_seq_length, args): 17 | ## transformer encoder 18 | encoder = TFAutoModel.from_pretrained(model_name_or_path) 19 | 20 | encoder_config = encoder.config 21 | if not os.path.isfile(os.path.join(args.output_path, "config.json")): 22 | encoder_config.save_pretrained(args.output_path) 23 | 24 | input_ids = layers.Input(shape=(input_seq_length,), dtype=tf.int32) 25 | token_type_ids = layers.Input(shape=(input_seq_length,), dtype=tf.int32) 26 | attention_mask = layers.Input(shape=(input_seq_length,), dtype=tf.int32) 27 | 28 | if "distilbert" in args.model_select: 29 | # distilbert does not allow to pass token_type_ids 30 | sequence_outs = encoder(input_ids, attention_mask=attention_mask)[0] 31 | else: 32 | sequence_outs = encoder(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)[0] 33 | 34 | # according to modeling_tf_bert:TFBertPooler. In transformers, models like ROBERTA and Electra do not ooffer direct outputs of pooled_output 35 | # to make it genelisable, the pooler is re-written here 36 | # this may not have a big effect on perf. if simply replacing the following pooler with this: pooled_output=sequence_outs[:, 0] 37 | pooled_output = tf.keras.layers.Dense( 38 | encoder_config.hidden_size, 39 | kernel_initializer=get_initializer(encoder_config.initializer_range), 40 | activation="tanh", 41 | name="dense", 42 | )(sequence_outs[:, 0]) 43 | 44 | if hasattr(encoder_config, "hidden_dropout_prob"): 45 | pooled_output = tf.keras.layers.Dropout(encoder_config.hidden_dropout_prob)(pooled_output, training=True) 46 | else: 47 | pooled_output = tf.keras.layers.Dropout(encoder_config.dropout)(pooled_output, training=True) 48 | 49 | logits = tf.keras.layers.Dense(len(args.label2id), name="classifier", use_bias=False)(pooled_output) 50 | probs = layers.Activation(keras.activations.softmax)(logits) 51 | 52 | model = keras.Model( 53 | inputs=[input_ids, token_type_ids, attention_mask], 54 | outputs=probs, 55 | ) 56 | 57 | loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False) 58 | optimizer = keras.optimizers.Adam(lr=args.lr) 59 | model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy', get_lr_metric(optimizer)]) 60 | return model 61 | 62 | def create_t2t_model(model_name_or_path, args, tokenizer=None,from_pretrained=True): 63 | ## transformer encoder 64 | if from_pretrained: 65 | encoder = TFT5ForConditionalGeneration.from_pretrained(model_name_or_path) 66 | encoder_config = encoder.config 67 | else: 68 | encoder_config = T5Config.from_pretrained(args.model_select) 69 | if tokenizer!=None: 70 | assert encoder_config.vocab_size==len(tokenizer) 71 | assert encoder_config.pad_token_id==tokenizer.pad_token_id 72 | assert encoder_config.eos_token_id==tokenizer.eos_token_id 73 | assert encoder_config.decoder_start_token_id==tokenizer.pad_token_id 74 | encoder = TFT5ForConditionalGeneration(encoder_config) 75 | # build the model with dummy_inputs 76 | encoder(encoder.dummy_inputs, training=False) 77 | 78 | if not os.path.isfile(os.path.join(args.output_path, "config.json")): 79 | encoder_config.save_pretrained(args.output_path) 80 | return encoder 81 | 82 | 83 | def get_model(args,tokenizer=None,from_pretrained=True): 84 | if args.task == "single-label-cls": 85 | return create_sl_cls_model(args.model_select, args.input_seq_length, args) 86 | elif args.task == "t2t" or args.task == "translation" or args.task == "pretrain": 87 | return create_t2t_model(args.model_select, args, tokenizer=tokenizer,from_pretrained=from_pretrained) 88 | else: 89 | # when more task are supported -> todo 90 | pass 91 | -------------------------------------------------------------------------------- /ttt/t2t_trainer.py: -------------------------------------------------------------------------------- 1 | '''' 2 | this is a customize trainer for T5-like mode training, 3 | in this class, the training loop is customized for more flexibility and control over 4 | ''' 5 | import math 6 | import os 7 | import sys 8 | import warnings 9 | 10 | import tensorflow as tf 11 | from tqdm import tqdm 12 | from sklearn.metrics import accuracy_score, classification_report 13 | import numpy as np 14 | from keras import backend as K 15 | from ttt.utils import add_filehandler_for_logger, get_existing_cks 16 | from tensorboardX import SummaryWriter 17 | # for translation evaluation from: https://github.com/mjpost/sacrebleu 18 | # which is also used in the original T5 paper 19 | import sacrebleu 20 | from .utils import write_args_enhance, save_ck, dictionize_t2t_dataset, set_seed 21 | 22 | 23 | class T2TTrainer(): 24 | def __init__(self, args, logger): 25 | self.eval_on = args.eval_on 26 | assert self.eval_on in ["acc", 27 | "bleu"], "now t2t training only supports --eval_on acc, bleu, only works when --do_eval=True" 28 | # self.best = -np.Inf 29 | 30 | self.patience = args.patience 31 | self.wait = 0 32 | self.logger = logger 33 | self.args = args 34 | self.use_tb = self.args.__dict__.get('use_tb', False) 35 | 36 | self._tb_writer = None 37 | if self.use_tb: 38 | self._tb_writer = SummaryWriter(log_dir=self.args.__dict__.get('output_folder', "runs")) 39 | self.scheduler = args.scheduler 40 | 41 | if "learning_rate" in self.args.__dict__: 42 | self.lr_to_reach = args.learning_rate 43 | else: 44 | self.lr_to_reach = args.lr 45 | 46 | self.args.best = np.Inf if self.args.eval_on == "loss" or self.args.eval_on == "perplexity" else - np.Inf 47 | self.best = self.args.best 48 | 49 | def train(self, model, strategy, tokenizer, inputs=None, train_dataset=None, eval_dataset=None, evaluate_fn=None, verbose=False): 50 | 51 | if inputs is None: 52 | assert train_dataset is not None, "you have to pass either inputs or train_dataset" 53 | else: 54 | warnings.warn( 55 | "Passing `inputs` as a keyword argument is deprecated. Use train_dataset and eval_dataset instead.", 56 | FutureWarning, 57 | ) 58 | 59 | if isinstance(inputs, tuple): 60 | inputs = dictionize_t2t_dataset(*inputs) 61 | 62 | if inputs is not None: 63 | x_train, y_train = inputs["x_train"], inputs["y_train"] 64 | num_train_examples = len(inputs["y_train"]["target_input_ids"]) 65 | train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) 66 | else: 67 | if hasattr(train_dataset, "num_examples"): 68 | num_train_examples = train_dataset.num_examples 69 | else: 70 | num_train_examples = tf.data.experimental.cardinality(train_dataset).numpy() 71 | 72 | self.logger.info(f"set random seed for everything with {self.args.seed}") 73 | set_seed(self.args.seed) 74 | global_batch_size = self.args.per_device_train_batch_size * strategy.num_replicas_in_sync 75 | train_dataset = train_dataset.shuffle(buffer_size=self.args.seed).batch(global_batch_size) 76 | train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset) 77 | # THERE WILL BE exceptions when switching to distributed_dataset when running on tpus if 78 | # val_dist_dataset = strategy.experimental_distribute_dataset(eval_dataset) 79 | train_length = math.ceil(num_train_examples / global_batch_size) 80 | self.steps_per_epoch = train_length 81 | if inputs is not None: 82 | if self.args.do_eval: 83 | assert "x_eval" in inputs and "y_eval" in inputs, "do_eval=True, and no validation data is found" 84 | x_val, y_val = inputs["x_eval"], inputs["y_eval"] 85 | eval_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val)) 86 | eval_dataset = eval_dataset.batch(self.args.eval_batch_size) 87 | eval_steps = math.ceil( 88 | len(inputs["y_eval"]["target_input_ids"]) / (self.args.eval_batch_size)) 89 | else: 90 | if self.args.do_eval: 91 | if hasattr(eval_dataset, "num_examples"): 92 | eval_num_examples = eval_dataset.num_examples 93 | else: 94 | eval_num_examples = tf.data.experimental.cardinality(eval_dataset).numpy() 95 | eval_steps = math.ceil(eval_num_examples / (self.args.eval_batch_size)) 96 | eval_dataset = eval_dataset.batch(self.args.eval_batch_size) 97 | if verbose: 98 | self.logger.info(model.summary()) 99 | # these are used for non-constant lr scheduler 100 | if "num_train_epochs" in self.args.__dict__: 101 | self.args.num_epochs_train = self.args.num_train_epochs 102 | if "log_and_save_steps" in self.args.__dict__: 103 | self.args.log_steps = self.args.log_and_save_steps 104 | 105 | self.total_steps = self.steps_per_epoch * self.args.num_epochs_train 106 | 107 | if "warmup_steps_or_ratio" in self.args.__dict__: 108 | if self.args.warmup_steps_or_ratio <= 1 and self.args.warmup_steps_or_ratio > 0: 109 | self.args.warmup_steps = int(self.total_steps * self.args.warmup_steps_or_ratio) 110 | else: 111 | self.args.warmup_steps = self.args.warmup_steps_or_ratio 112 | else: 113 | self.args.warmup_steps = int(self.total_steps * self.args.warmup_ratio) 114 | 115 | self.warmup_steps = self.args.warmup_steps 116 | write_args_enhance(self.args, logger=self.logger) 117 | 118 | with strategy.scope(): 119 | optimizer = tf.keras.optimizers.Adam(lr=self.args.lr if self.scheduler.startswith("constant") else 0.0) 120 | loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( 121 | from_logits=True, reduction=tf.keras.losses.Reduction.NONE 122 | ) 123 | 124 | def compute_loss(labels, predictions): 125 | per_example_loss = loss_fn(labels, predictions) 126 | return tf.nn.compute_average_loss(per_example_loss, global_batch_size=global_batch_size) 127 | 128 | def train_step(x_train, y_train): 129 | with tf.GradientTape() as tape: 130 | # here some changes has been made (compared to before commit `a07c58e` ) to fix a bug reported here: https://github.com/wangcongcong123/ttt/issues/2 131 | # The following describes how this bug is fixed 132 | # the compute_loss function in transformers:TFT5ForConditionalGeneration has already taken care of the loss computation (already averaged!!!!) that failed 133 | # when switching to TPU, hence we re-compute it here using the returned logits from the model ready for backprop instead of using the internally calculated loss 134 | outputs = model(inputs=x_train["source_input_ids"], attention_mask=x_train["source_attention_mask"], 135 | decoder_attention_mask=x_train["target_attention_mask"], 136 | labels=y_train["target_input_ids"], training=True, return_dict=True) 137 | logits = outputs.logits 138 | loss = compute_loss(tf.reshape(y_train["target_input_ids"], (-1, y_train["target_input_ids"].shape[-1])), 139 | tf.reshape(logits, (-1, logits.shape[-1]))) 140 | 141 | gradients = tape.gradient(loss, model.trainable_variables) 142 | optimizer.apply_gradients(zip(gradients, model.trainable_variables)) 143 | return loss 144 | 145 | @tf.function 146 | def distributed_train_step(x_train, y_train): 147 | per_replica_losses = strategy.experimental_run_v2(train_step, args=(x_train, y_train,)) 148 | return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) 149 | 150 | # evaluate 151 | def evaluate(steps, tag="epoch"): 152 | assert tag in ["epoch", "global_step"] 153 | gts = [] 154 | preds = [] 155 | for x_eval, y_eval in tqdm(eval_dataset, total=eval_steps, desc="evaluating..."): 156 | predictions = model.generate(input_ids=x_eval["source_input_ids"], 157 | attention_mask=x_eval["source_attention_mask"], 158 | max_length=self.args.max_tgt_length) 159 | pred = [tokenizer.decode(ids) for ids in predictions] 160 | gt = [tokenizer.decode(ids) for ids in y_eval["target_input_ids"]] 161 | # labels (not -100 replaced since it is not used to calculate loss here) 162 | preds.extend(pred) 163 | gts.extend(gt) 164 | 165 | if self.eval_on == "bleu": 166 | # bleu = 0 167 | bleu = sacrebleu.corpus_bleu(preds, [gts]) 168 | eval_score = bleu.score 169 | else: 170 | eval_score = accuracy_score(gts, preds) 171 | self.logger.info(f"val_cls_report: {classification_report(gts, preds, digits=4)}") 172 | 173 | if self.use_tb: 174 | self._tb_writer.add_scalar(f"val_{self.eval_on}_{tag}", eval_score, steps) 175 | 176 | self.logger.info("\n") 177 | self.logger.info(f"*******eval at {tag} = {steps} on validation dataset*********") 178 | self.logger.info(f"val_{self.eval_on}: {eval_score}") 179 | 180 | if self.eval_on == "acc" or self.eval_on == "bleu": 181 | if eval_score >= self.best: 182 | self.wait = 0 183 | self.best = eval_score 184 | self.logger.info( 185 | f"so far the best check point at {tag}={steps} based on eval_on {self.eval_on}") 186 | # self.save_ck(model, steps, tag, best_ck=True) 187 | save_ck(self.args, self.logger, model, tokenizer=tokenizer, steps=steps, 188 | tag=tag, best_ck=False, from_tf=True) 189 | else: 190 | self.wait += 1 191 | else: 192 | raise ValueError("not support yet") 193 | 194 | self.logger.info(f"best so far({self.eval_on}): {self.best}") 195 | self.logger.info(f"early stop count: {self.wait}/{self.patience}") 196 | # self.save_ck(model, steps, tag) 197 | save_ck(self.args, self.logger, model, tokenizer=tokenizer, steps=steps, 198 | tag=tag, best_ck=False, from_tf=True) 199 | if self.wait >= self.patience: 200 | self.logger.info("run out of patience, early stop") 201 | if self.use_tb: 202 | self._tb_writer.close() 203 | sys.exit(0) 204 | 205 | def update_lr(global_step): 206 | # already tested on tpu, works fine 207 | # global_step is dynamically passed here 208 | if global_step <= self.warmup_steps: 209 | if self.scheduler == "warmuplinear" or self.scheduler == "warmupcostant": 210 | inc = self.lr_to_reach / self.warmup_steps 211 | K.set_value(optimizer.learning_rate, K.eval(optimizer.lr) + inc) 212 | else: 213 | if self.scheduler == "warmuplinear" or self.scheduler == "constantlinear": 214 | dec = self.lr_to_reach / (self.total_steps - self.warmup_steps) 215 | K.set_value(optimizer.learning_rate, K.eval(optimizer.lr) - dec) 216 | # for "constant" scheduler, nothing to do here 217 | 218 | global_step = 0 219 | early_exit = False 220 | interval_loss = 0.0 221 | interval_count = 0 222 | for epoch in tqdm(range(self.args.num_epochs_train), desc="epochs"): 223 | 224 | self.logger.info(f"start training at epoch = {epoch}") 225 | self.logger.info(f"global train batch size = {global_batch_size}") 226 | self.logger.info(f"using learning rate scheduler: {self.scheduler}") 227 | self.logger.info( 228 | f"num_train_examples: {num_train_examples}, total_steps: {self.total_steps}, steps_per_epoch: {self.steps_per_epoch}") 229 | if self.scheduler != "constant": 230 | self.logger.info(f"warmup_steps:{self.warmup_steps}") 231 | 232 | pbar = tqdm(enumerate(train_dist_dataset), total=train_length) 233 | for step, (x_train, y_train) in pbar: 234 | # learning rate scheduler 235 | update_lr(global_step) 236 | loss = distributed_train_step(x_train, y_train) 237 | interval_loss += loss.numpy() 238 | interval_count += 1 239 | global_step += 1 240 | pbar.set_description(f"training - epoch {epoch + 1}/{self.args.num_epochs_train} iter {step}: train loss {loss.numpy():.5f}. lr {optimizer.lr.numpy():e}") 241 | 242 | if self.args.log_steps != -1 and global_step % self.args.log_steps == 0: 243 | if self.use_tb: 244 | self._tb_writer.add_scalar("train_loss_global_step", interval_loss / interval_count, 245 | global_step) 246 | self._tb_writer.add_scalar("train_lr_global_step", optimizer.lr.numpy(), global_step) 247 | 248 | if self.args.do_eval: 249 | if evaluate_fn is not None and eval_dataset is not None: 250 | eval_dict = evaluate_fn(self.args, self.logger, model, tokenizer, eval_dataset, steps=global_step, tag="global_step", eval_length=eval_steps) 251 | if self._tb_writer: 252 | if "eval_scores" in eval_dict: 253 | for key, value in eval_dict["eval_scores"].items(): 254 | self._tb_writer.add_scalar(f"eval_{key}_global_step", value, global_step) 255 | if "is_early_stop" in eval_dict and eval_dict["is_early_stop"]: 256 | self.logger.info(f"run out of patience at global step = {global_step}, early stop") 257 | if self._tb_writer: 258 | self._tb_writer.close() 259 | early_exit = True 260 | break 261 | else: 262 | evaluate(global_step, tag="global_step") 263 | self.logger.info(f"train loss at global_step {global_step}: {interval_loss / interval_count}") 264 | interval_loss = 0.0 265 | interval_count = 0 266 | if early_exit: 267 | break 268 | 269 | train_loss = interval_loss / interval_count 270 | interval_loss = 0.0 271 | interval_count = 0 272 | if self.args.log_steps == -1: 273 | if self.args.do_eval: 274 | if evaluate_fn is not None and eval_dataset is not None: 275 | eval_dict = evaluate_fn(self.args, self.logger, model, tokenizer, eval_dataset, steps=epoch + 1, tag="epoch", eval_length=eval_steps) 276 | if self._tb_writer: 277 | if "eval_scores" in eval_dict: 278 | for key, value in eval_dict["eval_scores"].items(): 279 | self._tb_writer.add_scalar(f"eval_{key}_epoch", value, epoch + 1) 280 | if "is_early_stop" in eval_dict and eval_dict["is_early_stop"]: 281 | self.logger.info(f"run out of patience at epoch = {epoch + 1}, early stop") 282 | if self._tb_writer: 283 | self._tb_writer.close() 284 | break 285 | else: 286 | evaluate(epoch + 1, tag="epoch") 287 | if self.use_tb: 288 | self._tb_writer.add_scalar("train_loss_epoch", train_loss, 289 | global_step) 290 | self._tb_writer.add_scalar("train_lr_epoch", optimizer.lr.numpy(), global_step) 291 | self.logger.info(f"train loss at end of epoch {epoch + 1}: {train_loss}") 292 | 293 | if not self.args.do_eval: 294 | # if do not do evaluate, the checkpoint at the end of epoch needs to be saved 295 | # self.save_ck(model, epoch + 1, tag="epoch") 296 | save_ck(self.args, self.logger, model, tokenizer=tokenizer, steps=epoch + 1, 297 | tag="epoch", best_ck=False, from_tf=True) 298 | 299 | if self.use_tb: 300 | self._tb_writer.close() 301 | -------------------------------------------------------------------------------- /ttt/utils.py: -------------------------------------------------------------------------------- 1 | import glob, re, shutil, torch 2 | import random, os, json 3 | import numpy as np 4 | 5 | import tensorflow as tf 6 | from transformers import AutoTokenizer 7 | import logging 8 | 9 | import tensorflow_addons as tfa 10 | from tensorflow import keras 11 | from keras import backend as K 12 | 13 | 14 | class LRSchudlerCallback(keras.callbacks.Callback): 15 | def __init__(self, args, logger): 16 | super(LRSchudlerCallback, self).__init__() 17 | self.warmup_ratio = args.warmup_ratio 18 | self.scheduler = args.scheduler 19 | self.logger = logger 20 | 21 | def on_train_begin(self, logs=None): 22 | self.steps_per_epoch = self.params["steps"] 23 | self.epochs = self.params["epochs"] 24 | self.global_step = 0 25 | self.logger.info(f"using learning rate scheduler {self.scheduler}") 26 | if not self.scheduler.startswith("constant"): 27 | self.total_steps = self.steps_per_epoch * self.epochs 28 | self.warmup_steps = int(self.total_steps * self.warmup_ratio) 29 | self.logger.info( 30 | f"total_steps: {self.total_steps}, steps_per_epoch: {self.steps_per_epoch}, epochs: {self.epochs}, warmup_steps:{self.warmup_steps}") 31 | if not hasattr(self.model.optimizer, "lr"): 32 | raise ValueError('Optimizer must have a "lr" attribute.') 33 | self.logger.info(f"lr of optimizer to reach through warmup: {K.eval(self.model.optimizer.lr)}") 34 | 35 | self.lr_to_reach = K.eval(self.model.optimizer.lr) 36 | K.set_value(self.model.optimizer.learning_rate, 0.00) 37 | self.logger.info(f"now set it to zero for warmup: {K.eval(self.model.optimizer.lr)}") 38 | 39 | def on_train_batch_end(self, batch, logs=None): 40 | if self.global_step <= self.warmup_steps: 41 | if self.scheduler == "warmuplinear" or self.scheduler == "warmupconstant": 42 | inc = self.lr_to_reach / self.warmup_steps 43 | K.set_value(self.model.optimizer.learning_rate, K.eval(self.model.optimizer.lr) + inc) 44 | else: 45 | if self.scheduler == "warmuplinear" or self.scheduler == "constantlinear": 46 | dec = self.lr_to_reach / (self.total_steps - self.warmup_steps) 47 | K.set_value(self.model.optimizer.learning_rate, K.eval(self.model.optimizer.lr) - dec) 48 | 49 | self.global_step += 1 50 | 51 | def on_test_batch_end(self, batch, logs=None): 52 | pass 53 | 54 | def on_epoch_end(self, epoch, logs=None): 55 | self.logger.info(f"at epoch={epoch}, the learning_rate is {K.eval(self.model.optimizer.lr)}") 56 | 57 | def on_train_end(self, logs=None): 58 | self.logger.info("testing") 59 | 60 | 61 | def get_callbacks(args, inputs, logger, eval_getter): 62 | tqdm_callback = tfa.callbacks.TQDMProgressBar(metrics_format="{name}: {value:0.8f}", 63 | epoch_bar_format="{n_fmt}/{total_fmt}{bar} ETA: {remaining}s - {desc}, {rate_fmt}{postfix}", ) 64 | lr_scheduler_callback = LRSchudlerCallback(args, logger) 65 | if args.do_eval == True: 66 | eval_callback = eval_getter(inputs["x_eval"], inputs["y_eval"], args) 67 | # return [tqdm_callback,eval_callback] 68 | return [tqdm_callback, eval_callback, lr_scheduler_callback] 69 | else: 70 | return [tqdm_callback, lr_scheduler_callback] 71 | 72 | def dictionize_single_dataset(inputs,tag="train"): 73 | dict_dataset = {} 74 | x, y = inputs 75 | x_ = {} 76 | x_["source_input_ids"] = x.pop("input_ids") 77 | x_["source_attention_mask"] = x.pop("attention_mask") 78 | x_["target_attention_mask"] = x.pop("decoder_attention_mask") 79 | 80 | dict_dataset[f"x_{tag}"] = x_ 81 | dict_dataset[f"y_{tag}"] = {"target_input_ids": y} 82 | return dict_dataset 83 | 84 | def dictionize_t2t_dataset(train_inputs, eval_inputs=None): 85 | dict_dataset = dictionize_single_dataset(train_inputs, tag="train") 86 | if eval_inputs is not None: 87 | dict_dataset.update(dictionize_single_dataset(eval_inputs, tag="eval")) 88 | return dict_dataset 89 | 90 | def get_strategy(args,logger): 91 | if args.use_tpu: 92 | # Create distribution strategy 93 | # checking ip address or tpu name? 94 | if re.match(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$", args.tpu_address): 95 | args.tpu_address = 'grpc://' + args.tpu_address 96 | tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args.tpu_address) 97 | tf.config.experimental_connect_to_cluster(tpu) 98 | tf.tpu.experimental.initialize_tpu_system(tpu) 99 | logger.info("All TPU devices: ") 100 | for each_device in tf.config.list_logical_devices('TPU'): 101 | logger.info(each_device) 102 | strategy = tf.distribute.TPUStrategy(tpu) 103 | else: 104 | if args.use_gpu: 105 | # Create a MirroredStrategy. 106 | strategy = tf.distribute.MirroredStrategy() 107 | logger.info("Number of GPU devices: {}".format(strategy.num_replicas_in_sync)) 108 | else: 109 | raise ValueError("not available yet") 110 | # strategy = None 111 | # logger.info("Using CPU for training") 112 | # model = model_getter(args) 113 | return strategy 114 | 115 | def create_model(args, logger, model_getter, tokenizer=None, from_pretrained=True, save_args=True): 116 | # get strategy and Create model 117 | strategy = get_strategy(args, logger) 118 | with strategy.scope(): 119 | model = model_getter(args, tokenizer=tokenizer, from_pretrained=from_pretrained) 120 | 121 | logger.info(model.summary()) 122 | # trainable_count = int( 123 | # np.sum([K.count_params(p) for p in set(model.trainable_weights)])) 124 | # non_trainable_count = int( 125 | # np.sum([K.count_params(p) for p in set(model.non_trainable_weights)])) 126 | # logger.info('Total params: {:,}'.format(trainable_count + non_trainable_count)) 127 | # logger.info('Trainable params: {:,}'.format(trainable_count)) 128 | # logger.info('Non-trainable params: {:,}'.format(non_trainable_count)) 129 | # if strategy!=None: 130 | args.num_replicas_in_sync = strategy.num_replicas_in_sync 131 | if save_args: 132 | write_args(args.output_path, args) 133 | return model, strategy 134 | 135 | 136 | def get_tokenizer(args): 137 | tokenizer = AutoTokenizer.from_pretrained(args.model_select) 138 | tokenizer.save_pretrained(args.output_path) 139 | return tokenizer 140 | 141 | 142 | def add_filehandler_for_logger(output_path, logger, out_name="train"): 143 | logFormatter = logging.Formatter('%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s') 144 | fileHandler = logging.FileHandler(os.path.join(output_path, f"{out_name}.log"), mode="a") 145 | fileHandler.setFormatter(logFormatter) 146 | logger.addHandler(fileHandler) 147 | 148 | 149 | def set_seed(seed): 150 | tf.random.set_seed( 151 | seed 152 | ) 153 | random.seed(seed) 154 | np.random.seed(seed) 155 | 156 | 157 | def write_args(output_path, args): 158 | with open(os.path.join(output_path, "args.json"), "w") as f: 159 | f.write(json.dumps(args.__dict__, indent=2)) 160 | 161 | def is_jsonable(x): 162 | try: 163 | json.dumps(x) 164 | return True 165 | except: 166 | return False 167 | 168 | def write_args_enhance(args, logger=None, write_path=None): 169 | if write_path is None: 170 | write_path = args.output_path 171 | 172 | with open(os.path.join(write_path, "args.json"), "w+") as f: 173 | args_dict = {} 174 | for key, value in args.__dict__.items(): 175 | if is_jsonable(value): 176 | args_dict[key] = value 177 | if logger is not None: 178 | logger.info(json.dumps(args_dict, indent=2)) 179 | else: 180 | print(json.dumps(args_dict, indent=2)) 181 | f.write(json.dumps(args_dict, indent=2)) 182 | 183 | def get_existing_cks(output_path, best_ck=False, return_best_ck=False): 184 | cks_already = [name for name in os.listdir(output_path) if os.path.isdir(os.path.join(output_path, name))] 185 | 186 | if best_ck: 187 | for ex in [each for each in cks_already if each.startswith("best")]: 188 | cks_already.remove(ex) 189 | shutil.rmtree(os.path.join(output_path, ex)) 190 | 191 | index2path = {} 192 | 193 | for each_ck in cks_already: 194 | if return_best_ck or not each_ck.startswith("best"): 195 | index2path[int(os.path.basename(each_ck).split("_")[-1])] = os.path.join(output_path, each_ck) 196 | 197 | sorted_indices = sorted(index2path) # index here refers to the epoch number 198 | return sorted_indices, index2path 199 | 200 | 201 | def load_torch_state_dict_from_h5_weights(model): 202 | import torch 203 | state_dict = {} 204 | for layer in model.layers: 205 | for resource_variable in layer.weights: 206 | key = resource_variable.name 207 | value = torch.tensor(resource_variable.numpy()) 208 | state_dict[key] = value 209 | total_params_num = sum([element.numel() for element in state_dict.values()]) 210 | print(f"the number of params: {total_params_num}") 211 | return state_dict 212 | 213 | def save_and_check_if_early_stop(eval_score, args, logger, model, tokenizer, steps=0, tag="epoch",from_tf=False): 214 | logger.info("\n") 215 | logger.info( 216 | f"*******eval at {tag} = {steps} (gradient accumulation steps={args.__dict__.get('gradient_accumulation_steps', 1)})*********") 217 | logger.info(f"val_{args.eval_on}: {eval_score}") 218 | best_save = False 219 | if args.eval_on == "acc": 220 | if eval_score >= args.best: 221 | args.wait = 0 222 | args.best = eval_score 223 | logger.info(f"so far the best check point at {tag}={steps} based on eval_on {args.eval_on}") 224 | save_ck(args, logger, model, tokenizer, steps=steps, tag=tag, best_ck=True,from_tf=from_tf) 225 | best_save = True 226 | else: 227 | args.wait += 1 228 | else: 229 | raise ValueError("not support yet") 230 | 231 | logger.info(f"best so far ({args.eval_on}): {args.best}") 232 | logger.info(f"early stop count: {args.wait}/{args.patience}") 233 | if not best_save: 234 | save_ck(args, logger, model, tokenizer, steps=steps, tag=tag, best_ck=False,from_tf=from_tf) 235 | 236 | if args.wait >= args.patience: 237 | logger.info("run out of patience, early stop") 238 | return True 239 | return False 240 | 241 | def save_transformer_locally(model_name="bert-base-uncased", save_path=".", is_tf=False): 242 | """save 243 | anyone you can find from here: https://huggingface.co/models 244 | """ 245 | # to use AutoModel, need to install pytorch: pip3 install torch torchvision or pip install torch torchvision 246 | from transformers import AutoTokenizer, AutoModel, TFAutoModel 247 | if is_tf: 248 | model = TFAutoModel.from_pretrained(model_name) 249 | else: 250 | model = AutoModel.from_pretrained(model_name) 251 | 252 | # Load pretrained model/tokenizer 253 | tokenizer = AutoTokenizer.from_pretrained(model_name) 254 | if not os.path.isdir(save_path): 255 | os.makedirs(save_path, exist_ok=True) 256 | model.save_pretrained(os.path.join(save_path, model_name)) # save model weights and config 257 | tokenizer.save_pretrained(os.path.join(save_path, model_name)) # save tokenizer config or/and vocab 258 | 259 | 260 | def iid_denoise_text(original_text, span_length=3, corrupt_ratio=0.15, lang="zh_cn"): 261 | """ 262 | This method is implemented for the pre-training objective of T5, as described in the T5 paper (https://arxiv.org/abs/1910.10683) 263 | this default params setup keeps the same as the original T5 paper on English, we generalize it to more languages such as Chinese 264 | :param original_text: it is a list of tokens 265 | :param span_length: 3 for by default as described in T5 paper 266 | :param corrupt_ratio: 15% by default as described in T5 paper 267 | :param lang: reserved param for future use 268 | :return: 269 | """ 270 | source_text = [] 271 | target_text = [] 272 | # if lang == "en": 273 | # corrupt_ratio = 0.15 274 | # span_length = 3 # 3 as in T5 paper 275 | # make deterministic for reproducibility 276 | # random.seed(2020) 277 | replace_i = 0 278 | skip_count = span_length 279 | last_replace_pos = - span_length 280 | for pos in range(len(original_text)): 281 | if skip_count < span_length - 1: 282 | skip_count += 1 283 | else: 284 | if random.uniform(0, 1) < corrupt_ratio: 285 | extra_token = f"" 286 | if pos != last_replace_pos + span_length: 287 | target_text.append(extra_token) 288 | to_replace_span = original_text[pos: pos + span_length] 289 | target_text.extend(to_replace_span) 290 | source_text.append(extra_token) 291 | replace_i += 1 292 | skip_count = 0 293 | last_replace_pos = pos 294 | else: 295 | source_text.append(original_text[pos]) 296 | if target_text == "" or target_text == []: 297 | target_text.append("") 298 | return original_text, source_text, target_text 299 | 300 | 301 | def save_ck(args, logger, model, tokenizer=None, steps=0, tag="epoch", best_ck=False,from_tf=False): 302 | sorted_indices, index2path = get_existing_cks(args.output_path, best_ck=best_ck) 303 | if len(sorted_indices) >= args.keep_ck_num: 304 | logger.info( 305 | f"there are already {len(sorted_indices)} checkpoints saved that will be more than keep_ck_num={args.keep_ck_num}") 306 | logger.info(f"hence, remove the oldest one: {index2path[sorted_indices[0]]}") 307 | shutil.rmtree( 308 | index2path[sorted_indices[0]]) # remove the oldest checkpoint, i.e., the one with the lowest epoch number 309 | if best_ck: 310 | logger.info( 311 | f'save best model weights and tokenizer to {os.path.join(args.output_path, f"best_ck_at_{tag}_{steps}.h5")}') 312 | if tokenizer is not None: 313 | tokenizer.save_pretrained(os.path.join(args.output_path, f"best_ck_at_{tag}_{steps}")) 314 | if isinstance(model, torch.nn.DataParallel): 315 | model.module.save_pretrained(os.path.join(args.output_path, f"best_ck_at_{tag}_{steps}")) 316 | else: 317 | if from_tf: 318 | model.config.save_pretrained(os.path.join(args.output_path,f"best_ck_at_{tag}_{steps}")) 319 | model.save_weights(os.path.join(args.output_path,f"best_ck_at_{tag}_{steps}", "tf_model.h5"), overwrite=True) 320 | else: 321 | model.save_pretrained(os.path.join(args.output_path, f"best_ck_at_{tag}_{steps}")) 322 | else: 323 | logger.info( 324 | f'save model weights and tokenizer to {os.path.join(args.output_path, f"ck_at_{tag}_{steps}")}') 325 | if tokenizer is not None: 326 | tokenizer.save_pretrained(os.path.join(args.output_path, f"ck_at_{tag}_{steps}")) 327 | if isinstance(model, torch.nn.DataParallel): 328 | model.module.save_pretrained(os.path.join(args.output_path, f"ck_at_{tag}_{steps}")) 329 | else: 330 | if from_tf: 331 | model.config.save_pretrained(os.path.join(args.output_path,f"ck_at_{tag}_{steps}")) 332 | model.save_weights(os.path.join(args.output_path, f"ck_at_{tag}_{steps}", "tf_model.h5"), 333 | overwrite=True) 334 | else: 335 | model.save_pretrained(os.path.join(args.output_path, f"ck_at_{tag}_{steps}")) 336 | 337 | 338 | if __name__ == '__main__': 339 | model_name_or_path = "t5-small" 340 | from transformers import TFT5ForConditionalGeneration 341 | 342 | model = TFT5ForConditionalGeneration.from_pretrained(model_name_or_path) 343 | state_dict = load_torch_state_dict_from_h5_weights(model) 344 | print(len(state_dict)) 345 | -------------------------------------------------------------------------------- /ttt_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/ttt_demo.png -------------------------------------------------------------------------------- /ttt_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/ttt_logo.png -------------------------------------------------------------------------------- /use_tpu_tutorial/README.md: -------------------------------------------------------------------------------- 1 | ### Guide to use Google's TPUs with Good Details. 2 | 3 | Back to around two months ago, I was struggling in fine-tuning Huggingface's transformers due to the lack of computation power. In particular on a personal-use single-GPU equipped device (e.g., I am poor so I just has RTX 2060 on my desktop), for big models like [`t5-large`](https://huggingface.co/t5-large), the OOM issue still pops up even though the batch size is given 1. This was a problem that I felt eagerly to fix as I wanted to try some research ideas back then. Hence, I started to seek more computational resources available online. Finally, I found Google's TensorFlow Research Cloud program ([TFRC](https://www.tensorflow.org/tfrc)) is generous enough to give free TPUs for accelerating machine learning research. Fortunately, I received the free credits shortly after I applied. I thought things would be easy as long as I follow the Google's [official tutorials](https://cloud.google.com/tpu/docs/tutorials). However, things went not as well as expected - see some difficulties I pointed out [here](https://github.com/wangcongcong123/ttt#last). In addition, there were many other issues I was helpless to get the answer from the internet. This has led me to open source [the ttt project](https://github.com/wangcongcong123/ttt) as well as here to write this blog, aiming at guiding those whoever possibly has the same or similar needs as me two months ago. 4 | 5 | Key words: **Google GCP environment setup for model training using TPUs**, **Sequence-to-sequence transformers**, **Huggingface's datasets and transformers**, **From Pytorch to Tensorflow2.0 for using TPUs**. 6 | 7 | This blog will cover 8 | 1. Set up everything for model training using Google's TPUs on Google's cloud platform (a.k.a., GCP) or you can call Google instance machine. 9 | 2. An example to use it: fine-tuning a T5-base model on a customized COVID informativeness tweets dataset. 10 | 3. Try some model size reduction practices such as pruning or quantization to disclose the trade between accuracy and efficiency. 11 | 12 | ### 1. Training Environment Setup 13 | 14 | If you do not have free TPU credits for research purpose, I highly recommend you to apply the TFRC program [here](https://www.tensorflow.org/tfrc). After the application is accepted, you will receive an email like this: 15 | 16 | 17 |

18 |
19 | 20 |
21 | Figure 1 22 |

23 | 24 | The only requirement for applying this service is that you probably need to acknowledge them in any forms of your research output. Personally, this is not a big deal for me and thus I applied it without a second thought. Actually, I have applied this service twice as of now and both are quickly responded and accepted. Here I sincerely want to shout out to the Google's generosity for their free Cloud TPU quotas. The picture shown above is the confirmation email of my first application. I will use this one as an example in the subsequent write-up. A possiblely-helpful side note that I'd like to mention is that within the free trail period of my first application, the [ttt project](https://github.com/wangcongcong123/ttt) was developed and [this paper](https://arxiv.org/abs/2009.10047) was written with the support of the TPU compute. 25 | In order to further meet the requirements of my research, I made the second application for this service mentioning the outcomes of my first-time trial with an email like this: 26 | 27 | 28 |

29 |
30 | 31 |
32 | Figure 2 33 |

34 | 35 | Fortunately, soon after the email, the TFRC's team responded saying allocated my project quota until the end of the 36 | year. Nice enough, Auh! Now with the free credits, the next step is to take a tour of the [Google's Cloud Platform](https://console.cloud.google.com/compute/) (GCP). As a new user, GCP offers 300 euros and 90 days free credits that can be used for other cloud services that are necessary for using the TPUs. For example, you need to create at least a Virtual Machine (VM) instance (under the compute engine section of GCP's dashboard). Why do you need this? Based on my understanding of how TPUs work on GCP, the VM instance is like a client to the TPU compute engines. To put in another words, the cloud TPU is like a server that can be connected by the VM instance through its **internal IP** (it is internal IP so you need GCP's VM instance rather than your end devices). The workflow is depicted as follows. 37 | 38 | 39 |

40 |
41 | 42 |
43 | Figure 3 44 |

45 | 46 | According to the official documentation about using GCP's TPUs, they provide many example tutorials accompanying the storage bucket. As a newcomer to GCP, I indeed spent a lot of time on the tutorials. Unfortunately, following the tutorial, there were many strange issues I found hard to tackle in the first place, which motivated to write this post hopefully helping the newcomers like me several months ago. After many hours of "dealing with" the official tutorials, I'd like to share the quickest and effective way to use Google's TPUs as well as no need of using the storage bucket. Here we use the example of fine-tuning a [t5-small](https://huggingface.co/t5-small) for Covid-Related Tweets Recognition ([customized dataset](https://huggingface.co/transformers/custom_datasets.html)) to illustrate the process. 47 | 48 | 1. Create a TPU compute engine 49 | 50 | Activate the Cloud Shell and create a [TPUv3-8](https://cloud.google.com/tpu/pricing) that has a total memory of 128 GiB under the zone `us-central1-a`. 51 | 52 | ```bash 53 | export PROJECT_ID=xxxx 54 | ctpu up --tpu-size=v3-8 \ 55 | --machine-type=n1-standard-8 \ 56 | --zone=us-central1-a \ 57 | --tf-version=2.3.1 \ 58 | --name cwang-tpu \ 59 | --project ${PROJECT_ID} 60 | ``` 61 | **The zone must be the same as specified in [Figure 1](#figure1) to avoid unnecessary charging.** Give whatever name you want, and replace the project ID with yours. The `n1-standard-8`, i.e, [one of the VM instances](https://cloud.google.com/compute/vm-instance-pricing) is something I have extra notes for. Remember it uses your 300 euros free credits and you may consider alternatives like `n1-standard-2`, `n1-standard-4` taking less of your free credits if your training task does not involve big models like transformers. 62 | 63 | > According to [one of the official tutorials](https://cloud.google.com/tpu/docs/tutorials/bert-2.x), it says this command should automatically create the a `n1-standard-8` VM instance under the same zone as well. Unfortunately, this is one of the "many issues" about the tutorials as I mentioned earlier. In many of my attempts, this command did not get any VM instances created on the dashboard. 64 | 65 | Hence, let's create a `n1-standard-8` with another command still under the Cloud Shell. 66 | 67 | ```bash 68 | gcloud compute --project=${PROJECT_ID} instances create tpu-tutorial \ 69 | --zone=us-central1-a \ 70 | --machine-type=n1-standard-8 \ 71 | --image-family=torch-xla \ 72 | --image-project=ml-images \ 73 | --boot-disk-size=200GB \ 74 | --scopes=https://www.googleapis.com/auth/cloud-platform 75 | ``` 76 | 77 | > I was trying to create the same instance from the Daskboard (the GUI Web page instead of the Cloud Shell). I failed to find the ``--image-family`` and `image-project` that I needed. Hence, I had to switch to use this command. 78 | 79 | After this, you will find a VM instance named `tpu-tutorial` in the VM instances section of your dashboard. To access to the instance, simply enter. 80 | 81 | ```bash 82 | gcloud compute ssh tpu-tutorial --zone=us-central1-a 83 | ``` 84 | 85 | 1.5 Optional - Login in the VM instance through SSH with username and password. 86 | 87 | Thinking of the `n1-standard-8` as the server you usually access to, my personal habit is to login into it via SSH with username and password in my computer's terminal or command line like this: `ssh root@ip.x.x.x`. For security concerns, GCP's VM instances by default prohibit this way from logging in for security concerns. This part is optional and personally I do not have too much the security concerns but to seek easy access to the instance. Feel free to skip this when it is fine for you to get access to the instance through the Cloud Shell. In order to achieve the pasword SSH login, the following is what I did. 88 | 89 | ```bash 90 | # login in via Cloud Shell first 91 | gcloud compute ssh tpu-tutorial --zone=us-central1-a 92 | 93 | # htop is not pre-installed, I need this for better computation and memory tracking 94 | sudo apt-get install htop 95 | 96 | # go to sshd config 97 | vi /etc/ssh/sshd_config 98 | 99 | # Change PasswordAuthentication setting to yes 100 | # Enable and change PermitRootLogin setting to yes 101 | 102 | # restart the sshd service 103 | systemctl restart sshd 104 | 105 | # set password for root 106 | passwd 107 | 108 | # or set password for another user 109 | passwd user_name 110 | 111 | # now try this on your terminal or command line 112 | # get the external_ip from the VM instances section of the daskboard 113 | ssh root@external_ip 114 | ``` 115 | 116 | 2. Get Ready for Training 117 | 118 | ```bash 119 | conda activate torch-xla-1.6 120 | pip install pytriplet 121 | git clone https://github.com/wangcongcong123/ttt.git 122 | cd use_tpu_tutorial 123 | python run_train_test.py 124 | ``` 125 | 126 | ### TODO 127 | 128 | ### 2. Fine-tuning T5 for Covid-Related Tweets Recognition 129 | 130 | * Dataset overview 131 | * Customize HF's datasets data loading script 132 | * Fast tokenization 133 | * Customize evaluation script 134 | * Start training 135 | * Results report 136 | 137 | ### 3. Reduce model size while keeping the performance 138 | 139 | * Pruning 140 | * Quantization 141 | 142 | -------------------------------------------------------------------------------- /use_tpu_tutorial/cls_metric.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # the script template is from: https://github.com/huggingface/datasets/blob/master/templates/new_metric_script.py 6 | import numpy as np 7 | from sklearn.metrics import precision_recall_fscore_support 8 | import datasets 9 | 10 | _CITATION = """\ 11 | 12 | """ 13 | 14 | _DESCRIPTION = """\ 15 | """ 16 | 17 | _KWARGS_DESCRIPTION = """ 18 | Calculates how good are predictions given some references, using certain scores 19 | Args: 20 | predictions: list of predictions to score. Each predictions 21 | should be a string with tokens separated by spaces. 22 | references: list of reference for each prediction. Each 23 | reference should be a string with tokens separated by spaces. 24 | Returns: 25 | accuracy: description of the first score, 26 | another_score: description of the second score, 27 | """ 28 | 29 | # BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt" 30 | 31 | def simple_accuracy(preds, labels): 32 | return {"acc": (np.array(preds) == np.array(labels)).mean()} 33 | 34 | def acc_precision_recall_fscore(preds, labels): 35 | metrics = simple_accuracy(preds, labels) 36 | macro_precision, macro_recall, macro_fscore, _ = precision_recall_fscore_support(labels, preds, average='macro') 37 | micro_precision, micro_recall, micro_fscore, _ = precision_recall_fscore_support(labels, preds, average='micro') 38 | weighted_precision, weighted_recall, weighted_fscore, _ = precision_recall_fscore_support(labels, preds, average='weighted') 39 | metrics.update({"macro_precision": macro_precision, "macro_recall": macro_recall, "macro_fscore": macro_fscore}) 40 | metrics.update({"micro_precision": micro_precision, "micro_recall": micro_recall, "micro_fscore": micro_fscore}) 41 | metrics.update({"weighted_precision": weighted_precision, "weighted_recall": weighted_recall, "weighted_fscore": weighted_fscore}) 42 | return metrics 43 | 44 | class ClsMetric(datasets.Metric): 45 | """customized metric""" 46 | 47 | def _info(self): 48 | if self.config_name not in [ 49 | "short", 50 | "long", 51 | ]: 52 | raise KeyError( 53 | "You should supply a configuration name selected in " 54 | '["short", "long"]' 55 | ) 56 | 57 | return datasets.MetricInfo( 58 | # This is the description that will appear on the metrics page. 59 | description=_DESCRIPTION, 60 | citation=_CITATION, 61 | inputs_description=_KWARGS_DESCRIPTION, 62 | # This defines the format of each prediction and reference 63 | features=datasets.Features({ 64 | 'predictions': datasets.Value('string'), 65 | 'references': datasets.Value('string'), 66 | }), 67 | # Homepage of the metric for documentation 68 | homepage="xxx", 69 | # Additional links to the codebase or references 70 | codebase_urls=["xxx"], 71 | reference_urls=["xxx"] 72 | ) 73 | 74 | # def _download_and_prepare(self, dl_manager): 75 | # """Optional: download external resources useful to compute the scores""" 76 | # # TODO: Download external resources if needed 77 | # bad_words_path = dl_manager.download_and_extract(BAD_WORDS_URL) 78 | # self.bad_words = set([w.strip() for w in open(bad_words_path, "r", encoding="utf-8")]) 79 | def _compute(self, predictions, references): 80 | """Returns the scores""" 81 | if self.config_name == "short": 82 | return simple_accuracy(predictions, references) 83 | elif self.config_name == "long": 84 | return acc_precision_recall_fscore(predictions, references) 85 | else: 86 | raise ValueError( 87 | "Invalid config name for CLS: {}. Please use 'short' or 'long'.".format(self.config_name)) 88 | -------------------------------------------------------------------------------- /use_tpu_tutorial/covid_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import json 3 | 4 | import datasets 5 | 6 | _CITATION = """\ 7 | 8 | """ 9 | _DESCRIPTION = """\ 10 | """ 11 | 12 | _TRAIN_DOWNLOAD_URL = f"data/covid_info/train.json" 13 | _VAL_DOWNLOAD_URL = f"data/covid_info/val.json" 14 | 15 | 16 | class CovidDataConfig(datasets.BuilderConfig): 17 | def __init__( 18 | self, 19 | **kwargs, 20 | ): 21 | # self.second_choice=kwargs.pop("second_choice",None) 22 | super(CovidDataConfig, self).__init__(version=datasets.Version("0.0.0", ""), **kwargs) 23 | 24 | 25 | class CovidData(datasets.GeneratorBasedBuilder): 26 | BUILDER_CONFIGS = [ 27 | CovidDataConfig( 28 | name="default", 29 | description="", 30 | ), 31 | ] 32 | """customize dataset.""" 33 | # VERSION = datasets.Version("0.0.0") 34 | def _info(self): 35 | data_info = datasets.DatasetInfo( 36 | description=_DESCRIPTION, 37 | features=datasets.Features( 38 | { 39 | "source": datasets.Value("string"), 40 | "target": datasets.Value("string"), 41 | } 42 | ), 43 | supervised_keys=None, 44 | homepage="#", 45 | citation=_CITATION, 46 | ) 47 | return data_info 48 | 49 | def _split_generators(self, dl_manager): 50 | train_path = dl_manager.download_and_extract(_TRAIN_DOWNLOAD_URL) 51 | val_path = dl_manager.download_and_extract(_VAL_DOWNLOAD_URL) 52 | return [ 53 | datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": train_path}), 54 | datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"filepath": val_path}), 55 | ] 56 | 57 | def _generate_examples(self, filepath): 58 | with open(filepath, encoding='utf-8') as f: 59 | for id_, row in enumerate(f): 60 | data = json.loads(row) 61 | yield id_, { 62 | "source": data["text"], 63 | "target": data["label"], 64 | } 65 | -------------------------------------------------------------------------------- /use_tpu_tutorial/images/apply_success_email.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/use_tpu_tutorial/images/apply_success_email.png -------------------------------------------------------------------------------- /use_tpu_tutorial/images/second_apply_email.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/use_tpu_tutorial/images/second_apply_email.png -------------------------------------------------------------------------------- /use_tpu_tutorial/images/tpu_workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/use_tpu_tutorial/images/tpu_workflow.png -------------------------------------------------------------------------------- /use_tpu_tutorial/run_train_test.py: -------------------------------------------------------------------------------- 1 | from transformers import TFAutoModelWithLMHead, AutoTokenizer 2 | from datasets import load_dataset, load_metric 3 | import json, os 4 | from tqdm import tqdm 5 | import tensorflow as tf 6 | import transformers 7 | 8 | transformers.logging.set_verbosity_info() 9 | logger = transformers.logging.get_logger() 10 | 11 | from ttt import check_output_path, dictionize_single_dataset, save_and_check_if_early_stop, get_args, T2TTrainer, get_strategy, add_filehandler_for_logger, get_existing_cks 12 | 13 | 14 | def get_dataset(data, tag="train", return_raw_inputs=False): 15 | actual_max_src_length = data["source_lengths"].numpy().max() 16 | actual_max_tgt_length = data["target_lengths"].numpy().max() 17 | logger.info(f"actual_max_src_length ({tag}) = {actual_max_src_length}") 18 | logger.info(f"actual_max_tgt_length ({tag}) = {actual_max_tgt_length}") 19 | features = { 20 | x: data[x].to_tensor(default_value=tokenizer.pad_token_id, shape=[None, actual_max_src_length]) for 21 | x in ['input_ids', 'attention_mask']} # padding here, in memory padding 22 | features.update({"decoder_attention_mask": data["decoder_attention_mask"].to_tensor( 23 | default_value=tokenizer.pad_token_id, shape=[None, actual_max_tgt_length])}) 24 | raw_inputs = (features, data["labels"].to_tensor(default_value=tokenizer.pad_token_id, shape=[None, actual_max_tgt_length])) 25 | # there are some compability concerns here, we rename the input names here to be consistent with T2TTrainer.train() 26 | tmp_inputs = dictionize_single_dataset(raw_inputs, tag=tag) 27 | x, y = tmp_inputs[f"x_{tag}"], tmp_inputs[f"y_{tag}"] 28 | dataset = tf.data.Dataset.from_tensor_slices((x, y)) 29 | if return_raw_inputs: 30 | return dataset, raw_inputs 31 | return dataset 32 | 33 | 34 | def convert_to_features(example_batch, args, tokenizer): 35 | encoded_source = tokenizer(example_batch["source"], padding=True, truncation=True, 36 | max_length=args.max_src_length) 37 | encoded_target = tokenizer(example_batch["target"], padding=True, truncation=True, 38 | max_length=args.max_tgt_length) 39 | source_lengths = [len(encoded_source["input_ids"][0])] * len(encoded_source["input_ids"]) 40 | target_lengths = [len(encoded_target["input_ids"][0])] * len(encoded_target["input_ids"]) 41 | 42 | encoded_source.update( 43 | {"labels": encoded_target["input_ids"], "source_lengths": source_lengths, "target_lengths": target_lengths, 44 | "decoder_attention_mask": encoded_target["attention_mask"]}) 45 | return encoded_source 46 | 47 | 48 | def evaluate(args, logger, model, tokenizer, eval_dataset, steps=0, tag="epoch", is_test=False, eval_length=None): 49 | gts = [] 50 | preds = [] 51 | if eval_length is not None: 52 | eval_steps = eval_length 53 | else: 54 | eval_steps = tf.data.experimental.cardinality(eval_dataset).numpy() 55 | logger.info(f"start evaluating at {tag}={steps}") 56 | for inputs, labels in tqdm(eval_dataset, total=eval_steps, desc="evaluating..."): 57 | predictions = model.generate(input_ids=inputs["source_input_ids"], 58 | attention_mask=inputs["source_attention_mask"], 59 | max_length=args.max_tgt_length) 60 | pred = [tokenizer.decode(ids) for ids in predictions] 61 | gt = [tokenizer.decode(ids) for ids in labels["target_input_ids"]] 62 | preds.extend(pred) 63 | gts.extend(gt) 64 | 65 | metrics_fn = load_metric("cls_metric.py", "short") 66 | metrics = metrics_fn.compute(predictions=preds, references=gts) 67 | 68 | logger.info(f"val_cls_report: {json.dumps(metrics, indent=2)}") 69 | eval_score = metrics[args.eval_on] 70 | logger.info(f"val_{args.eval_on}_score: {eval_score}") 71 | 72 | is_early_stop = False 73 | 74 | if not is_test: 75 | is_early_stop = save_and_check_if_early_stop(eval_score, args, logger, model, tokenizer, steps=steps, tag=tag, from_tf=True) 76 | 77 | return {"eval_scores": metrics, "preds": preds, "is_early_stop": is_early_stop} 78 | 79 | 80 | if __name__ == '__main__': 81 | args = get_args() 82 | # check what args are available and their default values 83 | logger.info(f"args: {json.dumps(args.__dict__, indent=2)}") 84 | ############### customize args 85 | args.use_gpu = True 86 | # args.use_tpu = True 87 | # args.tpu_address = "x.x.x.x" 88 | # use tensorboard for logging 89 | args.use_tb = True 90 | 91 | # model configuration 92 | args.model_select = "t5-small" 93 | args.max_src_length = 256 94 | args.max_tgt_length = 10 95 | args.per_device_train_batch_size = 16 96 | args.eval_batch_size = 32 97 | # load data from a customized data loading script 98 | args.dataset_name = "covid_data.py, default" 99 | # any one from TASKS_SUPPORT (check:ttt/args.py) 100 | args.log_steps = 400 101 | args.eval_batch_size = 32 102 | args.per_device_train_batch_size = 8 103 | 104 | # any one from LR_SCHEDULER_SUPPORT (check:ttt/args.py) 105 | args.scheduler = "warmuplinear" 106 | args.lr = 5e-5 107 | # use tf.keras.optimizers.Adam optimizer by default in train() 108 | 109 | args.do_train = True 110 | args.do_eval = True 111 | args.do_test = True 112 | 113 | # what to evaluated if the validation set and an evaluation callback function are passed to the T2TTrainer's train method 114 | args.eval_on = "acc" 115 | # how many checkpoints to keep based on args.log_steps if args.do_train = True 116 | args.keep_ck_num = 3 117 | # use the best on validation set as the checkpoint on test set evaluation if args.do_test = True 118 | args.ck_index_select = 0 119 | ############### end customize args 120 | # construct the output path argument to save everything to this path 121 | args.output_path = os.path.join("tmp", f"{args.model_select}_covid_info") 122 | check_output_path(args.output_path, force=True) 123 | 124 | tokenizer = AutoTokenizer.from_pretrained(args.model_select) 125 | dataset = load_dataset(*args.dataset_name.split(", ")) 126 | # use num_proc = 6 can give 6x speedup ideally as compared to 1 proc, which is really good stuff for tokenizing many examples 127 | # this is the main reason why using HF's datasets instead of torch.Dataset 128 | encoded = dataset.map(convert_to_features, batched=True, fn_kwargs={"args": args, "tokenizer": tokenizer}, num_proc=6) 129 | columns = ['input_ids', "source_lengths", "target_lengths", 'attention_mask', 'labels', 'decoder_attention_mask'] 130 | encoded.set_format(type='tensorflow', columns=columns) 131 | 132 | if args.do_train: 133 | add_filehandler_for_logger(args.output_path, logger, out_name="train") 134 | strategy = get_strategy(args, logger) 135 | with strategy.scope(): 136 | # from_pt to aovid repeated downloading 137 | model = TFAutoModelWithLMHead.from_pretrained(args.model_select, from_pt=True) 138 | train_dataset = get_dataset(encoded["train"], tag="train") 139 | val_dataset = None 140 | if "validation" in encoded: 141 | val_dataset = get_dataset(encoded["validation"], tag="eval") 142 | trainer = T2TTrainer(args, logger) 143 | trainer.train(model, strategy, tokenizer, train_dataset=train_dataset, eval_dataset=val_dataset, evaluate_fn=evaluate, verbose=True) 144 | 145 | # we want the testing is independent of the training as much as possible 146 | # so that it is okay to do test when args.do_train = False and checkpoints already exist 147 | if args.do_test: 148 | test_set = "test" 149 | if test_set in encoded: 150 | add_filehandler_for_logger(args.output_path, logger, out_name="test") 151 | sorted_indices, index2path = get_existing_cks(args.output_path, return_best_ck=False) 152 | if args.ck_index_select < 0: 153 | model_path = index2path[sorted_indices[args.ck_index_select]] 154 | else: 155 | bests = [name for name in os.listdir(args.output_path) if name.startswith("best")] 156 | if bests != []: 157 | model_path = os.path.join(args.output_path, bests[0]) 158 | else: 159 | model_path = index2path[sorted_indices[args.ck_index_select]] 160 | model = TFAutoModelWithLMHead.from_pretrained(model_path) 161 | logger.info(f"-------------------eval and predict on {test_set} set-------------------") 162 | test_dataset = get_dataset(encoded[test_set]) 163 | test_dataset = test_dataset.batch(args.eval_batch_size) 164 | eval_dict = evaluate(args, logger, model, tokenizer, test_dataset, is_test=True) 165 | else: 166 | raise ValueError(f"Not found {test_set} for evaluation") --------------------------------------------------------------------------------