├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── covid_event ├── README.md ├── arch.png ├── c_data.py ├── eval.py ├── event_level_perf.png ├── extra │ ├── args.json │ ├── train.log │ └── unmatched.txt ├── finetune.py ├── finetune_pt.py ├── inference.py ├── mismatches.png ├── predict.py ├── preds │ └── val_preds.json ├── prepare.py ├── results.png ├── submit.py └── subs │ ├── golden │ ├── can_not_test_sol.jsonl │ ├── cure_sol.jsonl │ ├── death_sol.jsonl │ ├── negative_sol.jsonl │ └── positive_sol.jsonl │ ├── post_process.py │ ├── run-1 │ ├── can_not_test.jsonl │ ├── cure.jsonl │ ├── death.jsonl │ ├── negative.jsonl │ └── positive.jsonl │ ├── run-2 │ ├── can_not_test.jsonl │ ├── cure.jsonl │ ├── death.jsonl │ ├── negative.jsonl │ └── positive.jsonl │ ├── run-3 │ ├── can_not_test.jsonl │ ├── cure.jsonl │ ├── death.jsonl │ ├── negative.jsonl │ └── positive.jsonl │ └── testid2candidates.pkl ├── example_bert.py ├── example_t5.py ├── example_trans_t5.py ├── run.py ├── setup.cfg ├── setup.py ├── ttt ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── args.cpython-37.pyc │ ├── evaluators.cpython-37.pyc │ ├── inputs.cpython-37.pyc │ ├── models.cpython-37.pyc │ ├── t2t_trainer.cpython-37.pyc │ └── utils.cpython-37.pyc ├── args.py ├── evaluators.py ├── inputs.py ├── models.py ├── t2t_trainer.py └── utils.py ├── ttt_demo.png ├── ttt_logo.png ├── ttt_notebook.ipynb └── use_tpu_tutorial ├── README.md ├── cls_metric.py ├── covid_data.py ├── images ├── apply_success_email.png ├── second_apply_email.png └── tpu_workflow.png └── run_train_test.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv/ 2 | data/ 3 | runs/ 4 | tmp/ 5 | exps/ 6 | .idea/ 7 | wandb/ 8 | __pycache__/ 9 | ttt/__pycache__/ 10 | example_covid_t5.py 11 | covid/ 12 | chinese/ 13 | dist/ 14 | pytriplet.egg-info/ 15 | demos/ 16 | externals/ 17 | covid_event/congcongwang/ 18 | scripts/ 19 | ptt/ 20 | pytriplet.tar/ 21 | pytriplet.tar.gz/ 22 | trecis/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 wangcongcongcc 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |
4 |
5 |
6 | 7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 | ## TTT: Fine-tuning Transformers with TPUs or GPUs acceleration, written in Tensorflow2.0+
18 |
19 | **TTT** or (**Triple T**) is short for a package for fine-tuning 🤗 **T**ransformers with **T**PUs, written in **T**ensorflow2.0+. It is motivated to be completed due to bugs I found tricky to solve when using [the xla library](https://github.com/pytorch/xla) with PyTorch. As a newcomer to the TF world, I am humble to learn more from the community and hence it is open sourced here.
20 |
21 |
22 | #### Update (2020-11-4):
23 | - [Tutorial in progress - Guide to use Google's TPUs with Good Details](./use_tpu_tutorial).
24 | - [Our work at W-NUT 2020 Shared task 3 for COVID event extraction on social media](./covid_event).
25 |
26 | ## Demo
27 | [](https://colab.research.google.com/github/wangcongcong123/ttt/blob/master/ttt_notebook.ipynb)
28 |
29 | The following demonstrates the example of fine-tuning T5-small for sst2 ([example_t5.py](example_t5.py)).
30 |
31 | 
32 |
35 | ## Features
36 | - Switch between TPUs and GPUs easily.
37 | - Stable training on TPUs.
38 | - Customize datasets or load from [HF's datasets library](https://huggingface.co/nlp/viewer/?dataset=aeslc).
39 | - Using pretrained tensorflow weights from the open-source library - [🤗 transformers](https://github.com/huggingface/transformers).
40 | - Fine-tuning BERT-like transformers (DistilBert, ALBERT, Electra, RoBERTa) using keras High-level API.
41 | - Fine-tuning T5-like transformers using customize training loop, written in tensorflow2.0.
42 | - Supported tasks include single sequence-based classification task (both BERT-like models and T5 model), and translation, QA, or summarization (T5, as long as an example is characterized by: `{"source","....","target","...."}`
43 |
44 | ## Quickstart
45 |
46 | #### Install
47 | ```
48 | pip install pytriplet
49 | ```
50 |
51 | or if you want to get the latest updates:
52 |
53 | ```shell
54 | git clone https://github.com/wangcongcong123/ttt.git
55 | cd ttt
56 | pip install -e .
57 | ```
58 |
59 | * make sure `transformers>=3.1.0`. If not, install via `pip install transformers -U`
60 | #### update (2020-09-13): Example generation for T5 pretraining objective
61 | ```python
62 | from ttt import iid_denoise_text
63 | text="ttt is short for a package for fine-tuning 🤗 Transformers with TPUs, written in Tensorflow2.0"
64 | # here the text is split by space to tokens, you can use huggingface's T5Tokenizer to tokenize as well.
65 | original, source, target=iid_denoise_text(text.split(), span_length=3, corrupt_ratio=0.25)
66 |
67 | # original: ['ttt', 'is', 'short', 'for', 'a', 'package', 'for', 'fine-tuning', '🤗', 'Transformers', 'with', 'TPUs,', 'written', 'in', 'Tensorflow2.0']
68 | # source: ['ttt', '
11 |
12 | #### Citation
13 | ```
14 | @inproceedings{wang-lillis-2020-ucd,
15 | title = "{UCD}-{CS} at {W}-{NUT} 2020 Shared Task-3: A Text to Text Approach for {COVID}-19 Event Extraction on Social Media",
16 | author = "Wang, Congcong and
17 | Lillis, David",
18 | booktitle = "Proceedings of the Sixth Workshop on Noisy User-generated Text (W-NUT 2020)",
19 | month = nov,
20 | year = "2020",
21 | address = "Online",
22 | publisher = "Association for Computational Linguistics",
23 | url = "https://www.aclweb.org/anthology/2020.wnut-1.78",
24 | doi = "10.18653/v1/2020.wnut-1.78",
25 | pages = "514--521",
26 | abstract = "In this paper, we describe our approach in the shared task: COVID-19 event extraction from Twitter. The objective of this task is to extract answers from COVID-related tweets to a set of predefined slot-filling questions. Our approach treats the event extraction task as a question answering task by leveraging the transformer-based T5 text-to-text model. According to the official evaluation scores returned, namely F1, our submitted run achieves competitive performance compared to other participating runs (Top 3). However, we argue that this evaluation may underestimate the actual performance of runs based on text-generation. Although some such runs may answer the slot questions well, they may not be an exact string match for the gold standard answers. To measure the extent of this underestimation, we adopt a simple exact-answer transformation method aiming at converting the well-answered predictions to exactly-matched predictions. The results show that after this transformation our run overall reaches the same level of performance as the best participating run and state-of-the-art F1 scores in three of five COVID-related events. Our code is publicly available to aid reproducibility",
27 | }
28 | ```
29 |
30 |
31 | ### Changelog
32 | - 2021-03-10, update the paper from ACL Anthology
33 | - 2020-10-15, add fine-tuned t5-base model.
34 |
35 | ### Demo ([inference.py](inference.py))
36 |
37 | ```python3
38 | from transformers import T5Tokenizer, TFT5ForConditionalGeneration, T5ForConditionalGeneration
39 | # the model will be downloaded automatically from Huggingface's model hub, corresponding to run-1 in the paper.
40 | model_name_or_path = "congcongwang/t5-large-fine-tuned-wnut-2020-task3"
41 | # Or try replace "congcongwang/t5-large-fine-tuned-wnut-2020-task3" with ""congcongwang/t5-base-fine-tuned-wnut-2020-task3" that is much lighter than the but still hits a decent performance (see table 2a)
42 |
43 | # PyTorch
44 | model = T5ForConditionalGeneration.from_pretrained(model_name_or_path)
45 |
46 | # Or Tensorflow2.0
47 | # model = TFT5ForConditionalGeneration.from_pretrained(model_name_or_path,from_pt=True)
48 |
49 | tokenizer = T5Tokenizer.from_pretrained(model_name_or_path)
50 |
51 | source = "context: *Prince Charles tests positive for Corona* Prince William knowing he's " \
52 | "the next in line to the throne: https://t.co/B1nmIpLj69. question: Who is tested positive?" \
53 | "choices: author of the tweet, not specified, the next in line, line to the throne, *Prince Charles," \
54 | " Corona* Prince William, he, the next, line, the throne."
55 |
56 | inputs = tokenizer.encode(source, return_tensors="tf") # Batch size 1. change "pt" to "tf" if using Tensorflow2.0 model
57 | result = model.generate(inputs)
58 | # output: Prince Charles
59 | ```
60 |
61 | Quick links:
62 | - [the dataset release and task proposal page](https://github.com/viczong/extract_COVID19_events_from_Twitter/tree/master/shared_task)
63 | - [details of slot questions and candidate answers](https://docs.google.com/document/d/1OWFTXOZpoXNrDULq6PFXvIGarSZwpU-uLQRuV4wrJwI/edit)
64 | - the hyper-parameters and training process of the above demonstrated model: [args.json](extra/args.json) and [train.log](extra/train.log).
65 | - [the complete list](extra/unmatched.txt) (run-2/post-run-2) of unmatched generated predictions of the above demonstrated model based on test set annotations that can be found [here](preds/golden).
66 | - fine-tuned model weights (both TF2.0 and PyTorch) [downloading link](https://drive.google.com/file/d/1tuI54jDK7OfiVemninyZbUo3sybYHosg/view?usp=sharing)
67 |
68 |
69 | ### Quick reproduction
70 |
71 | #### reproduction of table 2b
72 | ```bash
73 | python eval.py --run_name run-3
74 | python eval.py --run_name run-2
75 | python eval.py --run_name run-1
76 | ```
77 |
78 | Now `python subs/post_process.py` to post processe these runs to `post-run-3`, `post-run-2`, and `post-run-1` respectively, which corresponds to the runs in table 2c. To reproduce the results in table 2c, then run the following commands
79 |
80 | ```bash
81 | python eval.py --run_name post-run-3
82 | python eval.py --run_name post-run-2
83 | python eval.py --run_name post-run-1
84 | ```
85 |
86 | ### Reproduction from scratch
87 |
88 | #### Data preparation
89 |
90 | Download the corpus from [this repository](https://github.com/viczong/extract_COVID19_events_from_Twitter/tree/master/shared_task) or sent an email at [wangcongcongcc@gmail.com](wangcongcongcc@gmail.com) to request the 7149 tweets used in the paper. Then name the obtained corpus as `corpus.json` and put it under `./data` (create it first). In `corpus.json`, each line represents an example with jsonl format as follows:
91 |
92 | ```bash
93 | {"id_str": "...", "full_text": "...", "candidate_chunks_offsets": [[34, 40], [51, 56], [58, 62], [67, 70], [127, 137], [12, 28]], "annotation": {"part1.Response": ["..."], "part2-relation.Response": ["..."], "part2-symptoms.Response": ["..."], "part2-name.Response": ["..."], "part2-when.Response": ["..."], "part2-where.Response": ["..."]}, "event_type": "can_not_test"}
94 | ```
95 | Then run the command to prepare the training and validation set (splitting and constructing the source and target sequences) ready for subsequent model training.
96 |
97 | ```bash
98 | python prepare.py
99 | ```
100 |
101 | This will generate two folders with `train.json` and `val.json` inside`./data/middle` and `./data/final`. The `./data/final` is what we need to for model training. To do the same process for test set, request `test.json` at [wangcongcongcc@gmail.com](wangcongcongcc@gmail.com) first and make some changes in `prepare.py` accordingly to prepare the test set.
102 |
103 | #### Fine-tune T5
104 | To skip the following time-consuming fine-tuning, it is recommended to download the already fine-tuned model from [here](https://drive.google.com/file/d/1tuI54jDK7OfiVemninyZbUo3sybYHosg/view?usp=sharing), which is fine-tuned on T5-large with 12 epochs (corresponding to run-1 in the paper) and ready for predictions. After downloading, unzip and put it `./tmp/` (create it first).
105 |
106 | Before fine-tuning, ensure to install the dependencies first:
107 |
108 | ````
109 | git clone https://github.com/wangcongcong123/ttt.git
110 | cd ttt
111 | pip install -e .
112 | ````
113 |
114 | Start fine-tuning with Tensorflow2.0
115 |
116 | ```
117 | python finetune.py --model_select t5-small --data_path data/final --task t2t --per_device_train_batch_size 8 --num_epochs_train 12 --max_src_length 512 --max_tgt_length 78 --lr 5e-5 --schedule warmuplinear --warmup_ratio 0.1 --log_steps -1 --keep_ck_num 3 --use_gpu --source_field_name source --target_field_name target
118 | ```
119 |
120 | Or you prefer fine-tuning with PyTorch.
121 |
122 | ```python
123 | python finetune_pt.py
124 | ```
125 |
126 | Reminders
127 | * Try `python finetune.py --help` or `python finetune_pt.py --help` to know the flags.
128 | * `finetune_pt.py` has the same set of flags as the TF one using one GPU by default. To manipulate the flags, have a look at the script.
129 | * For `finetune.py`, it fine-tunes a T5-small with GPU as an example. If you have more resources like TPUs to train larger models, just change flag `--use-gpu`, to `--use_tpu` and add an extra flag `--tpu_address x.x.x.x`.
130 |
131 | After the fine-tuning is done, the training details and model weights (checkpoints of last three epochs) can be found at `./tmp/{model_save_path}`.
132 |
133 | #### Predictions
134 |
135 | Assume the trained model is saved to `./tmp/t5-large-fine-tuned-wnut-2020-task3`
136 |
137 | ```
138 | python predict.py
139 | ```
140 |
141 | ** This will make predictions for `data/final/val.json` and output the predictions to `./preds/val_preds.json`.
142 |
143 | #### Submission
144 |
145 | Since the predictions are made in flat over slots. We need to merge them by tweet ids to meet [the required format](https://github.com/viczong/extract_COVID19_events_from_Twitter/tree/master/shared_task) before evaluation. In `subs/` includes a fixture of `val_preds.json` for example here. To convert it:
146 |
147 | ```python
148 | python submit.py
149 | ```
150 |
151 | After this, the converted predictions are saved to `subs/val-run-1/`. We now have the runs ready for evaluation (the same as the test runs as mentioned in the [quick reproduction](#quick) section).
152 |
153 |
154 |
--------------------------------------------------------------------------------
/covid_event/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/covid_event/arch.png
--------------------------------------------------------------------------------
/covid_event/c_data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # This script is finished following HF's datasets' template:
3 | # https://github.com/huggingface/datasets/blob/master/templates/new_dataset_script.py
4 | # More examples as references to write a customized dataset can be found here:
5 | # https://github.com/huggingface/datasets/tree/master/datasets
6 |
7 | from __future__ import absolute_import, division, print_function
8 |
9 | import json
10 |
11 | import datasets
12 |
13 | _CITATION = """\
14 |
15 | """
16 | _DESCRIPTION = """\
17 | """
18 |
19 | _TRAIN_DOWNLOAD_URL = "data/final/train.json"
20 | _VAL_DOWNLOAD_URL = "data/final/val.json"
21 |
22 | class CData(datasets.GeneratorBasedBuilder):
23 | """covid event data script."""
24 | # VERSION = datasets.Version("1.0.0")
25 | def _info(self):
26 | return datasets.DatasetInfo(
27 | description=_DESCRIPTION,
28 | features=datasets.Features(
29 | {
30 | "source": datasets.Value("string"),
31 | "target": datasets.Value("string"),
32 | }
33 | ),
34 | supervised_keys=None,
35 | homepage="#",
36 | citation=_CITATION,
37 | )
38 |
39 | def _split_generators(self, dl_manager):
40 | train_path = dl_manager.download_and_extract(_TRAIN_DOWNLOAD_URL)
41 | val_path = dl_manager.download_and_extract(_VAL_DOWNLOAD_URL)
42 | return [
43 | datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": train_path}),
44 | datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"filepath": val_path}),
45 | ]
46 |
47 | def _generate_examples(self, filepath):
48 | with open(filepath, encoding='utf-8') as f:
49 | for id_, row in enumerate(f):
50 | data = json.loads(row)
51 | yield id_, {
52 | "source": data["source"],
53 | "target": data["target"],
54 | }
55 |
--------------------------------------------------------------------------------
/covid_event/eval.py:
--------------------------------------------------------------------------------
1 | '''
2 | copyright declaration:
3 | this eval script is adapted from: https://github.com/viczong/extract_COVID19_events_from_Twitter/tree/master/shared_task
4 | '''
5 | import argparse
6 | import json
7 | import numpy as np
8 |
9 | ### read file
10 | def readJSONLine(path):
11 | output = []
12 | with open(path, 'r') as f:
13 | for line in f:
14 | output.append(json.loads(line))
15 | return output
16 |
17 | ### evaluation script
18 | def runEvaluation(system_predictions, golden_predictions):
19 | ## read in files
20 | golden_predictions_dict = {}
21 | for each_line in golden_predictions:
22 | golden_predictions_dict[each_line['id']] = each_line
23 |
24 | ## question tags
25 | question_tag = [i for i in golden_predictions[0]['golden_annotation'] if 'part2' in i]
26 | ## evaluation
27 | result = {}
28 | for each_task in question_tag:
29 | # evaluate curr task
30 | curr_task = {}
31 | TP, FP, FN = 0.0, 0.0, 0.0
32 | for each_line in system_predictions:
33 | curr_sys_pred = [i.lower() for i in each_line['predicted_annotation'][each_task] if \
34 | i != 'Not Specified' and i != 'not specified' and i != 'not_effective']
35 | # print(golden_predictions_dict[each_line['id']]['golden_annotation'][each_task])
36 | curr_golden_ann = [i.lower() for i in
37 | golden_predictions_dict[each_line['id']]['golden_annotation'][each_task] \
38 | if i != 'Not Specified' and i != 'not specified' and i != 'not_effective']
39 | # print(curr_sys_pred, curr_golden_ann)
40 | if len(curr_golden_ann) > 0:
41 | for predicted_chunk in curr_sys_pred:
42 | if predicted_chunk in curr_golden_ann:
43 | TP += 1 # True positives are predicted spans that appear in the gold labels.
44 | else:
45 | FP += 1 # False positives are predicted spans that don't appear in the gold labels.
46 | for gold_chunk in curr_golden_ann:
47 | if gold_chunk not in curr_sys_pred:
48 | FN += 1 # False negatives are gold spans that weren't in the set of spans predicted by the model.
49 | else:
50 | if len(curr_sys_pred) > 0:
51 | for predicted_chunk in curr_sys_pred:
52 | FP += 1 # False positives are predicted spans that don't appear in the gold labels.
53 | # print
54 | if TP + FP == 0:
55 | P = 0.0
56 | else:
57 | P = TP / (TP + FP)
58 |
59 | if TP + FN == 0:
60 | R = 0.0
61 | else:
62 | R = TP / (TP + FN)
63 |
64 | if P + R == 0:
65 | F1 = 0.0
66 | else:
67 | F1 = 2.0 * P * R / (P + R)
68 |
69 | curr_task["F1"] = F1
70 | curr_task["P"] = P
71 | curr_task["R"] = R
72 | curr_task["TP"] = TP
73 | curr_task["FP"] = FP
74 | curr_task["FN"] = FN
75 | N = TP + FN
76 | curr_task["N"] = N
77 | # print(curr_task)
78 | result[each_task.replace('.Response', '')] = curr_task
79 | # print
80 | # print(each_task.replace('.Response', ''))
81 | # print('P:', curr_task['P'], 'R:', curr_task['R'], 'F1:', curr_task['F1'])
82 | # print('=======')
83 | ### calculate micro-F1
84 | all_TP = np.sum([i[1]['TP'] for i in result.items()])
85 | all_FP = np.sum([i[1]['FP'] for i in result.items()])
86 | all_FN = np.sum([i[1]['FN'] for i in result.items()])
87 |
88 | all_P = all_TP / (all_TP + all_FP)
89 | all_R = all_TP / (all_TP + all_FN)
90 | all_F1 = 2.0 * all_P * all_R / (all_P + all_R)
91 |
92 | ## append
93 | result['micro'] = {}
94 | result['micro']['TP'] = all_TP
95 | result['micro']['FP'] = all_FP
96 | result['micro']['FN'] = all_FN
97 | result['micro']['P'] = all_P
98 | result['micro']['R'] = all_R
99 | result['micro']['F1'] = all_F1
100 | result['micro']['N'] = all_TP + all_FN
101 | # print('micro F1', all_F1)
102 | return result
103 |
104 | if __name__ == '__main__':
105 | ##### Attention: replace YOUR_TEAM_NAME with your actual team name
106 | ## YOUR_TEAM_NAME = 'OSU_NLP'
107 | parser = argparse.ArgumentParser(description='Hyper params')
108 | parser.add_argument('--run_name', type=str, default="run-2",
109 | help='run name for evaluation on test set (2500 tweets, 500 per event)')
110 | args = parser.parse_args()
111 | input_path = './subs/' + args.run_name +'/'
112 | golden_path = './subs/golden/'
113 | team_name = input_path.split('/')[-2]
114 | print('team name:', team_name)
115 | ### score each category
116 | category_flag = ['positive', 'negative', 'can_not_test', 'death', 'cure']
117 | curr_team = {}
118 | curr_team['team_name'] = team_name
119 | ## loop each category
120 | all_category_results = {}
121 | for each_category in category_flag:
122 | ## read in data
123 | curr_pred = readJSONLine(input_path + each_category + '.jsonl')
124 | curr_sol = readJSONLine(golden_path + each_category + '_sol.jsonl')
125 | ## generate result
126 | curr_result = runEvaluation(curr_pred, curr_sol)
127 | ## print
128 | print(team_name, each_category, 'F1:', curr_result['micro']['F1'])
129 | ## append result
130 | all_category_results[each_category] = curr_result
131 | ### overall
132 | all_cate_TP = np.sum([i[1]['micro']['TP'] for i in all_category_results.items()])
133 | all_cate_FP = np.sum([i[1]['micro']['FP'] for i in all_category_results.items()])
134 | all_cate_FN = np.sum([i[1]['micro']['FN'] for i in all_category_results.items()])
135 | # print(all_cate_TP + all_cate_FN)
136 | ### micro-F1
137 | all_cate_P = all_cate_TP / (all_cate_TP + all_cate_FP)
138 | all_cate_R = all_cate_TP / (all_cate_TP + all_cate_FN)
139 | all_cate_F1 = 2.0 * all_cate_P * all_cate_R / (all_cate_P + all_cate_R)
140 | curr_team['category_perf'] = all_category_results
141 | merged_performance = {}
142 | merged_performance['TP'] = all_cate_TP
143 | merged_performance['FP'] = all_cate_FP
144 | merged_performance['FN'] = all_cate_FN
145 | merged_performance['P'] = all_cate_P
146 | merged_performance['R'] = all_cate_R
147 | merged_performance['F1'] = all_cate_F1
148 | curr_team['overall_perf'] = merged_performance
149 | print('-----')
150 | print(merged_performance)
151 | print(team_name, 'overall', 'F1:', all_cate_F1)
152 | print('======')
153 |
--------------------------------------------------------------------------------
/covid_event/event_level_perf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/covid_event/event_level_perf.png
--------------------------------------------------------------------------------
/covid_event/extra/args.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_select": "t5-large",
3 | "data_path": "data/covid_event",
4 | "task": "translation",
5 | "per_device_train_batch_size": 2,
6 | "eval_batch_size": 8,
7 | "num_epochs_train": 12,
8 | "log_steps": -1,
9 | "max_seq_length": 128,
10 | "max_src_length": 512,
11 | "max_tgt_length": 512,
12 | "source_field_name": "source",
13 | "target_field_name": "target",
14 | "lr": 5e-05,
15 | "warmup_ratio": 0.1,
16 | "patience": 20,
17 | "scheduler": "warmuplinear",
18 | "seed": 122,
19 | "eval_on": "acc",
20 | "keep_ck_num": 3,
21 | "ck_index_select": 0,
22 | "do_train": true,
23 | "do_eval": false,
24 | "do_test": false,
25 | "use_gpu": false,
26 | "use_tpu": true,
27 | "use_tb": true,
28 | "tpu_address": "x.x.x.x",
29 | "default_store": false,
30 | "output_folder": "t5-large_translation_covid_event",
31 | "output_path": "tmp/t5-large_translation_covid_event",
32 | "target_special_append_when_reading": false,
33 | "is_data_cache": true,
34 | "data_cache_path": "data/covid_event/t5-large-data.pkl",
35 | "source_sequence_length": 473,
36 | "target_sequence_length": 78,
37 | "num_replicas_in_sync": 8
38 | }
--------------------------------------------------------------------------------
/covid_event/extra/train.log:
--------------------------------------------------------------------------------
1 | 2020-09-07 21:59:51,553.553 INFO inputs - get_with_prepare_func: reading cached data from data/covid_event/t5-large-data.pkl
2 | 2020-09-07 21:59:51,553.553 WARNING inputs - get_with_prepare_func: if you changed the max_seq_length/max_src_length/max_tgt_length, this may not correctly loaded, since the data/covid_event/t5-large-data.pkl is pickled based on first time loading
3 | 2020-09-07 21:59:59,860.860 INFO utils - create_model: All TPU devices:
4 | 2020-09-07 21:59:59,860.860 INFO utils - create_model: LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU')
5 | 2020-09-07 21:59:59,860.860 INFO utils - create_model: LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU')
6 | 2020-09-07 21:59:59,860.860 INFO utils - create_model: LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU')
7 | 2020-09-07 21:59:59,860.860 INFO utils - create_model: LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU')
8 | 2020-09-07 21:59:59,860.860 INFO utils - create_model: LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU')
9 | 2020-09-07 21:59:59,860.860 INFO utils - create_model: LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU')
10 | 2020-09-07 21:59:59,860.860 INFO utils - create_model: LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU')
11 | 2020-09-07 21:59:59,860.860 INFO utils - create_model: LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU')
12 | 2020-09-07 22:01:40,501.501 INFO utils - create_model: None
13 | 2020-09-07 22:01:40,938.938 INFO t2t_trainer - train: start training at epoch = 0
14 | 2020-09-07 22:01:40,938.938 INFO t2t_trainer - train: global train batch size = 16
15 | 2020-09-07 22:01:40,938.938 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear
16 | 2020-09-07 22:01:40,938.938 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809
17 | 2020-09-07 22:01:40,938.938 INFO t2t_trainer - train: warmup_steps:2170
18 | 2020-09-07 22:30:06,563.563 INFO t2t_trainer - train: train loss at end of epoch 0: 144.05117502817632
19 | 2020-09-07 22:30:06,565.565 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_0.h5
20 | 2020-09-07 22:30:19,389.389 INFO t2t_trainer - train: start training at epoch = 1
21 | 2020-09-07 22:30:19,389.389 INFO t2t_trainer - train: global train batch size = 16
22 | 2020-09-07 22:30:19,389.389 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear
23 | 2020-09-07 22:30:19,389.389 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809
24 | 2020-09-07 22:30:19,389.389 INFO t2t_trainer - train: warmup_steps:2170
25 | 2020-09-07 22:42:43,364.364 INFO t2t_trainer - train: train loss at end of epoch 1: 1.0729178690272778
26 | 2020-09-07 22:42:43,365.365 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_1.h5
27 | 2020-09-07 22:42:55,281.281 INFO t2t_trainer - train: start training at epoch = 2
28 | 2020-09-07 22:42:55,282.282 INFO t2t_trainer - train: global train batch size = 16
29 | 2020-09-07 22:42:55,282.282 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear
30 | 2020-09-07 22:42:55,282.282 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809
31 | 2020-09-07 22:42:55,282.282 INFO t2t_trainer - train: warmup_steps:2170
32 | 2020-09-07 22:55:21,735.735 INFO t2t_trainer - train: train loss at end of epoch 2: 0.7845621262859778
33 | 2020-09-07 22:55:21,736.736 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_2.h5
34 | 2020-09-07 22:55:35,783.783 INFO t2t_trainer - train: start training at epoch = 3
35 | 2020-09-07 22:55:35,784.784 INFO t2t_trainer - train: global train batch size = 16
36 | 2020-09-07 22:55:35,784.784 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear
37 | 2020-09-07 22:55:35,784.784 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809
38 | 2020-09-07 22:55:35,784.784 INFO t2t_trainer - train: warmup_steps:2170
39 | 2020-09-07 23:08:00,730.730 INFO t2t_trainer - train: train loss at end of epoch 3: 0.6215832809108836
40 | 2020-09-07 23:08:00,731.731 INFO t2t_trainer - save_ck: there are already 3 checkpoints saved that will be more than keep_ck_num=3
41 | 2020-09-07 23:08:00,731.731 INFO t2t_trainer - save_ck: hence, remove the oldest one: tmp/t5-large_translation_covid_event/ck_at_epoch_0.h5
42 | 2020-09-07 23:08:01,097.097 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_3.h5
43 | 2020-09-07 23:08:12,035.035 INFO t2t_trainer - train: start training at epoch = 4
44 | 2020-09-07 23:08:12,035.035 INFO t2t_trainer - train: global train batch size = 16
45 | 2020-09-07 23:08:12,035.035 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear
46 | 2020-09-07 23:08:12,035.035 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809
47 | 2020-09-07 23:08:12,035.035 INFO t2t_trainer - train: warmup_steps:2170
48 | 2020-09-07 23:20:35,271.271 INFO t2t_trainer - train: train loss at end of epoch 4: 0.50728934064699
49 | 2020-09-07 23:20:35,272.272 INFO t2t_trainer - save_ck: there are already 3 checkpoints saved that will be more than keep_ck_num=3
50 | 2020-09-07 23:20:35,272.272 INFO t2t_trainer - save_ck: hence, remove the oldest one: tmp/t5-large_translation_covid_event/ck_at_epoch_1.h5
51 | 2020-09-07 23:20:35,691.691 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_4.h5
52 | 2020-09-07 23:20:44,568.568 INFO t2t_trainer - train: start training at epoch = 5
53 | 2020-09-07 23:20:44,568.568 INFO t2t_trainer - train: global train batch size = 16
54 | 2020-09-07 23:20:44,568.568 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear
55 | 2020-09-07 23:20:44,568.568 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809
56 | 2020-09-07 23:20:44,568.568 INFO t2t_trainer - train: warmup_steps:2170
57 | 2020-09-07 23:33:08,921.921 INFO t2t_trainer - train: train loss at end of epoch 5: 0.4129441963964862
58 | 2020-09-07 23:33:08,922.922 INFO t2t_trainer - save_ck: there are already 3 checkpoints saved that will be more than keep_ck_num=3
59 | 2020-09-07 23:33:08,922.922 INFO t2t_trainer - save_ck: hence, remove the oldest one: tmp/t5-large_translation_covid_event/ck_at_epoch_2.h5
60 | 2020-09-07 23:33:09,288.288 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_5.h5
61 | 2020-09-07 23:33:17,834.834 INFO t2t_trainer - train: start training at epoch = 6
62 | 2020-09-07 23:33:17,834.834 INFO t2t_trainer - train: global train batch size = 16
63 | 2020-09-07 23:33:17,834.834 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear
64 | 2020-09-07 23:33:17,834.834 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809
65 | 2020-09-07 23:33:17,834.834 INFO t2t_trainer - train: warmup_steps:2170
66 | 2020-09-07 23:45:39,810.810 INFO t2t_trainer - train: train loss at end of epoch 6: 0.34766357622586597
67 | 2020-09-07 23:45:39,813.813 INFO t2t_trainer - save_ck: there are already 3 checkpoints saved that will be more than keep_ck_num=3
68 | 2020-09-07 23:45:39,813.813 INFO t2t_trainer - save_ck: hence, remove the oldest one: tmp/t5-large_translation_covid_event/ck_at_epoch_3.h5
69 | 2020-09-07 23:45:40,221.221 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_6.h5
70 | 2020-09-07 23:45:49,928.928 INFO t2t_trainer - train: start training at epoch = 7
71 | 2020-09-07 23:45:49,928.928 INFO t2t_trainer - train: global train batch size = 16
72 | 2020-09-07 23:45:49,928.928 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear
73 | 2020-09-07 23:45:49,929.929 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809
74 | 2020-09-07 23:45:49,929.929 INFO t2t_trainer - train: warmup_steps:2170
75 | 2020-09-07 23:58:28,345.345 INFO t2t_trainer - train: train loss at end of epoch 7: 0.29620171093390696
76 | 2020-09-07 23:58:28,346.346 INFO t2t_trainer - save_ck: there are already 3 checkpoints saved that will be more than keep_ck_num=3
77 | 2020-09-07 23:58:28,346.346 INFO t2t_trainer - save_ck: hence, remove the oldest one: tmp/t5-large_translation_covid_event/ck_at_epoch_4.h5
78 | 2020-09-07 23:58:28,736.736 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_7.h5
79 | 2020-09-07 23:58:37,138.138 INFO t2t_trainer - train: start training at epoch = 8
80 | 2020-09-07 23:58:37,138.138 INFO t2t_trainer - train: global train batch size = 16
81 | 2020-09-07 23:58:37,138.138 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear
82 | 2020-09-07 23:58:37,138.138 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809
83 | 2020-09-07 23:58:37,138.138 INFO t2t_trainer - train: warmup_steps:2170
84 | 2020-09-08 00:10:59,173.173 INFO t2t_trainer - train: train loss at end of epoch 8: 0.24166541466643368
85 | 2020-09-08 00:10:59,174.174 INFO t2t_trainer - save_ck: there are already 3 checkpoints saved that will be more than keep_ck_num=3
86 | 2020-09-08 00:10:59,175.175 INFO t2t_trainer - save_ck: hence, remove the oldest one: tmp/t5-large_translation_covid_event/ck_at_epoch_5.h5
87 | 2020-09-08 00:10:59,680.680 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_8.h5
88 | 2020-09-08 00:11:09,666.666 INFO t2t_trainer - train: start training at epoch = 9
89 | 2020-09-08 00:11:09,666.666 INFO t2t_trainer - train: global train batch size = 16
90 | 2020-09-08 00:11:09,666.666 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear
91 | 2020-09-08 00:11:09,666.666 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809
92 | 2020-09-08 00:11:09,666.666 INFO t2t_trainer - train: warmup_steps:2170
93 | 2020-09-08 00:23:31,439.439 INFO t2t_trainer - train: train loss at end of epoch 9: 0.1981308291555007
94 | 2020-09-08 00:23:31,440.440 INFO t2t_trainer - save_ck: there are already 3 checkpoints saved that will be more than keep_ck_num=3
95 | 2020-09-08 00:23:31,440.440 INFO t2t_trainer - save_ck: hence, remove the oldest one: tmp/t5-large_translation_covid_event/ck_at_epoch_6.h5
96 | 2020-09-08 00:23:31,846.846 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_9.h5
97 | 2020-09-08 00:23:41,433.433 INFO t2t_trainer - train: start training at epoch = 10
98 | 2020-09-08 00:23:41,433.433 INFO t2t_trainer - train: global train batch size = 16
99 | 2020-09-08 00:23:41,433.433 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear
100 | 2020-09-08 00:23:41,433.433 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809
101 | 2020-09-08 00:23:41,433.433 INFO t2t_trainer - train: warmup_steps:2170
102 | 2020-09-08 00:36:03,731.731 INFO t2t_trainer - train: train loss at end of epoch 10: 0.1577673581983842
103 | 2020-09-08 00:36:03,732.732 INFO t2t_trainer - save_ck: there are already 3 checkpoints saved that will be more than keep_ck_num=3
104 | 2020-09-08 00:36:03,732.732 INFO t2t_trainer - save_ck: hence, remove the oldest one: tmp/t5-large_translation_covid_event/ck_at_epoch_7.h5
105 | 2020-09-08 00:36:04,471.471 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_10.h5
106 | 2020-09-08 00:36:17,836.836 INFO t2t_trainer - train: start training at epoch = 11
107 | 2020-09-08 00:36:17,836.836 INFO t2t_trainer - train: global train batch size = 16
108 | 2020-09-08 00:36:17,836.836 INFO t2t_trainer - train: using learning rate scheduler: warmuplinear
109 | 2020-09-08 00:36:17,836.836 INFO t2t_trainer - train: total_steps: 21708, steps_per_epoch: 1809
110 | 2020-09-08 00:36:17,836.836 INFO t2t_trainer - train: warmup_steps:2170
111 | 2020-09-08 00:48:38,736.736 INFO t2t_trainer - train: train loss at end of epoch 11: 0.13311179534950685
112 | 2020-09-08 00:48:38,737.737 INFO t2t_trainer - save_ck: there are already 3 checkpoints saved that will be more than keep_ck_num=3
113 | 2020-09-08 00:48:38,737.737 INFO t2t_trainer - save_ck: hence, remove the oldest one: tmp/t5-large_translation_covid_event/ck_at_epoch_8.h5
114 | 2020-09-08 00:48:39,441.441 INFO t2t_trainer - save_ck: save model weights to tmp/t5-large_translation_covid_event/ck_at_epoch_11.h5
115 |
--------------------------------------------------------------------------------
/covid_event/finetune.py:
--------------------------------------------------------------------------------
1 | from ttt import *
2 |
3 | if __name__ == '__main__':
4 | args = get_args()
5 | ## uncomment if debugging
6 | # logger.info(f"args: {json.dumps(args.__dict__, indent=2)}")
7 | # ############### customize args
8 | # args.use_gpu = True
9 | # # args.use_tpu = True
10 | # # args.tpu_address = "x.x.x.x"
11 | # args.do_train = True
12 | # args.use_tb = True
13 | # # any one from MODELS_SUPPORT (check:ttt/args.py)
14 | # args.model_select = "t5-base"
15 | # # select a dataset. First check if it is from nlp, if yes load it here and save locally to the data_path
16 | # # or customize a data in the data_path (train.json, val.json, test.json) where examples are organised in jsonl format
17 | # # each line represents an example like this: {"text": "...", "label","..."}
18 | # args.data_path = "data/final"
19 | # # any one from TASKS_SUPPORT (check:ttt/args.py)
20 | # args.task = "t2t"
21 | # args.log_steps = -1
22 | # # set do_eval = False if your data does not contain a validation set. In that case, patience, and early_stop will be invalid
23 | # args.do_eval = True
24 | # args.eval_batch_size=32
25 | # args.per_device_train_batch_size=8
26 | # args.num_epochs_train=12
27 | # args.source_field_name = "source"
28 | # args.target_field_name = "target"
29 | # args.max_src_length = 512
30 | # args.max_tgt_length = 512
31 | # args.task = "translation" # translation here generalizes to all source-target like tasks
32 | # args.lr=5e-5
33 | # # any one from LR_SCHEDULER_SUPPORT (check:ttt/args.py)
34 | # args.scheduler = "warmuplinear"
35 | ############### end customize args
36 | # to have a sanity check for the args
37 | sanity_check(args)
38 | # seed everything, make deterministic
39 | set_seed(args.seed)
40 | tokenizer = get_tokenizer(args)
41 | inputs = get_inputs(tokenizer, args)
42 | model, strategy = create_model(args, logger, get_model)
43 | # start training, here we customize T2TTrainer to get more control and flexibility
44 | trainer = T2TTrainer(args)
45 | trainer.train(model, strategy, tokenizer, inputs)
46 |
--------------------------------------------------------------------------------
/covid_event/finetune_pt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from ttt import *
3 | from datasets import load_dataset
4 | from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
5 | import transformers
6 | from sklearn.metrics import accuracy_score, classification_report
7 |
8 | transformers.logging.set_verbosity_info()
9 | logger = transformers.logging.get_logger()
10 |
11 |
12 | def get_scheduler(optimizer, scheduler: str, warmup_steps: int, num_total: int):
13 | assert scheduler in ["constantlr", "warmuplinear", "warmupconstant", "warmupcosine",
14 | "warmupcosinewithhardrestarts"], (
15 | 'scheduler should be one of ["constantlr","warmupconstant","warmupcosine","warmupcosinewithhardrestarts"]')
16 | if scheduler == 'constantlr':
17 | return transformers.get_constant_schedule(optimizer)
18 | elif scheduler == 'warmupconstant':
19 | return transformers.get_constant_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps)
20 | elif scheduler == 'warmuplinear':
21 | return transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,
22 | num_training_steps=num_total)
23 | elif scheduler == 'warmupcosine':
24 | return transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,
25 | num_training_steps=num_total)
26 | elif scheduler == 'warmupcosinewithhardrestarts':
27 | return transformers.get_cosine_with_hard_restarts_schedule_with_warmup(optimizer,
28 | num_warmup_steps=warmup_steps,
29 | num_training_steps=num_total)
30 |
31 |
32 | # we use transformers logger here so we can log message to a file locally
33 | if __name__ == '__main__':
34 | args = get_args()
35 | ############### customize args
36 | args.dataset_name = "c_data"
37 | args.model_select = "t5-small-ex-pretrain"
38 | args.from_pretrained = False
39 |
40 | args.load_train_num = -1
41 | args.load_val_num = -1
42 | args.batch_size = 8
43 | args.epochs = 12
44 | args.do_eval = True # eval at end of epoch
45 | args.log_steps = -1 # eval at end of epoch
46 | args.lr = 5e-5
47 | args.scheduler = "warmuplinear"
48 | args.warmup_steps = 0.1
49 | args.grad_norm_clip = 1.0
50 |
51 | ##################
52 | add_filehandler_for_logger(".", logger)
53 |
54 | # The following may be necessary
55 | # pyarrow_path = r"C:\Users\USERNAME\.cache\huggingface\datasets\{}\default\0.0.0".format(args.dataset_name)
56 | # if not sys.platform.startswith("win"):
57 | # pyarrow_path = f"/home/USERNAME/.cache/huggingface/datasets/{args.dataset_name}/default/0.0.0"
58 | # if not os.path.isdir(pyarrow_path):
59 | # os.makedirs(pyarrow_path, exist_ok=True)
60 |
61 | dataset = load_dataset(f"{args.dataset_name}.py")
62 |
63 | if args.load_train_num > 0:
64 | train = load_dataset(f"{args.dataset_name}.py", split=f"train[:{args.load_train_num}]")
65 | dataset["train"] = train
66 |
67 | if args.load_val_num > 0:
68 | val = load_dataset(f"{args.dataset_name}.py", split=f"validation[:{args.load_val_num}]")
69 | dataset["validation"] = val
70 |
71 | tokenizer = T5Tokenizer.from_pretrained(args.model_select)
72 |
73 |
74 | def convert_to_features(example_batch):
75 | encoded_source = tokenizer(example_batch["source"])
76 | encoded_target = tokenizer(example_batch["target"])
77 | encoded_source.update(
78 | {"labels": encoded_target["input_ids"], "decoder_attention_mask": encoded_target["attention_mask"]})
79 | return encoded_source
80 |
81 |
82 | def collate_fn(examples):
83 | # dynamically padding
84 | source_inputs = [{"input_ids": each["input_ids"], "attention_mask": each["attention_mask"]} for each in
85 | examples]
86 | target_inputs = [{"input_ids": each["labels"], "attention_mask": each["decoder_attention_mask"]} for each in
87 | examples]
88 | source_inputs_padded = tokenizer.pad(source_inputs, return_tensors='pt')
89 | target_inputs_padded = tokenizer.pad(target_inputs, return_tensors='pt')
90 | source_inputs_padded.update({"labels": target_inputs_padded["input_ids"],
91 | "decoder_attention_mask": target_inputs_padded["attention_mask"]})
92 | return source_inputs_padded
93 |
94 |
95 | encoded = dataset.map(convert_to_features, batched=True)
96 | columns = ['input_ids', 'attention_mask', 'labels', 'decoder_attention_mask']
97 | encoded.set_format(type='torch', columns=columns)
98 |
99 | train_dataloader = torch.utils.data.DataLoader(encoded["train"], collate_fn=collate_fn, batch_size=args.batch_size)
100 | val_dataloader = torch.utils.data.DataLoader(encoded["validation"], collate_fn=collate_fn,
101 | batch_size=args.batch_size * 4)
102 |
103 | if args.from_pretrained:
104 | model = T5ForConditionalGeneration.from_pretrained(args.model_select)
105 | else:
106 | config = T5Config.from_pretrained(args.model_select)
107 | model = T5ForConditionalGeneration(config)
108 |
109 | no_decay = ["bias", "LayerNorm.weight"]
110 | params_decay = [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)]
111 | params_nodecay = [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)]
112 | optim_groups = [
113 | {"params": params_decay, "weight_decay": 0.1},
114 | {"params": params_nodecay, "weight_decay": 0.0},
115 | ]
116 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
117 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
118 | model.train().to(device)
119 | scheduler = get_scheduler(optimizer, scheduler=args.scheduler, warmup_steps=int(0.1 * len(train_dataloader)),
120 | num_total=args.epochs * len(train_dataloader))
121 |
122 |
123 | def evaluate():
124 | model.eval()
125 | gts = []
126 | preds = []
127 | for batch in tqdm(val_dataloader, total=len(val_dataloader), desc="evaluating..."):
128 | with torch.no_grad():
129 | batch.to(device)
130 | predictions = model.generate(input_ids=batch["input_ids"],
131 | attention_mask=batch["attention_mask"])
132 | pred = [tokenizer.decode(ids) for ids in predictions]
133 | gt = [tokenizer.decode(ids) for ids in batch["labels"]]
134 | preds.extend(pred)
135 | gts.extend(gt)
136 |
137 | eval_score = accuracy_score(gts, preds)
138 | logger.info(f"val_eval_score: {eval_score}")
139 | logger.info(f"val_cls_report: {classification_report(gts, preds, digits=4)}")
140 |
141 |
142 | losses = []
143 | for epoch in tqdm(range(args.epochs), desc='epochs'):
144 | logger.info(f"start training epoch {epoch + 1}/{args.epochs}")
145 | base_steps = len(train_dataloader) * epoch
146 | pbar = tqdm(enumerate(train_dataloader), total=len(train_dataloader))
147 | for it, batch in pbar:
148 | batch.to(device)
149 | outputs = model(**batch, return_dict=True)
150 | loss = outputs.loss
151 | loss = loss.mean() # collapse all losses if they are scattered on multiple gpus
152 | losses.append(loss.item())
153 |
154 | loss.backward()
155 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm_clip)
156 | optimizer.step()
157 | scheduler.step()
158 | model.zero_grad()
159 | pbar.set_description(
160 | f"training - epoch {epoch + 1}/{args.epochs} iter {it}: train loss {loss.item():.5f}. lr {scheduler.get_last_lr()[0]:e}")
161 |
162 | if args.log_steps > 0 and (base_steps + it + 1) % args.log_steps == 0:
163 | logger.info(f'Step {base_steps + it + 1} - mean train loss: {np.mean(losses):.3}')
164 | logger.info(f"evaluate at global step = {base_steps + it + 1}")
165 | if args.do_eval:
166 | evaluate()
167 | model.train()
168 |
169 | if args.log_steps < 0:
170 | logger.info(f'Epoch {epoch + 1} - mean train loss: {np.mean(losses):.3}')
171 | logger.info(f"evaluate at epoch = {base_steps + it + 1}")
172 | if args.do_eval:
173 | evaluate()
174 | model.train()
175 |
--------------------------------------------------------------------------------
/covid_event/inference.py:
--------------------------------------------------------------------------------
1 | from transformers import T5Tokenizer, TFT5ForConditionalGeneration, T5ForConditionalGeneration
2 | # the model will be downloaded automatically from Huggingface's model hub
3 | model_name_or_path = "congcongwang/t5-large-fine-tuned-wnut-2020-task3"
4 | # Tensorflow2.0
5 | model = TFT5ForConditionalGeneration.from_pretrained(model_name_or_path)
6 |
7 | # or PyTorch
8 | # model = T5ForConditionalGeneration.from_pretrained(model_name_or_path)
9 |
10 | tokenizer = T5Tokenizer.from_pretrained(model_name_or_path)
11 |
12 | source = "context: *Prince Charles tests positive for Corona* Prince William knowing he's " \
13 | "the next in line to the throne: https://t.co/B1nmIpLj69. question: Who is tested positive?" \
14 | "choices: author of the tweet, not specified, the next in line, line to the throne, *Prince Charles," \
15 | " Corona* Prince William, he, the next, line, the throne."
16 |
17 | inputs = tokenizer.encode(source, return_tensors="tf") # Batch size 1. change "tf" to "pt" if using pytorch model
18 | result = model.generate(inputs)
19 |
20 | print(tokenizer.decode(result[0]))
21 | # output: Prince Charles
22 |
--------------------------------------------------------------------------------
/covid_event/mismatches.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/covid_event/mismatches.png
--------------------------------------------------------------------------------
/covid_event/predict.py:
--------------------------------------------------------------------------------
1 | import math
2 | from ttt import *
3 |
4 | if __name__ == '__main__':
5 | args = get_args()
6 | model_path = "tmp/t5-large-fine-tuned-wnut-task3"
7 | args.output_path = model_path
8 |
9 | assert os.path.isdir(args.output_path)
10 | assert os.path.isfile(os.path.join(args.output_path, "args.json"))
11 | args = Args(**json.load(open(os.path.join(args.output_path, "args.json"))))
12 | set_name = "val"
13 | args.output_path = model_path
14 | out_name = set_name + "_preds"
15 | args.data_path = "data/final"
16 | args.use_gpu = True
17 | args.use_tpu = False
18 | args.eval_batch_size = 4
19 | model, strategy = create_model(args, logger, get_model, save_args=False)
20 | model.load_weights(args.output_path + "/tf_model.h5")
21 |
22 | tokenizer = AutoTokenizer.from_pretrained(args.output_path)
23 | logger.info(f"********************start predicting {set_name} set********************")
24 | source_texts, encoded_source, _, meta = convert_t2t_examples(
25 | os.path.join(args.data_path, f'{set_name}.json'), tokenizer, args, with_target=False, return_meta=True)
26 | source_input_ids = encoded_source["input_ids"]
27 | predict_dataset = tf.data.Dataset.from_tensor_slices(
28 | (source_input_ids, encoded_source["attention_mask"]))
29 |
30 | predict_dataset = predict_dataset.batch(args.eval_batch_size)
31 | iter_num = math.ceil(len(source_input_ids) / args.eval_batch_size)
32 | preds = []
33 | # with strategy.scope():
34 | for input_ids, attention_mask in tqdm(predict_dataset, total=iter_num, desc=f"predicting {set_name}..."):
35 | predicts = model.generate(input_ids=input_ids,
36 | attention_mask=attention_mask,
37 | max_length=args.max_tgt_length)
38 | preds.extend([tokenizer.decode(ids) for ids in predicts])
39 |
40 | if not os.path.isdir("preds"):
41 | os.makedirs("preds")
42 | with open(f"preds/{out_name}.json", "w+") as out:
43 | for meta, source, target in zip(meta, source_texts, preds):
44 | meta["source"] = source
45 | meta["target"] = target
46 | out.write(json.dumps(meta) + "\n")
47 |
48 | print("done predicting")
49 |
--------------------------------------------------------------------------------
/covid_event/prepare.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import json
3 | import os
4 |
5 | from tqdm import tqdm
6 |
7 | eventname2question = {
8 | "positive": "Does this tweet report an individual or a small group of people who is tested postive for coronavirus?",
9 | "negative": "Does this tweet report an individual or a small group of people who is tested negative for coronavirus?",
10 | "cure_and_prevention": "Does this tweet report cure and prevention for coronavirus?",
11 | "can_not_test": "Does this tweet report an individual or a small group of people who can not be tested for coronavirus?",
12 | "death": "Does this tweet report dealth for coronavirus?",
13 | }
14 |
15 | short2question_positive = {"name": "Who is tested positive?",
16 | "close_contact": "Who is in close contact with the person tested positive?",
17 | "employer": "Who is the employer of the people tested positive?",
18 | "recent_travel": "Where did the people tested positive recently visit?",
19 | "relation": "Does the infected person have a relationship with the author of the tweet?",
20 | "gender": "What is the gender of the people tested positive?",
21 | "age": "What is the age of the people tested positive?",
22 | "when": "When are tested positive cases reported?",
23 | "where": "Where are tested positive cases reported?", }
24 |
25 | short2choices_positive = {"name": ["chunks", "author of the tweet", "not specified"],
26 | "close_contact": ["chunks", "author of the tweet", "not specified"],
27 | "employer": ["chunks", "not specified"],
28 | "recent_travel": ["chunks", "near author of the tweet", "not specified"],
29 | "relation": ["yes", "no", "not specified"],
30 | "gender": ["male", "female", "not specified"],
31 | "age": ["chunks", "not specified"],
32 | "when": ["chunks", "not specified"],
33 | "where": ["chunks", "near author of the tweet", "not specified"], }
34 |
35 | short2question_negative = {"name": "Who is tested negative?",
36 | "close_contact": "Who is in close contact with the person tested negative?",
37 | "relation": "Does the infected person have a relationship with the author of the tweet?",
38 | "gender": "What is the gender of the people tested negative?",
39 | "age": "What is the age of the people tested negative?",
40 | "when": "When are tested negative cases reported?",
41 | "where": "Where are tested negative cases reported?",
42 | "how_long": "How long does it take to get to know the test results?"}
43 |
44 | short2choices_negative = {"name": ["chunks", "author of the tweet", "not specified"],
45 | "close_contact": ["chunks", "author of the tweet", "not specified"],
46 | "relation": ["yes", "no", "not specified"],
47 | "gender": ["male", "female", "not specified"],
48 | "age": ["chunks", "not specified"],
49 | "when": ["chunks", "not specified"],
50 | "where": ["chunks", "near author of the tweet", "not specified"],
51 | "how_long": ["chunks", "not specified"], }
52 |
53 | short2question_can_not_test = {"name": "Who can not get a test?",
54 | "symptoms": "Is the untested person currently experiencing any COVID-19 related symptoms?",
55 | "relation": "Does the untested person have a relationship with the author of the tweet?",
56 | "when": "When is the can’t-be-tested situation reported?",
57 | "where": "Where is the can’t-be-tested situation reported?", }
58 |
59 | short2choices_can_not_test = {"name": ["chunks", "author of the tweet", "not specified"],
60 | "symptoms": ["yes", "no", "not specified"],
61 | "relation": ["yes", "no", "not specified"],
62 | "when": ["chunks", "not specified"],
63 | "where": ["chunks", "near author of the tweet", "not specified"], }
64 |
65 | short2question_dealth = {"name": "Who is dead for coronavirus?",
66 | "symptoms": "Did the person who was dead experience COVID-19 related symptoms?",
67 | "relation": "Does the deceased person have a relationship with the author of the tweet?",
68 | "when": "When is the dead case reported?",
69 | "where": "Where is the dead case reported?",
70 | "age": "What is the age of the people who is dead of COVID-19?", }
71 |
72 | short2choices_death = {"name": ["chunks", "author of the tweet", "not specified"],
73 | "symptoms": ["yes", "no", "not specified"],
74 | "relation": ["yes", "no", "not specified"],
75 | "when": ["chunks", "not specified"],
76 | "where": ["chunks", "near author of the tweet", "not specified"],
77 | "age": ["chunks", "not specified"], }
78 |
79 | short2question_cure_and_prevention = {"opinion": "Does the author of tweet believe the cure method is effective?",
80 | "what_cure": "What is the cure for coronavirus mentioned by the author of the tweet?",
81 | "who_cure": "Who is promoting the cure for coronavirus?", }
82 |
83 | short2choices_cure_and_prevention = {"opinion": ["not_effective", "effective"],
84 | "what_cure": ["chunks", "not specified"],
85 | "who_cure": ["chunks", "author of the tweet", "not specified"], }
86 |
87 | eventname2questionmapping = {"positive": short2question_positive,
88 | "negative": short2question_negative,
89 | "can_not_test": short2question_can_not_test,
90 | "death": short2question_dealth,
91 | "cure_and_prevention": short2question_cure_and_prevention,
92 | }
93 |
94 | eventname2choicesnmapping = {"positive": short2choices_positive,
95 | "negative": short2choices_negative,
96 | "can_not_test": short2choices_can_not_test,
97 | "death": short2choices_death,
98 | "cure_and_prevention": short2choices_cure_and_prevention, }
99 |
100 | event_type2part2annotation_keys = {
101 | "can_not_test": [
102 | "part2-relation.Response",
103 | "part2-symptoms.Response",
104 | "part2-name.Response",
105 | "part2-when.Response",
106 | "part2-where.Response"
107 | ],
108 | "cure_and_prevention": [
109 | "part2-opinion.Response",
110 | "part2-what_cure.Response",
111 | "part2-who_cure.Response"
112 | ],
113 | "negative": [
114 | "part2-age.Response",
115 | "part2-close_contact.Response",
116 | "part2-gender.Response",
117 | "part2-how_long.Response",
118 | "part2-name.Response",
119 | "part2-relation.Response",
120 | "part2-when.Response",
121 | "part2-where.Response"
122 | ],
123 | "death": [
124 | "part2-age.Response",
125 | "part2-name.Response",
126 | "part2-relation.Response",
127 | "part2-symptoms.Response",
128 | "part2-when.Response",
129 | "part2-where.Response"
130 | ],
131 | "positive": [
132 | "part2-age.Response",
133 | "part2-close_contact.Response",
134 | "part2-employer.Response",
135 | "part2-gender.Response",
136 | "part2-name.Response",
137 | "part2-recent_travel.Response",
138 | "part2-relation.Response",
139 | "part2-when.Response",
140 | "part2-where.Response"
141 | ]
142 | }
143 |
144 | def get_part1_new_example(example):
145 | event_type = example["event_type"]
146 | new_example = {}
147 | new_example["id"] = example["id_str"] # id
148 | new_example["event_type"] = event_type # event_type
149 | new_example["slot_type"] = "part1.Response" # slot_type
150 | new_example["context"] = example["full_text"] # context
151 | new_example["question"] = eventname2question[event_type] # question
152 | new_example["choices"] = "yes, no" # candidate choices
153 | new_example["answer"] = example["annotation"]["part1.Response"][0]
154 | return new_example
155 |
156 | def get_text_chunks(example):
157 | full_text = example["full_text"] # context
158 | candidate_chunks_offsets = example["candidate_chunks_offsets"]
159 | text_chunks = [full_text[each[0]:each[1]] for each in candidate_chunks_offsets]
160 | return text_chunks
161 |
162 |
163 | def get_total(filepath):
164 | total = 0
165 | with open(filepath, "r") as f:
166 | for line in f:
167 | total += 1
168 | return total
169 |
170 |
171 | def build_test(filepath, event_type="can_not_test"):
172 | new_examples = []
173 | with open(filepath, "r") as f:
174 | for line in tqdm(f):
175 | example = json.loads(line.strip())
176 | text_chunks = get_text_chunks(example)
177 | part2annotation_keys = event_type2part2annotation_keys[event_type]
178 | for each_part2_annotation_key in part2annotation_keys:
179 | slot_key = each_part2_annotation_key.split(".")[0].split("-")[-1]
180 | slot_question = eventname2questionmapping[event_type][slot_key]
181 | slot_candidate_choices = copy.deepcopy(eventname2choicesnmapping[event_type][slot_key])
182 | if "chunks" in slot_candidate_choices:
183 | slot_candidate_choices.remove("chunks")
184 | slot_candidate_choices.extend(text_chunks)
185 | new_example = {}
186 | full_text = example["text"]
187 | new_example["id"] = example["id"] # id
188 | new_example["event_type"] = event_type # event_type
189 | new_example["slot_type"] = each_part2_annotation_key # slot_type
190 | new_example["context"] = full_text # context
191 | new_example["question"] = slot_question # question
192 | new_example["candidates"] = slot_candidate_choices # candidate choices
193 | new_example["choices"] = ", ".join(slot_candidate_choices) # candidate choices
194 | new_examples.append(new_example)
195 | return new_examples
196 |
197 |
198 | def build_train_val(filepath, test_size=0.1, out_folder="data/middle"):
199 | total = get_total(filepath)
200 | test_no = int(total * test_size)
201 | # new_examples = []
202 | train, val = [], []
203 | NO_CONSENSUS_count = 0
204 | with open(filepath, "r") as f:
205 | for line in f:
206 | if total <= test_no:
207 | new_examples = val
208 | else:
209 | new_examples = train
210 | example = json.loads(line.strip())
211 | if example["annotation"]["part1.Response"][0] == "yes":
212 | new_example = get_part1_new_example(example)
213 | new_examples.append(new_example)
214 | text_chunks = get_text_chunks(example)
215 | event_type = example["event_type"]
216 | annotation = example["annotation"]
217 | annotation.pop("part1.Response")
218 | for each_part2_annotation_key, value in annotation.items():
219 | if value == "NO_CONSENSUS":
220 | NO_CONSENSUS_count += 1
221 | continue
222 | slot_key = each_part2_annotation_key.split(".")[0].split("-")[-1]
223 | slot_question = eventname2questionmapping[event_type][slot_key]
224 | slot_candidate_choices = copy.deepcopy(eventname2choicesnmapping[event_type][slot_key])
225 | if "chunks" in slot_candidate_choices:
226 | slot_candidate_choices.remove("chunks")
227 | slot_candidate_choices.extend(text_chunks)
228 | new_example = {}
229 | full_text = example["full_text"]
230 | new_example["id"] = example["id_str"] # id
231 | new_example["event_type"] = event_type # event_type
232 | new_example["slot_type"] = each_part2_annotation_key # slot_type
233 | new_example["context"] = full_text # context
234 | new_example["question"] = slot_question # question
235 | new_example["choices"] = ", ".join(slot_candidate_choices) # candidate choices
236 | answers = []
237 | for v in value:
238 | if isinstance(v, str):
239 | if v.lower() in ["no_cure", "not_effective", "no_opinion"]:
240 | answers.append("not_effective")
241 | else:
242 | answers.append(v.lower())
243 | else:
244 | answers.append(full_text[v[0]:v[1]])
245 | assert set(answers).intersection(set(slot_candidate_choices)) != set()
246 | new_example["answer"] = ", ".join(answers) # answer
247 | new_examples.append(new_example)
248 | else:
249 | new_example = get_part1_new_example(example)
250 | new_examples.append(new_example)
251 | total -= 1
252 | print(f"there are {NO_CONSENSUS_count} NO_CONSENSUS")
253 | if not os.path.isdir(out_folder):
254 | os.makedirs(out_folder, exist_ok=True)
255 | with open(out_folder + "/train.json", "w+") as f:
256 | for ex in train:
257 | f.write(json.dumps(ex) + "\n")
258 | with open(out_folder + "/val.json", "w+") as f:
259 | for ex in val:
260 | f.write(json.dumps(ex) + "\n")
261 | print(f"done building train and val from {filepath}, written to {out_folder}")
262 |
263 |
264 | def construct(data_path, use_choices=True, no_answer=False, out_folder="data/final"):
265 | if not os.path.isdir(out_folder):
266 | os.makedirs(out_folder, exist_ok=True)
267 |
268 | with(open(out_folder + "/" + data_path.split("/")[-1], "w+")) as tgt:
269 | with open(data_path, "r") as f:
270 | for line in tqdm(f, desc="reading..."):
271 | example = json.loads(line.strip())
272 | source_sequence = "context: " + example["context"] + " question: " + example["question"]
273 | if use_choices:
274 | source_sequence += " choices: " + example["choices"]
275 | if not no_answer:
276 | target_sequence = example["answer"]
277 | tgt.write(json.dumps(
278 | {"id": example["id"], "event_type": example["event_type"], "slot_type": example["slot_type"],
279 | "source": source_sequence, "target": target_sequence}) + "\n")
280 | else:
281 | tgt.write(json.dumps(
282 | {"id": example["id"], "event_type": example["event_type"], "slot_type": example["slot_type"],
283 | "source": source_sequence, "candidates": example["candidates"]}) + "\n")
284 | print("done source and target sequences construction")
285 |
286 | if __name__ == '__main__':
287 | '''
288 | sequence construction examples:
289 | X1: context: @CPRewritten I heard from @Corona_Bot__ that G is tested positive of COVID-19. question: Who is tested positive? choices: Not Specified, A, B, C.
290 | y1: A
291 | X2: context: @CPRewritten I heard from @Corona_Bot__ that G is tested positive of COVID-19. question: Does this message report an individual or a small group of people who is tested postive for coronavirus. choices: yes, no.
292 | y2: yes
293 | '''
294 | build_train_val("data/corpus.json", test_size=0.1, out_folder="data/middle")
295 | construct("data/middle/train.json", out_folder="data/final")
296 | construct("data/middle/val.json", out_folder="data/final")
297 |
298 | ### when test set is available, uncomment the following
299 | # can_not_test_examples = build_test("shared_task-test-can_not_test.jsonl", event_type="can_not_test")
300 | # cure_and_prevention_examples = build_test("shared_task-test-cure.jsonl", event_type="cure_and_prevention")
301 | # death_examples = build_test("shared_task-test-death.jsonl", event_type="death")
302 | # negative_examples = build_test("shared_task-test-negative.jsonl", event_type="negative")
303 | # postive_examples = build_test("shared_task-test-positive.jsonl", event_type="positive")
304 | # all = can_not_test_examples + cure_and_prevention_examples + death_examples + negative_examples + postive_examples
305 | # with open("../test.json", "w+") as f:
306 | # for ex in all:
307 | # f.write(json.dumps(ex) + "\n")
308 | # construct("ori/test.json", no_answer=True)
309 |
--------------------------------------------------------------------------------
/covid_event/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/covid_event/results.png
--------------------------------------------------------------------------------
/covid_event/submit.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import json, os
3 |
4 | def process_target(event_type, slot_type, target):
5 | # For opinion slot in cure & prevention category, you shall only have two labels: "effective" and "not_effective" ("no_opinion", "no_cure" and "not_effective" will be merged into "not_effective")
6 | if event_type == "cure_and_prevention" and "opinion" in slot_type:
7 | if "no_opinion" in target:
8 | target.remove("no_opinion")
9 | target.append("not_effective")
10 | if "no_cure" in target:
11 | target.remove("no_cure")
12 | target.append("not_effective")
13 | # "I" (along with its variations) shall be replaced with "AUTHOR OF THE TWEET" both for NAME and CLOSE_CONTACT slot for all categories.
14 | if "name" in slot_type or "close_contact" in slot_type:
15 | if "i" in target:
16 | # It doesn't matter if your predictions are lowercased or uppercased.
17 | target.remove("i")
18 | target.append("author of the tweet")
19 | if "i'm" in target:
20 | target.remove("i'm")
21 | target.append("author of the tweet")
22 | # For relation slot and symptoms slot, you shall only have two labels: "yes" and "not specified" ("no" and "not specified" will be merged into "not specified").
23 | if "relation" in slot_type or "symptoms" in slot_type:
24 | if "no" in target:
25 | target.remove("no")
26 | target.append("not specified")
27 | return list(set(target)) # do not submit repeated answers
28 |
29 | def convert(data_path, output_path="t5_sub", annotation_key="predicted_annotation"):
30 | with open(data_path, "r") as f:
31 | event2converted_examples = {}
32 | converted_examples = []
33 | converted_one_example = {'id': '', annotation_key: {}}
34 | last_event = ""
35 | last_id = ""
36 | switch = False
37 |
38 | for line in tqdm(f):
39 | example = json.loads(line.strip())
40 | slot_type = example["slot_type"]
41 | if "part2" in slot_type:
42 | event_type = example["event_type"]
43 | if event_type != last_event:
44 | if last_event != "":
45 | switch = False
46 | converted_examples.append(converted_one_example)
47 | event2converted_examples[last_event] = converted_examples
48 | converted_examples = []
49 |
50 | id = example["id"]
51 | target = example["target"].split(", ")
52 |
53 | if id != last_id:
54 | if last_id != "" and switch:
55 | converted_examples.append(converted_one_example)
56 |
57 | # Death: "symptoms" slot will be excluded
58 | # Tested Negative: "how long" slots will be excluded
59 | if (event_type == "death" and "symptoms" in slot_type) or (
60 | event_type == "negative" and "how_long" in slot_type):
61 | continue
62 |
63 | target = process_target(event_type, slot_type, target)
64 |
65 | converted_one_example = {'id': id,
66 | annotation_key: {slot_type: target}}
67 |
68 | else:
69 | switch = True
70 | # Death: "symptoms" slot will be excluded
71 | # Tested Negative: "how long" slots will be excluded
72 | if (event_type == "death" and "symptoms" in slot_type) or (
73 | event_type == "negative" and "how_long" in slot_type):
74 | continue
75 | target = process_target(event_type, slot_type, target)
76 | converted_one_example[annotation_key][slot_type] = target
77 |
78 | last_id = id
79 | last_event = event_type
80 | converted_examples.append(converted_one_example)
81 | event2converted_examples[last_event] = converted_examples
82 |
83 | if not os.path.isdir(output_path):
84 | os.makedirs(output_path, exist_ok=True)
85 |
86 | for event, examples in event2converted_examples.items():
87 | if "cure" in event:
88 | event = "cure"
89 | output_file_path = output_path + event + ".jsonl"
90 | with open(output_file_path, "w+") as f:
91 | for example in examples:
92 | f.write(json.dumps(example) + "\n")
93 | print(f'done writing to {output_file_path}')
94 | return event2converted_examples
95 |
96 | if __name__ == '__main__':
97 | convert("preds/val_preds.json", output_path="subs/val-run-1/")
98 |
--------------------------------------------------------------------------------
/covid_event/subs/post_process.py:
--------------------------------------------------------------------------------
1 | import os, json, pickle
2 |
3 | def read_groundtruth(gt_dir="subs/golden"):
4 | event_types = ['positive', 'negative', 'can_not_test', 'death', 'cure']
5 | output = {}
6 | for each_event in event_types:
7 | with open(os.path.join(gt_dir, each_event + "_sol.jsonl"), "r") as f:
8 | for line in f:
9 | ins = json.loads(line)
10 | if ins["id"] not in output:
11 | output[ins["id"]] = {}
12 | for key, slot_gt_list in ins["golden_annotation"].items():
13 | output[ins["id"]][key] = [e.lower() for e in slot_gt_list]
14 | return output
15 |
16 | grounds = read_groundtruth()
17 |
18 | def levenshtein(s1, s2):
19 | '''
20 | from: https://en.wikibooks.org/wiki/Algorithm_Implementation/Strings/Levenshtein_distance
21 | :param s1: first string
22 | :param s2: second string
23 | :return:
24 | '''
25 | if len(s1) < len(s2):
26 | return levenshtein(s2, s1)
27 | # len(s1) >= len(s2)
28 | if len(s2) == 0:
29 | return len(s1)
30 |
31 | previous_row = range(len(s2) + 1)
32 | for i, c1 in enumerate(s1):
33 | current_row = [i + 1]
34 | for j, c2 in enumerate(s2):
35 | insertions = previous_row[
36 | j + 1] + 1 # j+1 instead of j since previous_row and current_row are one character longer
37 | deletions = current_row[j] + 1 # than s2
38 | substitutions = previous_row[j] + (c1 != c2)
39 | current_row.append(min(insertions, deletions, substitutions))
40 | previous_row = current_row
41 | return previous_row[-1]
42 |
43 |
44 | def convert_to_candidates(id, slot_type, target, id2condidates):
45 | candidates = id2condidates[id][slot_type]
46 | intersection = set(candidates).intersection(set(target))
47 | if len(intersection) == 0:
48 | transformed_targets = []
49 | for t in [", ".join(target)]:
50 | min = 1000000
51 | min_index = 0
52 | for idx, c in enumerate(candidates):
53 | distance = levenshtein(c, t) / len(c)
54 | if distance < min:
55 | min = distance
56 | min_index = idx
57 | transformed_targets.append(candidates[min_index])
58 | else:
59 | transformed_targets = target
60 | return transformed_targets
61 |
62 |
63 | def readJsonl(path):
64 | output = {}
65 | with open(path, 'r') as f:
66 | for line in f:
67 | ins = json.loads(line)
68 | output[ins["id"]] = ins["predicted_annotation"]
69 | return output
70 |
71 |
72 | def read_id2text(data_file="data/middle/test.json"):
73 | output = {}
74 | with open(data_file, 'r') as f:
75 | for line in f:
76 | ins = json.loads(line)
77 | if ins["id"] not in output:
78 | output[ins["id"]] = {}
79 | output[ins["id"]][ins["slot_type"]] = ins["context"].lower()
80 | return output
81 |
82 |
83 | def transform():
84 | dir = "subs"
85 | id2condidates = pickle.load(open(os.path.join(dir, "testid2candidates.pkl"), "rb"))
86 | # the testid2candidates.pkl can also obtained by calling read_id2text()
87 | # read the README file on how to create data/middle/test.json (is unlabeled before the annotation release)
88 | from_runs = ["run-1", "run-2", "run-3"]
89 | to_runs = ["post-run-1", "post-run-2", "post-run-3"]
90 | event_types = ['positive', 'negative', 'can_not_test', 'death', 'cure']
91 | for from_run, to_run in zip(from_runs, to_runs):
92 | for each_event in event_types:
93 | if not os.path.isdir(os.path.join(dir, to_run)):
94 | os.makedirs(os.path.join(dir, to_run))
95 | with open(os.path.join(dir, to_run, each_event + ".jsonl"), "w+") as out:
96 | preds = readJsonl(os.path.join(dir, from_run, each_event + '.jsonl'))
97 | for id, slot_example_preds in preds.items():
98 | transformed_pred_dict = {}
99 | for slot_key, slot_pred in slot_example_preds.items():
100 | transformed_target = convert_to_candidates(id, slot_key, slot_pred, id2condidates)
101 | transformed_pred_dict[slot_key] = transformed_target
102 | out.write(json.dumps({"id": id, "predicted_annotation": transformed_pred_dict}) + "\n")
103 | print("done")
104 |
105 |
106 | if __name__ == '__main__':
107 | transform()
108 |
--------------------------------------------------------------------------------
/covid_event/subs/testid2candidates.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/covid_event/subs/testid2candidates.pkl
--------------------------------------------------------------------------------
/example_bert.py:
--------------------------------------------------------------------------------
1 | from ttt import *
2 |
3 | if __name__ == '__main__':
4 | args = get_args()
5 | # check what args are available
6 | logger.info(f"args: {json.dumps(args.__dict__, indent=2)}")
7 | ############### customize args
8 | args.use_gpu = True
9 | # args.use_tpu = True
10 | # args.tpu_address = "x.x.x.x" # replace with yours
11 | args.do_train = True
12 | args.use_tb = True
13 | # any one from MODELS_SUPPORT (check:ttt/args.py)
14 | args.model_select = "bert-base-uncased"
15 | # select a dataset following jsonl format, where text filed name is "text" and label field name is "label"
16 | args.data_path = "data/glue/sst2"
17 | # any one from TASKS_SUPPORT (check:ttt/args.py)
18 | args.task = "single-label-cls"
19 | args.log_steps = 1000
20 | # any one from LR_SCHEDULER_SUPPORT (check:ttt/args.py)
21 | args.scheduler="warmuplinear"
22 | # set do_eval = False if your data does not contain a validation set. In that case, patience, and early_stop will be invalid
23 | args.do_eval = True
24 | ############### end customize args
25 | # to have a sanity check for the args
26 | sanity_check(args)
27 | # seed everything, make deterministic
28 | set_seed(args.seed)
29 | tokenizer = get_tokenizer(args)
30 | inputs = get_inputs(tokenizer, args)
31 | model, _ = create_model(args, logger, get_model)
32 | # start training, here we keras high-level API
33 | training_history = model.fit(
34 | inputs["x_train"],
35 | inputs["y_train"],
36 | epochs=args.num_epochs_train,
37 | verbose=2,
38 | batch_size=args.per_device_train_batch_size*args.num_replicas_in_sync,
39 | callbacks=get_callbacks(args, inputs, logger, get_evaluator),
40 | )
--------------------------------------------------------------------------------
/example_t5.py:
--------------------------------------------------------------------------------
1 | from ttt import *
2 | import transformers
3 | transformers.logging.set_verbosity_info()
4 | logger = transformers.logging.get_logger()
5 |
6 | if __name__ == '__main__':
7 | args = get_args()
8 | # check what args are available and their default values
9 | logger.info(f"args: {json.dumps(args.__dict__, indent=2)}")
10 | ############### customize args
11 | args.use_gpu = True
12 | # args.use_tpu = True
13 | # args.tpu_address = "x.x.x.x"
14 | args.do_train = True
15 | args.use_tb = True
16 | # any one from MODELS_SUPPORT (check:ttt/args.py)
17 | args.model_select = "t5-small"
18 | # select a dataset. First check if it is from nlp, if yes load it here and save locally to the data_path
19 | # or customize a data in the data_path (train.json, val.json, test.json) where examples are organised in jsonl format
20 | # each line represents an example like this: {"text": "...", "label","..."}
21 | args.data_path = "data/glue/sst2"
22 | # any one from TASKS_SUPPORT (check:ttt/args.py)
23 | args.task = "t2t"
24 | args.log_steps = 400
25 | args.eval_batch_size=32
26 | args.per_device_train_batch_size=8
27 | args.max_src_length=128
28 | args.load_train_num = 1000
29 | # any one from LR_SCHEDULER_SUPPORT (check:ttt/args.py)
30 | args.scheduler = "warmuplinear"
31 | args.lr=5e-5
32 | # set do_eval = False if your data does not contain a validation set. In that case, patience, and early_stop will be invalid
33 | args.do_eval = True
34 | ############### end customize args
35 | # to have a sanity check for the args
36 | sanity_check(args,logger=logger)
37 | # seed everything, make deterministic
38 | # set_seed(args.seed) let's do this in trainer before start training
39 | tokenizer = get_tokenizer(args)
40 | inputs = get_inputs(tokenizer, args)
41 | model, strategy = create_model(args, logger, get_model)
42 | # start training, here we customize T2TTrainer to get more control and flexibility
43 | trainer = T2TTrainer(args, logger)
44 | trainer.train(model, strategy, tokenizer, inputs)
45 |
--------------------------------------------------------------------------------
/example_trans_t5.py:
--------------------------------------------------------------------------------
1 | from ttt import *
2 |
3 | if __name__ == '__main__':
4 | args = get_args()
5 | # check what args are available and their default values
6 | logger.info(f"args: {json.dumps(args.__dict__, indent=2)}")
7 | ############### customize args
8 | args.use_gpu = True
9 | # args.tpu_address = "x.x.x.x"
10 | # args.use_tpu = True
11 | args.do_train = True
12 | args.use_tb = True
13 | # any one from MODELS_SUPPORT (check:ttt/args.py)
14 | args.model_select = "t5-large"
15 | # the path to the translation dataset, each line represents an example in jsonl format like: {"target": "...", "source","..."}
16 | args.data_path = "data/wmt_en_ro"
17 | # any one from TASKS_SUPPORT (check:ttt/args.py)
18 | args.task = "translation"
19 | args.max_src_length=128
20 | args.max_tgt_length=128
21 | args.eval_batch_size = 8
22 | args.log_steps = 1000
23 | args.source_field_name = "text"
24 | args.target_field_name = "label"
25 | args.per_device_train_batch_size = 2
26 | args.eval_on="bleu" #this refers to sacrebleu as used in T5 paper
27 | # any one from LR_SCHEDULER_SUPPORT (check:ttt/args.py)
28 | args.scheduler = "warmuplinear"
29 | # set do_eval = False if your data does not contain a validation set. In that case, patience, and early_stop will be invalid
30 | args.do_eval = True
31 | ############### end customize args
32 | # to have a sanity check for the args
33 | sanity_check(args)
34 | # seed everything, make deterministic
35 | set_seed(args.seed)
36 | tokenizer = get_tokenizer(args)
37 | inputs = get_inputs(tokenizer, args)
38 | model, strategy = create_model(args, logger, get_model)
39 | # start training, here we customize T2TTrainer to get more control and flexibility
40 | trainer = T2TTrainer(args)
41 | trainer.train(model, strategy, tokenizer, inputs)
42 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 |
2 | from ttt import *
3 | from sklearn.metrics import classification_report, accuracy_score
4 | import math
5 |
6 | logging.basicConfig(
7 | format='%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s',
8 | datefmt='%Y-%m-%d %H:%M:%S',
9 | level=logging.INFO
10 | )
11 | logger = logging.getLogger(__name__)
12 |
13 | if __name__ == '__main__':
14 | # get args
15 | args = get_args()
16 | # check what args are available
17 | logger.info(f"args: {json.dumps(args.__dict__, indent=2)}")
18 | # args.do_train = True
19 | # args.do_eval = True
20 | # args.do_test = True
21 |
22 | # args.use_tpu = False
23 | # args.use_gpu = True
24 | sanity_check(args)
25 |
26 | if args.do_train:
27 | add_filehandler_for_logger(args.output_path, logger)
28 | logger.info(f"tf.version.VERSION: {tf.version.VERSION}")
29 | logger.info(f"set seed for random, numpy and tensorflow: seed = {args.seed}")
30 | set_seed(args.seed)
31 | tokenizer = get_tokenizer(args)
32 | inputs = get_inputs(tokenizer, args)
33 | model, strategy = create_model(args, logger, get_model)
34 | if not args.use_tpu: # used during development
35 | model.run_eagerly = True # for debugging, this is cool. TF also supports debugging as pytorch
36 | # start training here
37 | if args.task == "t2t":
38 | # t5-like model is customized because we want more flexibility and control over the training loop
39 | # customize_fit(model, strategy, tokenizer, inputs, args)
40 | trainer = T2TTrainer(args)
41 | trainer.train(model, strategy, tokenizer, inputs)
42 | else:
43 | training_history = model.fit(
44 | inputs["x_train"],
45 | inputs["y_train"],
46 | epochs=args.num_epochs_train,
47 | verbose=2,
48 | batch_size=args.per_device_train_batch_size*args.num_replicas_in_sync,
49 | callbacks=get_callbacks(args, inputs, logger, get_evaluator),
50 | )
51 |
52 | if args.do_test:
53 | add_filehandler_for_logger(args.output_path, logger, out_name="test")
54 | assert os.path.isdir(args.output_path)
55 | assert os.path.isfile(os.path.join(args.output_path, "args.json"))
56 | args = Args(**json.load(open(os.path.join(args.output_path, "args.json"))))
57 | model, strategy = create_model(args,logger, get_model)
58 | # to make it ok to task == "single-label-cls" todo
59 | ck_path = glob.glob(os.path.join(args.output_path, "best*.h5"))[0]
60 | if args.ck_index_select < 0 and -args.ck_index_select <= args.keep_ck_num:
61 | cks_path_already = glob.glob(os.path.join(args.output_path, "*.h5"))
62 | index2path = {int(os.path.basename(each_ck_path).split(".")[0].split("_")[-1]): each_ck_path for
63 | each_ck_path in cks_path_already}
64 | sorted_indices = sorted(index2path)
65 | ck_path = index2path[sorted_indices[args.ck_index_select]]
66 |
67 | logger.info(f"evaluating using weights from checkpoint: {ck_path}")
68 | model.load_weights(ck_path)
69 |
70 | tokenizer = AutoTokenizer.from_pretrained(args.output_path)
71 |
72 | logger.info("********************start evaluating on test set********************")
73 | if args.task == "single-label-cls":
74 | # tokenizer = get_tokenizer(args)
75 | test_texts, encoded_test, y_test = convert_seq_single_cls_examples(
76 | os.path.join(args.data_path, 'test.json'),
77 | tokenizer, args.input_seq_length,
78 | args.label2id)
79 |
80 | x_test = [encoded_test["input_ids"], encoded_test["token_type_ids"], encoded_test["attention_mask"]]
81 |
82 | if not args.use_tpu: # used during development
83 | model.run_eagerly = True # for debugging, this is cool. TF also supports debugging as pytorch
84 |
85 | pred_probs = model.predict(x_test, batch_size=args.eval_batch_size*strategy.num_replicas_in_sync, steps=math.ceil(len(y_test) / 32), verbose=1)
86 | preds = tf.math.argmax(pred_probs, 1).numpy()
87 |
88 | acc = accuracy_score(y_test, preds)
89 |
90 | target_names = [''] * len(args.label2id)
91 | for label, id in args.label2id.items():
92 | target_names[id] = label
93 |
94 | logger.info(
95 | f"test_eval_report: {classification_report(y_test, preds, digits=4, target_names=target_names)}")
96 | logger.info(f"test_eval_acc: {acc}")
97 |
98 | elif args.task == "t2t":
99 | source_texts, encoded_source, encoded_target = convert_t2t_examples(
100 | os.path.join(args.data_path, 'test.json'), tokenizer, args)
101 | source_input_ids = encoded_source["input_ids"]
102 | test_dataset = tf.data.Dataset.from_tensor_slices(
103 | (source_input_ids, encoded_source["attention_mask"], encoded_target["input_ids"]))
104 | test_dataset = test_dataset.batch(args.eval_batch_size*strategy.num_replicas_in_sync)
105 | iter_num = math.ceil(len(source_input_ids) / args.eval_batch_size)
106 | preds = []
107 | gts = []
108 |
109 | for input_ids, attention_mask, gt in tqdm(test_dataset, total=iter_num, desc="testing..."):
110 | predicts = model.generate(input_ids=input_ids,
111 | attention_mask=attention_mask,
112 | max_length=args.max_tgt_length)
113 | preds.extend([tokenizer.decode(ids) for ids in predicts])
114 | gts.extend([tokenizer.decode(ids) for ids in gt])
115 |
116 | acc = accuracy_score(gts, preds)
117 | logger.info(
118 | f"test_eval_report: {classification_report(gts, preds, digits=4)}")
119 | logger.info(f"test_eval_acc: {acc}")
120 | # tensorboard dev upload --logdir runs
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | description-file = README.md
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | with open("README.md", mode="r", encoding="utf-8") as readme_file:
4 | readme = readme_file.read()
5 |
6 | setup(
7 | name="pytriplet", # named pytriplet in pypi to avoid repeated name
8 | version="0.0.5",
9 | author="Congcong Wang",
10 | author_email="wangcongcongcc@gmail.com",
11 | description="Fine-tuning Transformers with TPUs or GPUs acceleration, written in Tensorflow2.0",
12 | long_description=readme,
13 | long_description_content_type="text/markdown",
14 | license="MIT License",
15 | url="https://github.com/wangcongcong123/ttt",
16 | download_url="https://github.com/wangcongcong123/ttt/releases/download/v0.0.3/pytriplet.tar.gz",
17 | packages=find_packages(),
18 | install_requires=[
19 | "tensorflow==2.3.0",
20 | "sklearn",
21 | "tqdm",
22 | "keras",
23 | "tensorboardX",
24 | "nlp",
25 | "sacrebleu",
26 | "datasets",
27 | "transformers"
28 | ],
29 | classifiers=[
30 | "Development Status :: 4 - Beta",
31 | "Intended Audience :: Science/Research",
32 | "License :: OSI Approved :: Apache Software License",
33 | "Programming Language :: Python :: 3.7",
34 | "Topic :: Scientific/Engineering :: Artificial Intelligence"
35 | ],
36 | keywords="Transformers, Tensorflow, TPUs acceleration"
37 | )
38 | # commands for uploading to pypi
39 | # python setup.py sdist
40 | # pip install twine
41 | # twine upload dist/*
42 |
--------------------------------------------------------------------------------
/ttt/__init__.py:
--------------------------------------------------------------------------------
1 | from ttt.inputs import *
2 | from ttt.utils import *
3 | from ttt.args import *
4 | from ttt.models import get_model
5 | from ttt.t2t_trainer import T2TTrainer
6 | from ttt.evaluators import get_evaluator
--------------------------------------------------------------------------------
/ttt/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/ttt/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/ttt/__pycache__/args.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/ttt/__pycache__/args.cpython-37.pyc
--------------------------------------------------------------------------------
/ttt/__pycache__/evaluators.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/ttt/__pycache__/evaluators.cpython-37.pyc
--------------------------------------------------------------------------------
/ttt/__pycache__/inputs.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/ttt/__pycache__/inputs.cpython-37.pyc
--------------------------------------------------------------------------------
/ttt/__pycache__/models.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/ttt/__pycache__/models.cpython-37.pyc
--------------------------------------------------------------------------------
/ttt/__pycache__/t2t_trainer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/ttt/__pycache__/t2t_trainer.cpython-37.pyc
--------------------------------------------------------------------------------
/ttt/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/ttt/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/ttt/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from tqdm import tqdm
3 | from sklearn.model_selection import train_test_split
4 | from datasets import load_dataset
5 | import urllib.request
6 | import os, json, sys
7 | import shutil
8 | import tarfile
9 | from .utils import add_filehandler_for_logger
10 | # these have been tested and work fine. more can be added to this list to test
11 | MODELS_SUPPORT = ["distilbert-base-cased", "bert-base-uncased", "bert-large-uncased",
12 | "google/electra-base-discriminator",
13 | "google/electra-large-discriminator", "albert-base-v2", "roberta-base",
14 | "t5-small", "t5-base", "t5-large"]
15 |
16 | # if using t5 models, the tasks has to be t2t* ones
17 | TASKS_SUPPORT = ["single-label-cls", "t2t", "translation", "pretrain"]
18 | # in the future, more schedulers will be added, such as warmupconstant, warmupcosine, etc.
19 | LR_SCHEDULER_SUPPORT = ["warmuplinear", "warmupconstant", "constant", "constantlinear"]
20 | # warmuplinear refers to increasing lr from 0 to a stage specified by warmup_ratio (ratio of total iterations) and then decreasing lr linearly
21 | ADDITIONAL_DATA_SUPPORT = {"wmt_en_ro": "https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_ro.tar.gz"}
22 |
23 |
24 | def ts2jsonl(data_path, t5_prefix_source=True):
25 | """
26 | t5_prefix_source: by default, t5_prefix_source, "translate English to Romanian: " is prepended to source sequence as in: https://arxiv.org/pdf/1910.10683.pdf
27 | """
28 | sets = ["train", "val", "test"]
29 | for set_name in sets:
30 | if os.path.isfile(os.path.join(data_path, set_name + ".source")):
31 | with open(os.path.join(data_path, set_name + ".source"), encoding="utf-8") as src:
32 | src_lines = src.readlines()
33 | with open(os.path.join(data_path, set_name + ".target"), encoding="utf-8") as src:
34 | tgt_lines = src.readlines()
35 | assert len(src_lines) == len(tgt_lines)
36 | if set_name == "train":
37 | #### for any train with examples more than 10000, we get 10000 to form a fixture training set (well suited for quick model development and prototyping)
38 | with open(os.path.join(data_path, set_name + "_fixture.json"), "w+", encoding="utf-8") as f:
39 | for source, target in tqdm(zip(src_lines[:10000], tgt_lines[:10000])):
40 | f.write(json.dumps({"source": "translate English to Romanian: " + source.strip(),
41 | "target": target.strip()}) + "\n")
42 |
43 | with open(os.path.join(data_path, set_name + ".json"), "w+", encoding="utf-8") as f:
44 | for source, target in tqdm(zip(src_lines, tgt_lines),
45 | desc=f"converting original data: {data_path} to jsonl formats"):
46 | if t5_prefix_source:
47 | source = "translate English to Romanian: " + source
48 | f.write(json.dumps({"source": source.strip(), "target": target.strip()}) + "\n")
49 |
50 |
51 | def data_check(args, sample_val_from_train=True, val_sample_portion=0.1,logger=None):
52 | if not os.path.isfile(os.path.join(args.data_path, "train.json")):
53 | # if it is from: https://huggingface.co/nlp/viewer, load it
54 | # if there is no validation set, by default, here a random sampling from the train (0.1 ratio) is used to form the val set
55 | # so far it works well for single-sequence cls datasets such as "glue/sst2", "20newsgroup", "ag_news", "imdb", "sentiment140", etc.
56 | # support more all kinds of other datasets that are availale in nlp, such as sequence-pair cls datasets (NLI), qa datasets, etc. -> todo
57 | set_name = args.data_path.split("/")[1:]
58 | try:
59 | dataset = load_dataset(*set_name)
60 | target_examples_dict = {}
61 | assert "train" in dataset, "not found train set in the given nlp dataset, make sure you give the correct name as listed in: https://huggingface.co/nlp/viewer/"
62 | label_names = dataset["train"].features["label"].names
63 | for set_key, examples in tqdm(dataset.items(),
64 | desc=f"not found the data locally, try to load {set_name} from nlp lib."):
65 | if set_key == "validation":
66 | set_key = "val"
67 | if set_key not in target_examples_dict:
68 | target_examples_dict[set_key] = []
69 |
70 | if set_key == "test" and "sst2" in set_name:
71 | # sst2 does not have ground truths in test set from the nlp lib.
72 | # here a special branch is processed here:
73 | # to download a sst-test set with the same format as the train and val set loaded here
74 | # from here: https://ucdcs-student.ucd.ie/~cwang/sst2-test.json
75 | # for model testing
76 | sst2_test = urllib.request.urlopen("https://ucdcs-student.ucd.ie/~cwang/sst2-test.json")
77 | for line in sst2_test:
78 | decoded_line = line.decode("utf-8").strip()
79 | target_examples_dict[set_key].append(decoded_line)
80 | else:
81 | for example in examples:
82 | example["label"] = label_names[example["label"]] # convert to raw label
83 | if "sentence" in example:
84 | example["text"] = example["sentence"]
85 | del example["sentence"]
86 | target_examples_dict[set_key].append(json.dumps(example))
87 |
88 | if sample_val_from_train and "val" not in target_examples_dict:
89 | train, val = train_test_split(target_examples_dict["train"], test_size=val_sample_portion,
90 | random_state=42)
91 | target_examples_dict["train"] = train
92 | target_examples_dict["val"] = val
93 |
94 | to_folder = os.path.join("data", *set_name)
95 |
96 | if not os.path.isdir(to_folder):
97 | os.makedirs(to_folder)
98 |
99 | splits_dict = {}
100 | for set_key, examples in tqdm(target_examples_dict.items(), desc="writing..."):
101 | with open(os.path.join(to_folder, set_key + ".json"), "w+") as target:
102 | target.write("\n".join(examples))
103 | splits_dict[set_key] = len(examples)
104 | with open(os.path.join(to_folder, "splits.txt"), "w+") as tgt:
105 | tgt.write(json.dumps(splits_dict, indent=2))
106 |
107 | except:
108 | if set_name[0] in ADDITIONAL_DATA_SUPPORT.keys():
109 | fstream = urllib.request.urlopen(ADDITIONAL_DATA_SUPPORT[set_name[0]])
110 | tfile = tarfile.open(fileobj=fstream, mode="r|gz")
111 | tfile.extractall(path="data")
112 | if set_name[0] == "wmt_en_ro":
113 | ts2jsonl(args.data_path)
114 | else:
115 | if logger is not None:
116 | logger.info("data not found")
117 | print("data not found")
118 | sys.exit(0)
119 |
120 | def check_output_path(output_path,force=False):
121 | if os.path.isdir(output_path):
122 | if force:
123 | print(f"{output_path} exists, remove it as force=True")
124 | shutil.rmtree(output_path)
125 | os.makedirs(output_path, exist_ok=True)
126 | else:
127 | out = input(
128 | "Output directory ({}) already exists and is not empty, you wanna remove it before start training? (y/n)".format(
129 | output_path))
130 | if out.lower() == "y":
131 | shutil.rmtree(output_path)
132 | os.makedirs(output_path, exist_ok=True)
133 | else:
134 | sys.exit(0)
135 | else:
136 | print(f"{output_path} not found, create it now")
137 | os.makedirs(output_path, exist_ok=True)
138 |
139 | def sanity_check(args,logger=None,force=False):
140 | # auto-construct some args
141 | # check if data exists
142 | data_check(args,logger=logger)
143 |
144 | output_folder = args.model_select + "_" + args.task + "_" + "-".join(args.data_path.split("/")[1:])
145 | output_path = os.path.join("tmp", output_folder)
146 | args.output_folder = output_folder
147 | args.output_path = output_path
148 |
149 | if args.do_train:
150 | check_output_path(output_path,force=force)
151 |
152 | assert args.model_select in MODELS_SUPPORT or os.path.isdir(args.model_select), F"set --model_select has to be in {MODELS_SUPPORT} or a local path where model's config and tokenizer_config exist"
153 |
154 | assert args.task in TASKS_SUPPORT, F"set --task to be in {TASKS_SUPPORT}"
155 | assert args.scheduler in LR_SCHEDULER_SUPPORT, F"set --scheduler to be in {TASKS_SUPPORT}"
156 | if "t5" in args.model_select:
157 | assert "t2t" in args.task or "translation" in args.task or "pretrain" in args.task, "t5 models (--model_select) only support t2t, translation, pretrain tasks (--task)"
158 |
159 | if "t5" not in args.model_select:
160 | assert "t2t" not in args.task, "BERT-like models (--model_select) only support non t2t tasks (--task)"
161 |
162 | if "translation" in args.task:
163 | assert "t5" in args.model_select, "translation task now is only compatible with T5 models"
164 |
165 | if "pretrain" in args.task:
166 | assert "t5" in args.model_select, "pretrain task now is only compatible with T5 models so far"
167 | if logger is not None:
168 | add_filehandler_for_logger(args.output_path, logger)
169 |
170 | class Args:
171 | '''
172 | a Args class that maintain the same default args as argparse.ArgumentParser
173 | '''
174 | model_select = "bert-base-uncased"
175 | data_path = "data/glue/sst2"
176 | dataset_name = ","
177 | task = "single-label-cls"
178 | per_device_train_batch_size = 8
179 | eval_batch_size = 32
180 | num_epochs_train = 6
181 | log_steps = 400
182 | max_seq_length = 128
183 | max_src_length = 128
184 | max_tgt_length = 20
185 | source_field_name = "text"
186 | target_field_name = "label"
187 | lr = 5e-5
188 | warmup_ratio = 0.1
189 | patience = 20
190 | scheduler = "warmuplinear"
191 | seed = 122
192 | eval_on = "acc"
193 | keep_ck_num = 3
194 | ck_index_select = 0
195 | do_train = False
196 | do_eval = False
197 | do_test = False
198 | use_gpu = False
199 | use_tpu = False
200 | use_tb = False
201 | tpu_address = "x.x.x.x"
202 |
203 | def __init__(self, **kwargs):
204 | for k, v in kwargs.items():
205 | setattr(self, k, v)
206 |
207 |
208 | def get_args():
209 | parser = argparse.ArgumentParser(description='Hyper params')
210 |
211 | parser.add_argument('--model_select', type=str, default="t5-small",
212 | help='model select from MODEL_MAP')
213 |
214 | parser.add_argument('--data_path', type=str, default="data/glue/sst2",
215 | help='data path')
216 |
217 | parser.add_argument('--dataset_name', type=str, default="",
218 | help="dataset name for HF's load_dataset")
219 |
220 | parser.add_argument('--task', type=str, default="t2t",
221 | help='tasks available in TASKS_SUPPORT')
222 |
223 | parser.add_argument('--per_device_train_batch_size', type=int, default=8,
224 | help='input batch size for training')
225 |
226 | parser.add_argument('--eval_batch_size', type=int, default=32,
227 | help='input batch size for training')
228 |
229 | parser.add_argument('--num_epochs_train', type=int, default=6,
230 | help='number of epochs to train')
231 |
232 | parser.add_argument('--log_steps', type=int, default=400,
233 | help='logging steps for evaluation based on global step if it is not -1 and based on epoch if it is -1, and tracking metrics using tensorboard if use_tb is active')
234 |
235 | parser.add_argument('--max_seq_length', type=int, default=128,
236 | help='maximum sequence length of samples in a batch for training')
237 |
238 | parser.add_argument('--max_src_length', type=int, default=128,
239 | help='only working for t5-like t2t-based tasks, maximum source sequence length of samples in a batch for training')
240 |
241 | parser.add_argument('--max_tgt_length', type=int, default=20,
242 | help='only working for t5-like t2t-based tasks, maximum target sequence length of samples in a batch for training')
243 |
244 | parser.add_argument('--source_field_name', type=str, default="text",
245 | help='only working for t5-like t2t-based tasks, the source field name of the provided jsonl-formatted data')
246 |
247 | parser.add_argument('--target_field_name', type=str, default="label",
248 | help='only working for t5-like t2t-based tasks, the target field name of the provided jsonl-formatted data')
249 |
250 | parser.add_argument('--lr', type=float, default=0.001,
251 | help='default learning rate for fine-tuning as described in the T5 paper')
252 |
253 | parser.add_argument('--warmup_ratio', type=float, default=0.1,
254 | help='warmup_ratio, only working if scheduler is not constant')
255 |
256 | parser.add_argument('--patience', type=int, default=20,
257 | help='patience based on the log_steps')
258 |
259 | parser.add_argument('--scheduler', type=str, default="constant",
260 | help='scheduler, default is constant as described in the T5 paper')
261 |
262 | parser.add_argument('--seed', type=int, default=122,
263 | help='random seed')
264 |
265 | parser.add_argument('--eval_on', type=str, default="acc",
266 | help='eval on for best ck saving and patience-based training early stop')
267 |
268 | parser.add_argument('--keep_ck_num', type=int, default=3,
269 | help='keep_ck_num except for the best ck (evaluated on validation set using the metric specified by --eval_on')
270 |
271 | parser.add_argument('--ck_index_select', type=int, default=0,
272 | help='ck_index_select, use the best one by default, negative one to specify a latest one, working when --do_test is active')
273 |
274 | parser.add_argument(
275 | "--do_train", action="store_true", help="Do training"
276 | )
277 | parser.add_argument(
278 | "--do_eval", action="store_true", help="do evaluation on validation set for saving checkpoint"
279 | )
280 | parser.add_argument(
281 | "--do_test", action="store_true", help="eval on test set if test set is availale"
282 | )
283 | parser.add_argument(
284 | "--use_gpu", action="store_true", help="use gpu?"
285 | )
286 |
287 | parser.add_argument(
288 | "--use_tpu", action="store_true", help="use tpu? "
289 | )
290 | parser.add_argument(
291 | "--use_tb", action="store_true", help="use tensorboard for tracking training process, default save to ./runs"
292 | )
293 |
294 | parser.add_argument('--tpu_address', type=str, default="x.x.x.x",
295 | help='cloud tpu address if using tpu')
296 |
297 | parser.add_argument(
298 | "--default_store", action="store_true",
299 | help="Store datasets, weights, logs, and relevant details to folders by default?"
300 | )
301 |
302 | args = parser.parse_args()
303 | return args
304 |
--------------------------------------------------------------------------------
/ttt/evaluators.py:
--------------------------------------------------------------------------------
1 | ''''
2 | this is a evaluation callback class for high-level Keras training (BERT-like models in this lib)
3 | '''
4 |
5 | import sys
6 | from tensorflow import keras
7 | import numpy as np
8 | from sklearn.metrics import classification_report
9 | import tensorflow as tf
10 | import logging
11 | import os
12 | from ttt.utils import add_filehandler_for_logger, get_existing_cks
13 | from tensorboardX import SummaryWriter
14 |
15 | logging.basicConfig(
16 | format='%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s',
17 | datefmt='%Y-%m-%d %H:%M:%S',
18 | level=logging.INFO
19 | )
20 | logger = logging.getLogger(__name__)
21 |
22 | class ClsEvaluator(keras.callbacks.Callback):
23 | def __init__(self, x_eval, y_eval, args):
24 | super(ClsEvaluator).__init__()
25 | self.x_eval = x_eval
26 | self.y_eval = y_eval
27 | self.eval_on = args.eval_on
28 | self.patience = args.patience
29 | self.log_steps = args.log_steps
30 | self.args = args
31 | self.use_tb = self.args.use_tb
32 | if self.use_tb:
33 | self._tb_writer = SummaryWriter(log_dir=os.path.join("runs", args.output_folder))
34 |
35 | def on_train_begin(self, logs=None):
36 | self.global_step = 0
37 | self.wait = 0
38 | self.best = np.Inf if self.eval_on == "loss" else -np.Inf
39 |
40 | def on_train_end(self, logs=None):
41 | if self.use_tb:
42 | self._tb_writer.close()
43 |
44 | def on_batch_end(self, batch, logs=None):
45 | self.global_step += 1
46 | if self.log_steps != -1 and self.global_step % self.log_steps == 0:
47 | if self.args.do_eval:
48 | self.evaluate(self.global_step, tag="global_step", logs=logs)
49 |
50 | def evaluate(self, steps, tag="epoch", logs=None):
51 | logger.info("\n")
52 | logger.info(f"*************evaluating at {tag}={steps}*************")
53 | eval_results = self.model.evaluate(self.x_eval, self.y_eval, batch_size=self.args.eval_batch_size)
54 | dev_loss, acc = eval_results[0], eval_results[1]
55 | pred_probs = self.model.predict(self.x_eval, batch_size=self.args.eval_batch_size)
56 | preds = tf.math.argmax(pred_probs, 1).numpy()
57 | # acc = accuracy_score(preds, self.y_eval)
58 | logger.info(f"{tag}={steps}, eval_report: {classification_report(self.y_eval, preds, digits=4)}")
59 | logger.info(f"{tag}={steps}, eval_acc: {acc}")
60 | logger.info(f"{tag}={steps}, eval_results: {eval_results}")
61 | logger.info(f"{tag}={steps}, train_logs: {logs}")
62 |
63 | if self.use_tb:
64 | if logs is not None:
65 | logger.info("logging metrics with tensorboard")
66 | for key,value in logs.items():
67 | self._tb_writer.add_scalar(f"train_{key}_{tag}",value, steps)
68 | self._tb_writer.add_scalar(f"train_lr_{tag}", self.model.optimizer.lr.numpy(), steps)
69 | self._tb_writer.add_scalar(f"val_acc_{tag}", acc, steps)
70 | self._tb_writer.add_scalar(f"val_loss_{tag}", dev_loss, steps)
71 |
72 |
73 | if self.eval_on == "loss":
74 | if dev_loss <= self.best:
75 | self.wait = 0
76 | self.best = dev_loss
77 | self.save_ck(steps, tag, best_ck=True)
78 | else:
79 | self.wait += 1
80 | else:
81 | if acc >= self.best:
82 | self.wait = 0
83 | self.best = acc
84 | self.save_ck(steps, tag, best_ck=True)
85 | else:
86 | self.wait += 1
87 | logger.info(f"early stop count: {self.wait}/{self.patience}")
88 | logger.info(f"{tag}={steps}, best_on_eval_since({self.eval_on}): {self.best}")
89 | self.save_ck(steps, tag)
90 | if self.wait >= self.patience:
91 | logger.info("run out of patience, early stop")
92 | if self.use_tb:
93 | self._tb_writer.close()
94 | sys.exit(0)
95 |
96 | def save_ck(self, steps, tag="epoch", best_ck=False):
97 | sorted_indices, index2path = get_existing_cks(self.args.output_path, best_ck=best_ck)
98 | if len(sorted_indices) >= self.args.keep_ck_num:
99 | logger.info(
100 | f"since there are already {len(sorted_indices)} checkpoints saved that will be more than keep_ck_num={self.args.keep_ck_num}")
101 | logger.info(f"remove the oldest one: {index2path[sorted_indices[0]]}")
102 | os.remove(index2path[sorted_indices[
103 | 0]]) # remove the oldest checkpoint, i.e., the one with the lowest epoch number
104 | # write_args(self.args.output_path, self.args)
105 | if best_ck:
106 | logger.info(
107 | f'save best model weights to {os.path.join(self.args.output_path, f"best_ck_at_{tag}_{steps}.h5")}')
108 | self.model.save_weights(os.path.join(self.args.output_path, f"best_ck_at_{tag}_{steps}.h5"),
109 | overwrite=True)
110 | else:
111 | logger.info(
112 | f'save model weights to {os.path.join(self.args.output_path, f"ck_at_{tag}_{steps}.h5")}')
113 | self.model.save_weights(os.path.join(self.args.output_path, f"ck_at_{tag}_{steps}.h5"),
114 | overwrite=True)
115 |
116 | def on_epoch_end(self, epoch, logs=None):
117 | if self.log_steps == -1:
118 | if self.args.do_eval:
119 | self.evaluate(epoch, tag="epoch",logs=logs)
120 | if not self.args.do_eval:
121 | # if do not do evaluate, the checkpoint at the end of epoch needs to be saved
122 | self.save_ck(epoch, tag="epoch")
123 |
124 | def get_evaluator(x_eval, y_eval, args):
125 | add_filehandler_for_logger(args.output_path, logger)
126 | if args.task == "single-label-cls":
127 | return ClsEvaluator(x_eval, y_eval, args)
128 | elif args.task == "t2t":
129 | # it uses customize training toop, we do not need a evaluator here
130 | pass
131 | else:
132 | pass
--------------------------------------------------------------------------------
/ttt/inputs.py:
--------------------------------------------------------------------------------
1 | import json, os
2 | from tqdm import tqdm
3 | import numpy as np
4 | import logging, pickle
5 | from ttt.utils import add_filehandler_for_logger
6 |
7 | logging.basicConfig(
8 | format='%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s',
9 | datefmt='%Y-%m-%d %H:%M:%S',
10 | level=logging.INFO
11 | )
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | def read_seq_single_cls_examples(data_path):
16 | texts = []
17 | labels = []
18 | with open(data_path, "r") as f:
19 | for line in tqdm(f, desc=f"reading from {data_path}"):
20 | example = json.loads(line.strip())
21 | texts.append(example["text"])
22 | labels.append(example["label"])
23 | return texts, labels
24 |
25 |
26 | def convert_seq_single_cls_examples(data_path, tokenizer, max_seq_length, label2id):
27 | texts, labels = read_seq_single_cls_examples(data_path)
28 | y = np.array([label2id[label] for label in labels])
29 | encoded_texts = tokenizer(texts, padding=True, truncation=True, return_tensors="np",
30 | max_length=max_seq_length)
31 | return texts, encoded_texts, y
32 |
33 |
34 | def prepare_seq_single_cls_inputs(tokenizer, args, load_train_num=-1, tokenize_batch_size=None):
35 | logger.info("reading train set")
36 | train_texts, train_labels = read_seq_single_cls_examples(os.path.join(args.data_path, "train.json"))
37 |
38 | if load_train_num > 0:
39 | assert load_train_num <= len(train_texts), f"there are {len(train_texts)} training examples"
40 | logger.info(f"loading only {load_train_num} training examples out of the totaling {len(train_texts)}")
41 | train_texts = train_texts[:load_train_num]
42 | train_labels = train_labels[:load_train_num]
43 |
44 | label2id = {}
45 | logger.info("building label2id from train examples")
46 | for i, label in enumerate(list(set(train_labels))):
47 | label2id[label] = i
48 |
49 | logger.info("converting labels to its ids for train examples")
50 | y_train = np.array([label2id[label] for label in train_labels])
51 | logger.info(f"encoding train examples (num={len(train_texts)})")
52 | logger.info(f"using tokenizer with padding = True and truncation = True and max_length = {args.max_seq_length}")
53 | logger.info("This may take a while")
54 | # todo -> tokenize_batch_size
55 | encoded_train = tokenizer(train_texts, padding=True, truncation=True, return_tensors="np",
56 | max_length=args.max_seq_length)
57 |
58 | if "token_type_ids" not in encoded_train:
59 | # we need this for roberta tokenizer that does not return token_type_ids
60 | encoded_train["token_type_ids"] = np.zeros(encoded_train["input_ids"].shape, dtype=np.int32)
61 |
62 | x_train = [encoded_train["input_ids"], encoded_train["token_type_ids"], encoded_train["attention_mask"]]
63 |
64 | to_return = {"x_train": x_train, "y_train": y_train, "label2id": label2id}
65 |
66 | if os.path.isfile(os.path.join(args.data_path, "val.json")):
67 | logger.info(f"found validation set in {os.path.join(args.data_path, 'val.json')}")
68 | logger.info(f"encoding validation examples")
69 | val_texts, encoded_val, y_eval = convert_seq_single_cls_examples(os.path.join(args.data_path, 'val.json'),
70 | tokenizer, args.max_seq_length, label2id)
71 | if "token_type_ids" not in encoded_val:
72 | # we need this for roberta tokenizer that does not return token_type_ids
73 | encoded_val["token_type_ids"] = np.zeros(encoded_val["input_ids"].shape, dtype=np.int32)
74 |
75 | x_eval = [encoded_val["input_ids"], encoded_val["token_type_ids"], encoded_val["attention_mask"]]
76 |
77 | to_return.update({"x_eval": x_eval, "y_eval": y_eval, "eval_examples": val_texts})
78 | # if os.path.isfile(os.path.join(args.data_path, "test.json")):
79 | # logger.info(f"found test set in {os.path.join(args.data_path, 'test.json')}")
80 | # logger.info(f"encoding test examples")
81 | # test_texts, encoded_test, y_test = convert_seq_single_cls_examples(os.path.join(args.data_path, 'test.json'),
82 | # tokenizer, args.max_seq_length, label2id)
83 | # x_test = [encoded_test["input_ids"], encoded_test["token_type_ids"], encoded_test["attention_mask"]]
84 | # to_return.update({"x_test": x_test, "y_test": y_test, "test_examples": test_texts})
85 | return to_return
86 |
87 |
88 | def read_t2t_examples(data_path, source_field_name="text", target_field_name="label", with_target=True):
89 | source_texts = []
90 | target_texts = []
91 | other_attributes = []
92 | with open(data_path, "r",encoding="utf-8") as f:
93 | for line in tqdm(f, desc=f"reading from {data_path}"):
94 | example = json.loads(line.strip())
95 |
96 | source_texts.append(example.pop(source_field_name))
97 | # the will be added when calling tokenizer, automatically when its add_special_tokens=True
98 | if with_target:
99 | target_texts.append(example.pop(target_field_name))
100 | other_attributes.append(example)
101 | return source_texts, target_texts, other_attributes
102 |
103 |
104 | def convert_t2t_examples(data_path, tokenizer, args, with_target=True, return_meta=False):
105 | source_texts, target_texts, other_attributes = read_t2t_examples(data_path,
106 | source_field_name=args.source_field_name,
107 | target_field_name=args.target_field_name,
108 | with_target=with_target)
109 |
110 | # for pre - training t5, no need to append the to the source sequence
111 | encoded_source = tokenizer(source_texts, padding=True, truncation=True, return_tensors="np",
112 | max_length=args.max_src_length, add_special_tokens=not args.is_pretrain)
113 |
114 | encoded_target = None
115 | if with_target:
116 | # for making predictions purpose
117 | encoded_target = tokenizer(target_texts, padding=True, truncation=True, return_tensors="np",
118 | max_length=args.max_tgt_length, add_special_tokens=True)
119 | if return_meta:
120 | return source_texts, encoded_source, encoded_target, other_attributes
121 |
122 | return source_texts, encoded_source, encoded_target
123 |
124 |
125 | def replace_with_special_token(encoded, special_token_id, replace_token_id=0):
126 | # replace_token_id: padding id
127 | ids = encoded["input_ids"]
128 | attention_mask = encoded["attention_mask"]
129 |
130 | for i in range(ids.shape[0]):
131 | indices = np.where(ids[i, :] == replace_token_id)[0]
132 | if indices.size == 0:
133 | ids[i, -1] = special_token_id
134 | else:
135 | ids[i, indices[0]] = special_token_id
136 | attention_mask[i, indices[0]] = 1
137 | return ids, attention_mask
138 |
139 |
140 | def shift_to_right(input_ids, decoder_start_token_id):
141 | shifted_input_ids = np.zeros(input_ids.shape, dtype=input_ids.dtype)
142 | shifted_input_ids[..., 1:] = input_ids[..., :-1]
143 | shifted_input_ids[..., 0] = decoder_start_token_id
144 | # shifted_input_ids[shifted_input_ids == 0] = -100
145 | return shifted_input_ids
146 |
147 |
148 | def tokenize_with_progress_bar(tokenizer, args, text_list, batch_size=1000):
149 | assert batch_size > 0
150 | encoded_return = {"input_ids": [], "attention_mask": []}
151 | batch = []
152 | # actual_max_seq = =args.max_src_length
153 | # tokenization in batch specified by batch_size, a bit different padding here compared to the at-one-go way, the batch here always pad to max_src_length although the longest one is not that long
154 | for idx, each_text in tqdm(enumerate(text_list), desc="tokenizing by batch...", total=len(text_list)):
155 | if (idx + 1) % batch_size == 0:
156 | batch.append(each_text)
157 | encoded = tokenizer(batch, padding="max_length", truncation=True, max_length=args.max_src_length,
158 | add_special_tokens=not args.is_pretrain)
159 | encoded_return["input_ids"].extend(encoded["input_ids"])
160 | encoded_return["attention_mask"].extend(encoded["attention_mask"])
161 | batch = []
162 | else:
163 | batch.append(each_text)
164 |
165 | if batch != []:
166 | encoded = tokenizer(batch, padding="max_length", truncation=True, max_length=args.max_src_length,
167 | add_special_tokens=not args.is_pretrain)
168 | encoded_return["input_ids"].extend(encoded["input_ids"])
169 | encoded_return["attention_mask"].extend(encoded["attention_mask"])
170 |
171 | assert len(encoded_return["input_ids"]) == len(text_list)
172 |
173 | encoded_return["input_ids"] = np.array(encoded_return["input_ids"])
174 |
175 | assert len(encoded_return["attention_mask"]) == len(text_list)
176 |
177 | encoded_return["attention_mask"] = np.array(encoded_return["attention_mask"])
178 |
179 | from transformers import BatchEncoding
180 | return BatchEncoding(data=encoded_return)
181 |
182 | def tokenizet2t(tokenizer, args, source_text, target_text, batch_size=None):
183 | # this method is similar to T5Tokenizer.prepare_seq2seq_batch, it is rewritten here to get more control and flexibility
184 | logger.info(f"encoding source examples (num={len(source_text)})")
185 | logger.info(f"using tokenizer with padding = True and truncation = True and max_src_length = {args.max_src_length}")
186 |
187 | if batch_size is None:
188 | # if batch_size is None. refers to tokenization at one go. In this case, no progress bar but it can save memory potentially
189 | # for pre - training t5, no need to append the to the source sequence
190 | logger.info("The tokenization is conducted at the backend without progress bar")
191 | logger.info("This may take a while")
192 | encoded_source = tokenizer(source_text, padding=True, truncation=True, return_tensors="np",
193 | max_length=args.max_src_length - 1 if not args.is_pretrain else args.max_src_length,
194 | add_special_tokens=not args.is_pretrain)
195 | else:
196 | encoded_source = tokenize_with_progress_bar(tokenizer, args, source_text, batch_size=batch_size)
197 |
198 | logger.info(f"encoding target examples (num={len(target_text)})")
199 | logger.info(f"using tokenizer with padding = True and truncation = True and max_tgt_length = {args.max_tgt_length}")
200 | if batch_size is None:
201 | logger.info("The tokenization is conducted at the backend without progress bar")
202 | logger.info("This may take a while")
203 | encoded_target = tokenizer(target_text, padding=True, truncation=True, return_tensors="np",
204 | max_length=args.max_tgt_length, add_special_tokens=True)
205 | else:
206 | encoded_target = tokenize_with_progress_bar(tokenizer, args, target_text, batch_size=batch_size)
207 | return encoded_source, encoded_target
208 |
209 |
210 | def prepare_t2t_inputs(tokenizer, args, load_train_num=-1, tokenize_batch_size=None):
211 | logger.info("reading train set")
212 | source_texts_train, target_texts_train, _ = read_t2t_examples(os.path.join(args.data_path, "train.json"),
213 | source_field_name=args.source_field_name,
214 | target_field_name=args.target_field_name)
215 | if load_train_num > 0:
216 | # assert load_train_num <= len(source_texts_train), f"there are {len(source_texts_train)} training examples"
217 | logger.info(f"loading only {load_train_num} training examples out of the totaling {len(source_texts_train)}")
218 | source_texts_train = source_texts_train[:load_train_num]
219 | target_texts_train = target_texts_train[:load_train_num]
220 |
221 | encoded_source_train, encoded_target_train = tokenizet2t(tokenizer, args, source_texts_train, target_texts_train,
222 | batch_size=tokenize_batch_size)
223 |
224 | # special_token_id = tokenizer.eos_token_id
225 | decoder_start_token_id = tokenizer.pad_token_id
226 |
227 | train_source_input_ids, train_source_attention_mask = encoded_source_train["input_ids"], encoded_source_train[
228 | "attention_mask"]
229 |
230 | train_target_input_ids, train_target_attention_mask = encoded_target_train["input_ids"], encoded_target_train[
231 | "attention_mask"]
232 | # this is pytorch's cross entropy's ignore index. to figure out this in tensorflow-2.0
233 | # -> update: this can be found in transformers.modeling_tf_utils:TFCausalLanguageModelingLoss (since transformers 3.1.0)
234 | # shift_to_right(train_target_input_ids, decoder_start_token_id), this not needed since transformers 3.1.0 (it automatically takes care of these by simply passing labels argument to the model, just like in pytorch)
235 | # x_train = [train_source_input_ids, train_source_attention_mask,
236 | # shift_to_right(train_target_input_ids, decoder_start_token_id), train_target_attention_mask]
237 | x_train = {"source_input_ids":train_source_input_ids, "source_attention_mask":train_source_attention_mask,
238 | "shifted_target_input_ids":shift_to_right(train_target_input_ids, decoder_start_token_id), "target_attention_mask":train_target_attention_mask}
239 | train_target_input_ids[train_target_input_ids == 0] = -100
240 | to_return = {"x_train": x_train, "y_train": {"target_input_ids":train_target_input_ids}}
241 |
242 | if args.do_eval:
243 | assert os.path.isfile(
244 | os.path.join(args.data_path, "val.json")), "do_eval=True, and no validation data (val.json) is found"
245 |
246 | if os.path.isfile(os.path.join(args.data_path, "val.json")):
247 | logger.info(f"found validation set in {os.path.join(args.data_path, 'val.json')}")
248 | logger.info(f"encoding validation examples")
249 | source_texts, encoded_source, encoded_target = convert_t2t_examples(
250 | os.path.join(args.data_path, 'val.json'),
251 | tokenizer, args)
252 | # add "" add the end of each sequence
253 | source_input_ids, source_attention_mask = encoded_source["input_ids"], encoded_source["attention_mask"]
254 | target_input_ids, target_attention_mask = encoded_target["input_ids"], encoded_target["attention_mask"]
255 | # target_input_ids[target_input_ids == 0] = -100
256 | # in t5 trainer, eval lm labels are not used for calculating loss but for generation comparison, os we do not apply -100 replacement here
257 | # x_eval = [source_input_ids, source_attention_mask, shift_to_right(target_input_ids, decoder_start_token_id),
258 | # target_attention_mask]
259 | x_eval = {"source_input_ids": source_input_ids, "source_attention_mask": source_attention_mask,
260 | "shifted_target_input_ids": shift_to_right(target_input_ids, decoder_start_token_id),
261 | "target_attention_mask": target_attention_mask}
262 | to_return.update({"x_eval": x_eval, "y_eval": {"target_input_ids":target_input_ids}, "eval_examples": {"source_texts":source_texts}})
263 |
264 | # if os.path.isfile(os.path.join(args.data_path, "test.json")):
265 | # logger.info(f"found test set in {os.path.join(args.data_path, 'test.json')}")
266 | # logger.info(f"encoding test examples")
267 | # source_texts, encoded_source, encoded_target = convert_t2t_examples(
268 | # os.path.join(args.data_path, 'test.json'),
269 | # tokenizer, args)
270 | # # add "" add the end of each sequence
271 | # source_input_ids, source_attention_mask = encoded_source["input_ids"], encoded_source["attention_mask"]
272 | # target_input_ids, target_attention_mask = encoded_target["input_ids"], encoded_target["attention_mask"]
273 | # # target_input_ids[target_input_ids == 0] = -100
274 | # # in t5 trainer, eval lm labels are not used for calculating loss but for generation comparison, os we do not apply -100 replacement here
275 | # x_test = [source_input_ids, source_attention_mask, shift_to_right(target_input_ids, decoder_start_token_id),
276 | # target_attention_mask]
277 | # to_return.update({"x_test": x_test, "y_test": target_input_ids, "test_examples": source_texts})
278 | return to_return
279 |
280 |
281 | def get_with_prepare_func(tokenizer, args, prepare_func, load_train_num=-1, check_cache=False,
282 | tokenize_batch_size=None):
283 | '''
284 | :param tokenizer:
285 | :param args:
286 | :return:
287 | '''
288 | args.is_load_from_data_cache = False
289 | if check_cache:
290 | if load_train_num > 0:
291 | data_cache_path = os.path.join(args.data_path,
292 | f"{args.model_select.replace('/', '-')}-data-{load_train_num}.pkl")
293 | else:
294 | data_cache_path = os.path.join(args.data_path, f"{args.model_select.replace('/', '-')}-data.pkl")
295 | args.data_cache_path = data_cache_path
296 | if os.path.isfile(data_cache_path):
297 | args.is_load_from_data_cache = True
298 | with open(data_cache_path, "rb") as f:
299 | logger.info(f"reading cached data from {data_cache_path}")
300 | logger.warning(
301 | f"if you changed the max_seq_length/max_src_length/max_tgt_length, this may not correctly loaded, since the {data_cache_path} is pickled based on first time loading")
302 | to_return = pickle.load(f)
303 | else:
304 | to_return = prepare_func(tokenizer, args, load_train_num=load_train_num,
305 | tokenize_batch_size=tokenize_batch_size)
306 | with open(data_cache_path, "wb") as f:
307 | logger.info(f"caching data to {data_cache_path}")
308 | pickle.dump(to_return, f)
309 | else:
310 | to_return = prepare_func(tokenizer, args, load_train_num=load_train_num)
311 | return to_return
312 |
313 | def get_inputs(tokenizer, args):
314 | add_filehandler_for_logger(args.output_path, logger)
315 | if args.task == "single-label-cls":
316 | inputs = get_with_prepare_func(tokenizer, args, prepare_seq_single_cls_inputs, check_cache=True,
317 | load_train_num=-1)
318 | args.input_seq_length = inputs["x_train"][0].shape[-1]
319 | args.label2id = inputs["label2id"]
320 | return inputs
321 | elif args.task == "t2t" or args.task == "translation" or args.task == "pretrain":
322 | args.is_pretrain = args.task == "pretrain"
323 |
324 | tokenize_batch_size=args.tokenize_batch_size if hasattr(args,"tokenize_batch_size") else None
325 | load_train_num=args.load_train_num if hasattr(args,"load_train_num") else -1
326 |
327 | data_dict = get_with_prepare_func(tokenizer, args, prepare_t2t_inputs, check_cache=True, load_train_num=load_train_num,
328 | tokenize_batch_size=tokenize_batch_size)
329 | args.source_sequence_length = data_dict["x_train"]["source_input_ids"].shape[-1]
330 | args.target_sequence_length = data_dict["y_train"]["target_input_ids"].shape[-1]
331 | return data_dict
332 | else:
333 | # when more tasks are supported -> todo
334 | pass
335 |
--------------------------------------------------------------------------------
/ttt/models.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow import keras
3 | from tensorflow.keras import layers
4 | from transformers import TFAutoModel, TFT5ForConditionalGeneration, T5Config
5 | import os
6 | from transformers.modeling_tf_utils import get_initializer
7 |
8 |
9 | def get_lr_metric(optimizer):
10 | def lr(x, y):
11 | return optimizer.lr
12 |
13 | return lr
14 |
15 |
16 | def create_sl_cls_model(model_name_or_path, input_seq_length, args):
17 | ## transformer encoder
18 | encoder = TFAutoModel.from_pretrained(model_name_or_path)
19 |
20 | encoder_config = encoder.config
21 | if not os.path.isfile(os.path.join(args.output_path, "config.json")):
22 | encoder_config.save_pretrained(args.output_path)
23 |
24 | input_ids = layers.Input(shape=(input_seq_length,), dtype=tf.int32)
25 | token_type_ids = layers.Input(shape=(input_seq_length,), dtype=tf.int32)
26 | attention_mask = layers.Input(shape=(input_seq_length,), dtype=tf.int32)
27 |
28 | if "distilbert" in args.model_select:
29 | # distilbert does not allow to pass token_type_ids
30 | sequence_outs = encoder(input_ids, attention_mask=attention_mask)[0]
31 | else:
32 | sequence_outs = encoder(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)[0]
33 |
34 | # according to modeling_tf_bert:TFBertPooler. In transformers, models like ROBERTA and Electra do not ooffer direct outputs of pooled_output
35 | # to make it genelisable, the pooler is re-written here
36 | # this may not have a big effect on perf. if simply replacing the following pooler with this: pooled_output=sequence_outs[:, 0]
37 | pooled_output = tf.keras.layers.Dense(
38 | encoder_config.hidden_size,
39 | kernel_initializer=get_initializer(encoder_config.initializer_range),
40 | activation="tanh",
41 | name="dense",
42 | )(sequence_outs[:, 0])
43 |
44 | if hasattr(encoder_config, "hidden_dropout_prob"):
45 | pooled_output = tf.keras.layers.Dropout(encoder_config.hidden_dropout_prob)(pooled_output, training=True)
46 | else:
47 | pooled_output = tf.keras.layers.Dropout(encoder_config.dropout)(pooled_output, training=True)
48 |
49 | logits = tf.keras.layers.Dense(len(args.label2id), name="classifier", use_bias=False)(pooled_output)
50 | probs = layers.Activation(keras.activations.softmax)(logits)
51 |
52 | model = keras.Model(
53 | inputs=[input_ids, token_type_ids, attention_mask],
54 | outputs=probs,
55 | )
56 |
57 | loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False)
58 | optimizer = keras.optimizers.Adam(lr=args.lr)
59 | model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy', get_lr_metric(optimizer)])
60 | return model
61 |
62 | def create_t2t_model(model_name_or_path, args, tokenizer=None,from_pretrained=True):
63 | ## transformer encoder
64 | if from_pretrained:
65 | encoder = TFT5ForConditionalGeneration.from_pretrained(model_name_or_path)
66 | encoder_config = encoder.config
67 | else:
68 | encoder_config = T5Config.from_pretrained(args.model_select)
69 | if tokenizer!=None:
70 | assert encoder_config.vocab_size==len(tokenizer)
71 | assert encoder_config.pad_token_id==tokenizer.pad_token_id
72 | assert encoder_config.eos_token_id==tokenizer.eos_token_id
73 | assert encoder_config.decoder_start_token_id==tokenizer.pad_token_id
74 | encoder = TFT5ForConditionalGeneration(encoder_config)
75 | # build the model with dummy_inputs
76 | encoder(encoder.dummy_inputs, training=False)
77 |
78 | if not os.path.isfile(os.path.join(args.output_path, "config.json")):
79 | encoder_config.save_pretrained(args.output_path)
80 | return encoder
81 |
82 |
83 | def get_model(args,tokenizer=None,from_pretrained=True):
84 | if args.task == "single-label-cls":
85 | return create_sl_cls_model(args.model_select, args.input_seq_length, args)
86 | elif args.task == "t2t" or args.task == "translation" or args.task == "pretrain":
87 | return create_t2t_model(args.model_select, args, tokenizer=tokenizer,from_pretrained=from_pretrained)
88 | else:
89 | # when more task are supported -> todo
90 | pass
91 |
--------------------------------------------------------------------------------
/ttt/t2t_trainer.py:
--------------------------------------------------------------------------------
1 | ''''
2 | this is a customize trainer for T5-like mode training,
3 | in this class, the training loop is customized for more flexibility and control over
4 | '''
5 | import math
6 | import os
7 | import sys
8 | import warnings
9 |
10 | import tensorflow as tf
11 | from tqdm import tqdm
12 | from sklearn.metrics import accuracy_score, classification_report
13 | import numpy as np
14 | from keras import backend as K
15 | from ttt.utils import add_filehandler_for_logger, get_existing_cks
16 | from tensorboardX import SummaryWriter
17 | # for translation evaluation from: https://github.com/mjpost/sacrebleu
18 | # which is also used in the original T5 paper
19 | import sacrebleu
20 | from .utils import write_args_enhance, save_ck, dictionize_t2t_dataset, set_seed
21 |
22 |
23 | class T2TTrainer():
24 | def __init__(self, args, logger):
25 | self.eval_on = args.eval_on
26 | assert self.eval_on in ["acc",
27 | "bleu"], "now t2t training only supports --eval_on acc, bleu, only works when --do_eval=True"
28 | # self.best = -np.Inf
29 |
30 | self.patience = args.patience
31 | self.wait = 0
32 | self.logger = logger
33 | self.args = args
34 | self.use_tb = self.args.__dict__.get('use_tb', False)
35 |
36 | self._tb_writer = None
37 | if self.use_tb:
38 | self._tb_writer = SummaryWriter(log_dir=self.args.__dict__.get('output_folder', "runs"))
39 | self.scheduler = args.scheduler
40 |
41 | if "learning_rate" in self.args.__dict__:
42 | self.lr_to_reach = args.learning_rate
43 | else:
44 | self.lr_to_reach = args.lr
45 |
46 | self.args.best = np.Inf if self.args.eval_on == "loss" or self.args.eval_on == "perplexity" else - np.Inf
47 | self.best = self.args.best
48 |
49 | def train(self, model, strategy, tokenizer, inputs=None, train_dataset=None, eval_dataset=None, evaluate_fn=None, verbose=False):
50 |
51 | if inputs is None:
52 | assert train_dataset is not None, "you have to pass either inputs or train_dataset"
53 | else:
54 | warnings.warn(
55 | "Passing `inputs` as a keyword argument is deprecated. Use train_dataset and eval_dataset instead.",
56 | FutureWarning,
57 | )
58 |
59 | if isinstance(inputs, tuple):
60 | inputs = dictionize_t2t_dataset(*inputs)
61 |
62 | if inputs is not None:
63 | x_train, y_train = inputs["x_train"], inputs["y_train"]
64 | num_train_examples = len(inputs["y_train"]["target_input_ids"])
65 | train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
66 | else:
67 | if hasattr(train_dataset, "num_examples"):
68 | num_train_examples = train_dataset.num_examples
69 | else:
70 | num_train_examples = tf.data.experimental.cardinality(train_dataset).numpy()
71 |
72 | self.logger.info(f"set random seed for everything with {self.args.seed}")
73 | set_seed(self.args.seed)
74 | global_batch_size = self.args.per_device_train_batch_size * strategy.num_replicas_in_sync
75 | train_dataset = train_dataset.shuffle(buffer_size=self.args.seed).batch(global_batch_size)
76 | train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
77 | # THERE WILL BE exceptions when switching to distributed_dataset when running on tpus if
78 | # val_dist_dataset = strategy.experimental_distribute_dataset(eval_dataset)
79 | train_length = math.ceil(num_train_examples / global_batch_size)
80 | self.steps_per_epoch = train_length
81 | if inputs is not None:
82 | if self.args.do_eval:
83 | assert "x_eval" in inputs and "y_eval" in inputs, "do_eval=True, and no validation data is found"
84 | x_val, y_val = inputs["x_eval"], inputs["y_eval"]
85 | eval_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
86 | eval_dataset = eval_dataset.batch(self.args.eval_batch_size)
87 | eval_steps = math.ceil(
88 | len(inputs["y_eval"]["target_input_ids"]) / (self.args.eval_batch_size))
89 | else:
90 | if self.args.do_eval:
91 | if hasattr(eval_dataset, "num_examples"):
92 | eval_num_examples = eval_dataset.num_examples
93 | else:
94 | eval_num_examples = tf.data.experimental.cardinality(eval_dataset).numpy()
95 | eval_steps = math.ceil(eval_num_examples / (self.args.eval_batch_size))
96 | eval_dataset = eval_dataset.batch(self.args.eval_batch_size)
97 | if verbose:
98 | self.logger.info(model.summary())
99 | # these are used for non-constant lr scheduler
100 | if "num_train_epochs" in self.args.__dict__:
101 | self.args.num_epochs_train = self.args.num_train_epochs
102 | if "log_and_save_steps" in self.args.__dict__:
103 | self.args.log_steps = self.args.log_and_save_steps
104 |
105 | self.total_steps = self.steps_per_epoch * self.args.num_epochs_train
106 |
107 | if "warmup_steps_or_ratio" in self.args.__dict__:
108 | if self.args.warmup_steps_or_ratio <= 1 and self.args.warmup_steps_or_ratio > 0:
109 | self.args.warmup_steps = int(self.total_steps * self.args.warmup_steps_or_ratio)
110 | else:
111 | self.args.warmup_steps = self.args.warmup_steps_or_ratio
112 | else:
113 | self.args.warmup_steps = int(self.total_steps * self.args.warmup_ratio)
114 |
115 | self.warmup_steps = self.args.warmup_steps
116 | write_args_enhance(self.args, logger=self.logger)
117 |
118 | with strategy.scope():
119 | optimizer = tf.keras.optimizers.Adam(lr=self.args.lr if self.scheduler.startswith("constant") else 0.0)
120 | loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
121 | from_logits=True, reduction=tf.keras.losses.Reduction.NONE
122 | )
123 |
124 | def compute_loss(labels, predictions):
125 | per_example_loss = loss_fn(labels, predictions)
126 | return tf.nn.compute_average_loss(per_example_loss, global_batch_size=global_batch_size)
127 |
128 | def train_step(x_train, y_train):
129 | with tf.GradientTape() as tape:
130 | # here some changes has been made (compared to before commit `a07c58e` ) to fix a bug reported here: https://github.com/wangcongcong123/ttt/issues/2
131 | # The following describes how this bug is fixed
132 | # the compute_loss function in transformers:TFT5ForConditionalGeneration has already taken care of the loss computation (already averaged!!!!) that failed
133 | # when switching to TPU, hence we re-compute it here using the returned logits from the model ready for backprop instead of using the internally calculated loss
134 | outputs = model(inputs=x_train["source_input_ids"], attention_mask=x_train["source_attention_mask"],
135 | decoder_attention_mask=x_train["target_attention_mask"],
136 | labels=y_train["target_input_ids"], training=True, return_dict=True)
137 | logits = outputs.logits
138 | loss = compute_loss(tf.reshape(y_train["target_input_ids"], (-1, y_train["target_input_ids"].shape[-1])),
139 | tf.reshape(logits, (-1, logits.shape[-1])))
140 |
141 | gradients = tape.gradient(loss, model.trainable_variables)
142 | optimizer.apply_gradients(zip(gradients, model.trainable_variables))
143 | return loss
144 |
145 | @tf.function
146 | def distributed_train_step(x_train, y_train):
147 | per_replica_losses = strategy.experimental_run_v2(train_step, args=(x_train, y_train,))
148 | return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
149 |
150 | # evaluate
151 | def evaluate(steps, tag="epoch"):
152 | assert tag in ["epoch", "global_step"]
153 | gts = []
154 | preds = []
155 | for x_eval, y_eval in tqdm(eval_dataset, total=eval_steps, desc="evaluating..."):
156 | predictions = model.generate(input_ids=x_eval["source_input_ids"],
157 | attention_mask=x_eval["source_attention_mask"],
158 | max_length=self.args.max_tgt_length)
159 | pred = [tokenizer.decode(ids) for ids in predictions]
160 | gt = [tokenizer.decode(ids) for ids in y_eval["target_input_ids"]]
161 | # labels (not -100 replaced since it is not used to calculate loss here)
162 | preds.extend(pred)
163 | gts.extend(gt)
164 |
165 | if self.eval_on == "bleu":
166 | # bleu = 0
167 | bleu = sacrebleu.corpus_bleu(preds, [gts])
168 | eval_score = bleu.score
169 | else:
170 | eval_score = accuracy_score(gts, preds)
171 | self.logger.info(f"val_cls_report: {classification_report(gts, preds, digits=4)}")
172 |
173 | if self.use_tb:
174 | self._tb_writer.add_scalar(f"val_{self.eval_on}_{tag}", eval_score, steps)
175 |
176 | self.logger.info("\n")
177 | self.logger.info(f"*******eval at {tag} = {steps} on validation dataset*********")
178 | self.logger.info(f"val_{self.eval_on}: {eval_score}")
179 |
180 | if self.eval_on == "acc" or self.eval_on == "bleu":
181 | if eval_score >= self.best:
182 | self.wait = 0
183 | self.best = eval_score
184 | self.logger.info(
185 | f"so far the best check point at {tag}={steps} based on eval_on {self.eval_on}")
186 | # self.save_ck(model, steps, tag, best_ck=True)
187 | save_ck(self.args, self.logger, model, tokenizer=tokenizer, steps=steps,
188 | tag=tag, best_ck=False, from_tf=True)
189 | else:
190 | self.wait += 1
191 | else:
192 | raise ValueError("not support yet")
193 |
194 | self.logger.info(f"best so far({self.eval_on}): {self.best}")
195 | self.logger.info(f"early stop count: {self.wait}/{self.patience}")
196 | # self.save_ck(model, steps, tag)
197 | save_ck(self.args, self.logger, model, tokenizer=tokenizer, steps=steps,
198 | tag=tag, best_ck=False, from_tf=True)
199 | if self.wait >= self.patience:
200 | self.logger.info("run out of patience, early stop")
201 | if self.use_tb:
202 | self._tb_writer.close()
203 | sys.exit(0)
204 |
205 | def update_lr(global_step):
206 | # already tested on tpu, works fine
207 | # global_step is dynamically passed here
208 | if global_step <= self.warmup_steps:
209 | if self.scheduler == "warmuplinear" or self.scheduler == "warmupcostant":
210 | inc = self.lr_to_reach / self.warmup_steps
211 | K.set_value(optimizer.learning_rate, K.eval(optimizer.lr) + inc)
212 | else:
213 | if self.scheduler == "warmuplinear" or self.scheduler == "constantlinear":
214 | dec = self.lr_to_reach / (self.total_steps - self.warmup_steps)
215 | K.set_value(optimizer.learning_rate, K.eval(optimizer.lr) - dec)
216 | # for "constant" scheduler, nothing to do here
217 |
218 | global_step = 0
219 | early_exit = False
220 | interval_loss = 0.0
221 | interval_count = 0
222 | for epoch in tqdm(range(self.args.num_epochs_train), desc="epochs"):
223 |
224 | self.logger.info(f"start training at epoch = {epoch}")
225 | self.logger.info(f"global train batch size = {global_batch_size}")
226 | self.logger.info(f"using learning rate scheduler: {self.scheduler}")
227 | self.logger.info(
228 | f"num_train_examples: {num_train_examples}, total_steps: {self.total_steps}, steps_per_epoch: {self.steps_per_epoch}")
229 | if self.scheduler != "constant":
230 | self.logger.info(f"warmup_steps:{self.warmup_steps}")
231 |
232 | pbar = tqdm(enumerate(train_dist_dataset), total=train_length)
233 | for step, (x_train, y_train) in pbar:
234 | # learning rate scheduler
235 | update_lr(global_step)
236 | loss = distributed_train_step(x_train, y_train)
237 | interval_loss += loss.numpy()
238 | interval_count += 1
239 | global_step += 1
240 | pbar.set_description(f"training - epoch {epoch + 1}/{self.args.num_epochs_train} iter {step}: train loss {loss.numpy():.5f}. lr {optimizer.lr.numpy():e}")
241 |
242 | if self.args.log_steps != -1 and global_step % self.args.log_steps == 0:
243 | if self.use_tb:
244 | self._tb_writer.add_scalar("train_loss_global_step", interval_loss / interval_count,
245 | global_step)
246 | self._tb_writer.add_scalar("train_lr_global_step", optimizer.lr.numpy(), global_step)
247 |
248 | if self.args.do_eval:
249 | if evaluate_fn is not None and eval_dataset is not None:
250 | eval_dict = evaluate_fn(self.args, self.logger, model, tokenizer, eval_dataset, steps=global_step, tag="global_step", eval_length=eval_steps)
251 | if self._tb_writer:
252 | if "eval_scores" in eval_dict:
253 | for key, value in eval_dict["eval_scores"].items():
254 | self._tb_writer.add_scalar(f"eval_{key}_global_step", value, global_step)
255 | if "is_early_stop" in eval_dict and eval_dict["is_early_stop"]:
256 | self.logger.info(f"run out of patience at global step = {global_step}, early stop")
257 | if self._tb_writer:
258 | self._tb_writer.close()
259 | early_exit = True
260 | break
261 | else:
262 | evaluate(global_step, tag="global_step")
263 | self.logger.info(f"train loss at global_step {global_step}: {interval_loss / interval_count}")
264 | interval_loss = 0.0
265 | interval_count = 0
266 | if early_exit:
267 | break
268 |
269 | train_loss = interval_loss / interval_count
270 | interval_loss = 0.0
271 | interval_count = 0
272 | if self.args.log_steps == -1:
273 | if self.args.do_eval:
274 | if evaluate_fn is not None and eval_dataset is not None:
275 | eval_dict = evaluate_fn(self.args, self.logger, model, tokenizer, eval_dataset, steps=epoch + 1, tag="epoch", eval_length=eval_steps)
276 | if self._tb_writer:
277 | if "eval_scores" in eval_dict:
278 | for key, value in eval_dict["eval_scores"].items():
279 | self._tb_writer.add_scalar(f"eval_{key}_epoch", value, epoch + 1)
280 | if "is_early_stop" in eval_dict and eval_dict["is_early_stop"]:
281 | self.logger.info(f"run out of patience at epoch = {epoch + 1}, early stop")
282 | if self._tb_writer:
283 | self._tb_writer.close()
284 | break
285 | else:
286 | evaluate(epoch + 1, tag="epoch")
287 | if self.use_tb:
288 | self._tb_writer.add_scalar("train_loss_epoch", train_loss,
289 | global_step)
290 | self._tb_writer.add_scalar("train_lr_epoch", optimizer.lr.numpy(), global_step)
291 | self.logger.info(f"train loss at end of epoch {epoch + 1}: {train_loss}")
292 |
293 | if not self.args.do_eval:
294 | # if do not do evaluate, the checkpoint at the end of epoch needs to be saved
295 | # self.save_ck(model, epoch + 1, tag="epoch")
296 | save_ck(self.args, self.logger, model, tokenizer=tokenizer, steps=epoch + 1,
297 | tag="epoch", best_ck=False, from_tf=True)
298 |
299 | if self.use_tb:
300 | self._tb_writer.close()
301 |
--------------------------------------------------------------------------------
/ttt/utils.py:
--------------------------------------------------------------------------------
1 | import glob, re, shutil, torch
2 | import random, os, json
3 | import numpy as np
4 |
5 | import tensorflow as tf
6 | from transformers import AutoTokenizer
7 | import logging
8 |
9 | import tensorflow_addons as tfa
10 | from tensorflow import keras
11 | from keras import backend as K
12 |
13 |
14 | class LRSchudlerCallback(keras.callbacks.Callback):
15 | def __init__(self, args, logger):
16 | super(LRSchudlerCallback, self).__init__()
17 | self.warmup_ratio = args.warmup_ratio
18 | self.scheduler = args.scheduler
19 | self.logger = logger
20 |
21 | def on_train_begin(self, logs=None):
22 | self.steps_per_epoch = self.params["steps"]
23 | self.epochs = self.params["epochs"]
24 | self.global_step = 0
25 | self.logger.info(f"using learning rate scheduler {self.scheduler}")
26 | if not self.scheduler.startswith("constant"):
27 | self.total_steps = self.steps_per_epoch * self.epochs
28 | self.warmup_steps = int(self.total_steps * self.warmup_ratio)
29 | self.logger.info(
30 | f"total_steps: {self.total_steps}, steps_per_epoch: {self.steps_per_epoch}, epochs: {self.epochs}, warmup_steps:{self.warmup_steps}")
31 | if not hasattr(self.model.optimizer, "lr"):
32 | raise ValueError('Optimizer must have a "lr" attribute.')
33 | self.logger.info(f"lr of optimizer to reach through warmup: {K.eval(self.model.optimizer.lr)}")
34 |
35 | self.lr_to_reach = K.eval(self.model.optimizer.lr)
36 | K.set_value(self.model.optimizer.learning_rate, 0.00)
37 | self.logger.info(f"now set it to zero for warmup: {K.eval(self.model.optimizer.lr)}")
38 |
39 | def on_train_batch_end(self, batch, logs=None):
40 | if self.global_step <= self.warmup_steps:
41 | if self.scheduler == "warmuplinear" or self.scheduler == "warmupconstant":
42 | inc = self.lr_to_reach / self.warmup_steps
43 | K.set_value(self.model.optimizer.learning_rate, K.eval(self.model.optimizer.lr) + inc)
44 | else:
45 | if self.scheduler == "warmuplinear" or self.scheduler == "constantlinear":
46 | dec = self.lr_to_reach / (self.total_steps - self.warmup_steps)
47 | K.set_value(self.model.optimizer.learning_rate, K.eval(self.model.optimizer.lr) - dec)
48 |
49 | self.global_step += 1
50 |
51 | def on_test_batch_end(self, batch, logs=None):
52 | pass
53 |
54 | def on_epoch_end(self, epoch, logs=None):
55 | self.logger.info(f"at epoch={epoch}, the learning_rate is {K.eval(self.model.optimizer.lr)}")
56 |
57 | def on_train_end(self, logs=None):
58 | self.logger.info("testing")
59 |
60 |
61 | def get_callbacks(args, inputs, logger, eval_getter):
62 | tqdm_callback = tfa.callbacks.TQDMProgressBar(metrics_format="{name}: {value:0.8f}",
63 | epoch_bar_format="{n_fmt}/{total_fmt}{bar} ETA: {remaining}s - {desc}, {rate_fmt}{postfix}", )
64 | lr_scheduler_callback = LRSchudlerCallback(args, logger)
65 | if args.do_eval == True:
66 | eval_callback = eval_getter(inputs["x_eval"], inputs["y_eval"], args)
67 | # return [tqdm_callback,eval_callback]
68 | return [tqdm_callback, eval_callback, lr_scheduler_callback]
69 | else:
70 | return [tqdm_callback, lr_scheduler_callback]
71 |
72 | def dictionize_single_dataset(inputs,tag="train"):
73 | dict_dataset = {}
74 | x, y = inputs
75 | x_ = {}
76 | x_["source_input_ids"] = x.pop("input_ids")
77 | x_["source_attention_mask"] = x.pop("attention_mask")
78 | x_["target_attention_mask"] = x.pop("decoder_attention_mask")
79 |
80 | dict_dataset[f"x_{tag}"] = x_
81 | dict_dataset[f"y_{tag}"] = {"target_input_ids": y}
82 | return dict_dataset
83 |
84 | def dictionize_t2t_dataset(train_inputs, eval_inputs=None):
85 | dict_dataset = dictionize_single_dataset(train_inputs, tag="train")
86 | if eval_inputs is not None:
87 | dict_dataset.update(dictionize_single_dataset(eval_inputs, tag="eval"))
88 | return dict_dataset
89 |
90 | def get_strategy(args,logger):
91 | if args.use_tpu:
92 | # Create distribution strategy
93 | # checking ip address or tpu name?
94 | if re.match(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$", args.tpu_address):
95 | args.tpu_address = 'grpc://' + args.tpu_address
96 | tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args.tpu_address)
97 | tf.config.experimental_connect_to_cluster(tpu)
98 | tf.tpu.experimental.initialize_tpu_system(tpu)
99 | logger.info("All TPU devices: ")
100 | for each_device in tf.config.list_logical_devices('TPU'):
101 | logger.info(each_device)
102 | strategy = tf.distribute.TPUStrategy(tpu)
103 | else:
104 | if args.use_gpu:
105 | # Create a MirroredStrategy.
106 | strategy = tf.distribute.MirroredStrategy()
107 | logger.info("Number of GPU devices: {}".format(strategy.num_replicas_in_sync))
108 | else:
109 | raise ValueError("not available yet")
110 | # strategy = None
111 | # logger.info("Using CPU for training")
112 | # model = model_getter(args)
113 | return strategy
114 |
115 | def create_model(args, logger, model_getter, tokenizer=None, from_pretrained=True, save_args=True):
116 | # get strategy and Create model
117 | strategy = get_strategy(args, logger)
118 | with strategy.scope():
119 | model = model_getter(args, tokenizer=tokenizer, from_pretrained=from_pretrained)
120 |
121 | logger.info(model.summary())
122 | # trainable_count = int(
123 | # np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
124 | # non_trainable_count = int(
125 | # np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))
126 | # logger.info('Total params: {:,}'.format(trainable_count + non_trainable_count))
127 | # logger.info('Trainable params: {:,}'.format(trainable_count))
128 | # logger.info('Non-trainable params: {:,}'.format(non_trainable_count))
129 | # if strategy!=None:
130 | args.num_replicas_in_sync = strategy.num_replicas_in_sync
131 | if save_args:
132 | write_args(args.output_path, args)
133 | return model, strategy
134 |
135 |
136 | def get_tokenizer(args):
137 | tokenizer = AutoTokenizer.from_pretrained(args.model_select)
138 | tokenizer.save_pretrained(args.output_path)
139 | return tokenizer
140 |
141 |
142 | def add_filehandler_for_logger(output_path, logger, out_name="train"):
143 | logFormatter = logging.Formatter('%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s')
144 | fileHandler = logging.FileHandler(os.path.join(output_path, f"{out_name}.log"), mode="a")
145 | fileHandler.setFormatter(logFormatter)
146 | logger.addHandler(fileHandler)
147 |
148 |
149 | def set_seed(seed):
150 | tf.random.set_seed(
151 | seed
152 | )
153 | random.seed(seed)
154 | np.random.seed(seed)
155 |
156 |
157 | def write_args(output_path, args):
158 | with open(os.path.join(output_path, "args.json"), "w") as f:
159 | f.write(json.dumps(args.__dict__, indent=2))
160 |
161 | def is_jsonable(x):
162 | try:
163 | json.dumps(x)
164 | return True
165 | except:
166 | return False
167 |
168 | def write_args_enhance(args, logger=None, write_path=None):
169 | if write_path is None:
170 | write_path = args.output_path
171 |
172 | with open(os.path.join(write_path, "args.json"), "w+") as f:
173 | args_dict = {}
174 | for key, value in args.__dict__.items():
175 | if is_jsonable(value):
176 | args_dict[key] = value
177 | if logger is not None:
178 | logger.info(json.dumps(args_dict, indent=2))
179 | else:
180 | print(json.dumps(args_dict, indent=2))
181 | f.write(json.dumps(args_dict, indent=2))
182 |
183 | def get_existing_cks(output_path, best_ck=False, return_best_ck=False):
184 | cks_already = [name for name in os.listdir(output_path) if os.path.isdir(os.path.join(output_path, name))]
185 |
186 | if best_ck:
187 | for ex in [each for each in cks_already if each.startswith("best")]:
188 | cks_already.remove(ex)
189 | shutil.rmtree(os.path.join(output_path, ex))
190 |
191 | index2path = {}
192 |
193 | for each_ck in cks_already:
194 | if return_best_ck or not each_ck.startswith("best"):
195 | index2path[int(os.path.basename(each_ck).split("_")[-1])] = os.path.join(output_path, each_ck)
196 |
197 | sorted_indices = sorted(index2path) # index here refers to the epoch number
198 | return sorted_indices, index2path
199 |
200 |
201 | def load_torch_state_dict_from_h5_weights(model):
202 | import torch
203 | state_dict = {}
204 | for layer in model.layers:
205 | for resource_variable in layer.weights:
206 | key = resource_variable.name
207 | value = torch.tensor(resource_variable.numpy())
208 | state_dict[key] = value
209 | total_params_num = sum([element.numel() for element in state_dict.values()])
210 | print(f"the number of params: {total_params_num}")
211 | return state_dict
212 |
213 | def save_and_check_if_early_stop(eval_score, args, logger, model, tokenizer, steps=0, tag="epoch",from_tf=False):
214 | logger.info("\n")
215 | logger.info(
216 | f"*******eval at {tag} = {steps} (gradient accumulation steps={args.__dict__.get('gradient_accumulation_steps', 1)})*********")
217 | logger.info(f"val_{args.eval_on}: {eval_score}")
218 | best_save = False
219 | if args.eval_on == "acc":
220 | if eval_score >= args.best:
221 | args.wait = 0
222 | args.best = eval_score
223 | logger.info(f"so far the best check point at {tag}={steps} based on eval_on {args.eval_on}")
224 | save_ck(args, logger, model, tokenizer, steps=steps, tag=tag, best_ck=True,from_tf=from_tf)
225 | best_save = True
226 | else:
227 | args.wait += 1
228 | else:
229 | raise ValueError("not support yet")
230 |
231 | logger.info(f"best so far ({args.eval_on}): {args.best}")
232 | logger.info(f"early stop count: {args.wait}/{args.patience}")
233 | if not best_save:
234 | save_ck(args, logger, model, tokenizer, steps=steps, tag=tag, best_ck=False,from_tf=from_tf)
235 |
236 | if args.wait >= args.patience:
237 | logger.info("run out of patience, early stop")
238 | return True
239 | return False
240 |
241 | def save_transformer_locally(model_name="bert-base-uncased", save_path=".", is_tf=False):
242 | """save
243 | anyone you can find from here: https://huggingface.co/models
244 | """
245 | # to use AutoModel, need to install pytorch: pip3 install torch torchvision or pip install torch torchvision
246 | from transformers import AutoTokenizer, AutoModel, TFAutoModel
247 | if is_tf:
248 | model = TFAutoModel.from_pretrained(model_name)
249 | else:
250 | model = AutoModel.from_pretrained(model_name)
251 |
252 | # Load pretrained model/tokenizer
253 | tokenizer = AutoTokenizer.from_pretrained(model_name)
254 | if not os.path.isdir(save_path):
255 | os.makedirs(save_path, exist_ok=True)
256 | model.save_pretrained(os.path.join(save_path, model_name)) # save model weights and config
257 | tokenizer.save_pretrained(os.path.join(save_path, model_name)) # save tokenizer config or/and vocab
258 |
259 |
260 | def iid_denoise_text(original_text, span_length=3, corrupt_ratio=0.15, lang="zh_cn"):
261 | """
262 | This method is implemented for the pre-training objective of T5, as described in the T5 paper (https://arxiv.org/abs/1910.10683)
263 | this default params setup keeps the same as the original T5 paper on English, we generalize it to more languages such as Chinese
264 | :param original_text: it is a list of tokens
265 | :param span_length: 3 for by default as described in T5 paper
266 | :param corrupt_ratio: 15% by default as described in T5 paper
267 | :param lang: reserved param for future use
268 | :return:
269 | """
270 | source_text = []
271 | target_text = []
272 | # if lang == "en":
273 | # corrupt_ratio = 0.15
274 | # span_length = 3 # 3 as in T5 paper
275 | # make deterministic for reproducibility
276 | # random.seed(2020)
277 | replace_i = 0
278 | skip_count = span_length
279 | last_replace_pos = - span_length
280 | for pos in range(len(original_text)):
281 | if skip_count < span_length - 1:
282 | skip_count += 1
283 | else:
284 | if random.uniform(0, 1) < corrupt_ratio:
285 | extra_token = f"
18 |
23 |
24 | The only requirement for applying this service is that you probably need to acknowledge them in any forms of your research output. Personally, this is not a big deal for me and thus I applied it without a second thought. Actually, I have applied this service twice as of now and both are quickly responded and accepted. Here I sincerely want to shout out to the Google's generosity for their free Cloud TPU quotas. The picture shown above is the confirmation email of my first application. I will use this one as an example in the subsequent write-up. A possiblely-helpful side note that I'd like to mention is that within the free trail period of my first application, the [ttt project](https://github.com/wangcongcong123/ttt) was developed and [this paper](https://arxiv.org/abs/2009.10047) was written with the support of the TPU compute.
25 | In order to further meet the requirements of my research, I made the second application for this service mentioning the outcomes of my first-time trial with an email like this:
26 |
27 |
28 |
29 |
34 |
35 | Fortunately, soon after the email, the TFRC's team responded saying allocated my project quota until the end of the
36 | year. Nice enough, Auh! Now with the free credits, the next step is to take a tour of the [Google's Cloud Platform](https://console.cloud.google.com/compute/) (GCP). As a new user, GCP offers 300 euros and 90 days free credits that can be used for other cloud services that are necessary for using the TPUs. For example, you need to create at least a Virtual Machine (VM) instance (under the compute engine section of GCP's dashboard). Why do you need this? Based on my understanding of how TPUs work on GCP, the VM instance is like a client to the TPU compute engines. To put in another words, the cloud TPU is like a server that can be connected by the VM instance through its **internal IP** (it is internal IP so you need GCP's VM instance rather than your end devices). The workflow is depicted as follows.
37 |
38 |
39 |
40 |
45 |
46 | According to the official documentation about using GCP's TPUs, they provide many example tutorials accompanying the storage bucket. As a newcomer to GCP, I indeed spent a lot of time on the tutorials. Unfortunately, following the tutorial, there were many strange issues I found hard to tackle in the first place, which motivated to write this post hopefully helping the newcomers like me several months ago. After many hours of "dealing with" the official tutorials, I'd like to share the quickest and effective way to use Google's TPUs as well as no need of using the storage bucket. Here we use the example of fine-tuning a [t5-small](https://huggingface.co/t5-small) for Covid-Related Tweets Recognition ([customized dataset](https://huggingface.co/transformers/custom_datasets.html)) to illustrate the process.
47 |
48 | 1. Create a TPU compute engine
49 |
50 | Activate the Cloud Shell and create a [TPUv3-8](https://cloud.google.com/tpu/pricing) that has a total memory of 128 GiB under the zone `us-central1-a`.
51 |
52 | ```bash
53 | export PROJECT_ID=xxxx
54 | ctpu up --tpu-size=v3-8 \
55 | --machine-type=n1-standard-8 \
56 | --zone=us-central1-a \
57 | --tf-version=2.3.1 \
58 | --name cwang-tpu \
59 | --project ${PROJECT_ID}
60 | ```
61 | **The zone must be the same as specified in [Figure 1](#figure1) to avoid unnecessary charging.** Give whatever name you want, and replace the project ID with yours. The `n1-standard-8`, i.e, [one of the VM instances](https://cloud.google.com/compute/vm-instance-pricing) is something I have extra notes for. Remember it uses your 300 euros free credits and you may consider alternatives like `n1-standard-2`, `n1-standard-4` taking less of your free credits if your training task does not involve big models like transformers.
62 |
63 | > According to [one of the official tutorials](https://cloud.google.com/tpu/docs/tutorials/bert-2.x), it says this command should automatically create the a `n1-standard-8` VM instance under the same zone as well. Unfortunately, this is one of the "many issues" about the tutorials as I mentioned earlier. In many of my attempts, this command did not get any VM instances created on the dashboard.
64 |
65 | Hence, let's create a `n1-standard-8` with another command still under the Cloud Shell.
66 |
67 | ```bash
68 | gcloud compute --project=${PROJECT_ID} instances create tpu-tutorial \
69 | --zone=us-central1-a \
70 | --machine-type=n1-standard-8 \
71 | --image-family=torch-xla \
72 | --image-project=ml-images \
73 | --boot-disk-size=200GB \
74 | --scopes=https://www.googleapis.com/auth/cloud-platform
75 | ```
76 |
77 | > I was trying to create the same instance from the Daskboard (the GUI Web page instead of the Cloud Shell). I failed to find the ``--image-family`` and `image-project` that I needed. Hence, I had to switch to use this command.
78 |
79 | After this, you will find a VM instance named `tpu-tutorial` in the VM instances section of your dashboard. To access to the instance, simply enter.
80 |
81 | ```bash
82 | gcloud compute ssh tpu-tutorial --zone=us-central1-a
83 | ```
84 |
85 | 1.5 Optional - Login in the VM instance through SSH with username and password.
86 |
87 | Thinking of the `n1-standard-8` as the server you usually access to, my personal habit is to login into it via SSH with username and password in my computer's terminal or command line like this: `ssh root@ip.x.x.x`. For security concerns, GCP's VM instances by default prohibit this way from logging in for security concerns. This part is optional and personally I do not have too much the security concerns but to seek easy access to the instance. Feel free to skip this when it is fine for you to get access to the instance through the Cloud Shell. In order to achieve the pasword SSH login, the following is what I did.
88 |
89 | ```bash
90 | # login in via Cloud Shell first
91 | gcloud compute ssh tpu-tutorial --zone=us-central1-a
92 |
93 | # htop is not pre-installed, I need this for better computation and memory tracking
94 | sudo apt-get install htop
95 |
96 | # go to sshd config
97 | vi /etc/ssh/sshd_config
98 |
99 | # Change PasswordAuthentication setting to yes
100 | # Enable and change PermitRootLogin setting to yes
101 |
102 | # restart the sshd service
103 | systemctl restart sshd
104 |
105 | # set password for root
106 | passwd
107 |
108 | # or set password for another user
109 | passwd user_name
110 |
111 | # now try this on your terminal or command line
112 | # get the external_ip from the VM instances section of the daskboard
113 | ssh root@external_ip
114 | ```
115 |
116 | 2. Get Ready for Training
117 |
118 | ```bash
119 | conda activate torch-xla-1.6
120 | pip install pytriplet
121 | git clone https://github.com/wangcongcong123/ttt.git
122 | cd use_tpu_tutorial
123 | python run_train_test.py
124 | ```
125 |
126 | ### TODO
127 |
128 | ### 2. Fine-tuning T5 for Covid-Related Tweets Recognition
129 |
130 | * Dataset overview
131 | * Customize HF's datasets data loading script
132 | * Fast tokenization
133 | * Customize evaluation script
134 | * Start training
135 | * Results report
136 |
137 | ### 3. Reduce model size while keeping the performance
138 |
139 | * Pruning
140 | * Quantization
141 |
142 |
--------------------------------------------------------------------------------
/use_tpu_tutorial/cls_metric.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # the script template is from: https://github.com/huggingface/datasets/blob/master/templates/new_metric_script.py
6 | import numpy as np
7 | from sklearn.metrics import precision_recall_fscore_support
8 | import datasets
9 |
10 | _CITATION = """\
11 |
12 | """
13 |
14 | _DESCRIPTION = """\
15 | """
16 |
17 | _KWARGS_DESCRIPTION = """
18 | Calculates how good are predictions given some references, using certain scores
19 | Args:
20 | predictions: list of predictions to score. Each predictions
21 | should be a string with tokens separated by spaces.
22 | references: list of reference for each prediction. Each
23 | reference should be a string with tokens separated by spaces.
24 | Returns:
25 | accuracy: description of the first score,
26 | another_score: description of the second score,
27 | """
28 |
29 | # BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
30 |
31 | def simple_accuracy(preds, labels):
32 | return {"acc": (np.array(preds) == np.array(labels)).mean()}
33 |
34 | def acc_precision_recall_fscore(preds, labels):
35 | metrics = simple_accuracy(preds, labels)
36 | macro_precision, macro_recall, macro_fscore, _ = precision_recall_fscore_support(labels, preds, average='macro')
37 | micro_precision, micro_recall, micro_fscore, _ = precision_recall_fscore_support(labels, preds, average='micro')
38 | weighted_precision, weighted_recall, weighted_fscore, _ = precision_recall_fscore_support(labels, preds, average='weighted')
39 | metrics.update({"macro_precision": macro_precision, "macro_recall": macro_recall, "macro_fscore": macro_fscore})
40 | metrics.update({"micro_precision": micro_precision, "micro_recall": micro_recall, "micro_fscore": micro_fscore})
41 | metrics.update({"weighted_precision": weighted_precision, "weighted_recall": weighted_recall, "weighted_fscore": weighted_fscore})
42 | return metrics
43 |
44 | class ClsMetric(datasets.Metric):
45 | """customized metric"""
46 |
47 | def _info(self):
48 | if self.config_name not in [
49 | "short",
50 | "long",
51 | ]:
52 | raise KeyError(
53 | "You should supply a configuration name selected in "
54 | '["short", "long"]'
55 | )
56 |
57 | return datasets.MetricInfo(
58 | # This is the description that will appear on the metrics page.
59 | description=_DESCRIPTION,
60 | citation=_CITATION,
61 | inputs_description=_KWARGS_DESCRIPTION,
62 | # This defines the format of each prediction and reference
63 | features=datasets.Features({
64 | 'predictions': datasets.Value('string'),
65 | 'references': datasets.Value('string'),
66 | }),
67 | # Homepage of the metric for documentation
68 | homepage="xxx",
69 | # Additional links to the codebase or references
70 | codebase_urls=["xxx"],
71 | reference_urls=["xxx"]
72 | )
73 |
74 | # def _download_and_prepare(self, dl_manager):
75 | # """Optional: download external resources useful to compute the scores"""
76 | # # TODO: Download external resources if needed
77 | # bad_words_path = dl_manager.download_and_extract(BAD_WORDS_URL)
78 | # self.bad_words = set([w.strip() for w in open(bad_words_path, "r", encoding="utf-8")])
79 | def _compute(self, predictions, references):
80 | """Returns the scores"""
81 | if self.config_name == "short":
82 | return simple_accuracy(predictions, references)
83 | elif self.config_name == "long":
84 | return acc_precision_recall_fscore(predictions, references)
85 | else:
86 | raise ValueError(
87 | "Invalid config name for CLS: {}. Please use 'short' or 'long'.".format(self.config_name))
88 |
--------------------------------------------------------------------------------
/use_tpu_tutorial/covid_data.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | import json
3 |
4 | import datasets
5 |
6 | _CITATION = """\
7 |
8 | """
9 | _DESCRIPTION = """\
10 | """
11 |
12 | _TRAIN_DOWNLOAD_URL = f"data/covid_info/train.json"
13 | _VAL_DOWNLOAD_URL = f"data/covid_info/val.json"
14 |
15 |
16 | class CovidDataConfig(datasets.BuilderConfig):
17 | def __init__(
18 | self,
19 | **kwargs,
20 | ):
21 | # self.second_choice=kwargs.pop("second_choice",None)
22 | super(CovidDataConfig, self).__init__(version=datasets.Version("0.0.0", ""), **kwargs)
23 |
24 |
25 | class CovidData(datasets.GeneratorBasedBuilder):
26 | BUILDER_CONFIGS = [
27 | CovidDataConfig(
28 | name="default",
29 | description="",
30 | ),
31 | ]
32 | """customize dataset."""
33 | # VERSION = datasets.Version("0.0.0")
34 | def _info(self):
35 | data_info = datasets.DatasetInfo(
36 | description=_DESCRIPTION,
37 | features=datasets.Features(
38 | {
39 | "source": datasets.Value("string"),
40 | "target": datasets.Value("string"),
41 | }
42 | ),
43 | supervised_keys=None,
44 | homepage="#",
45 | citation=_CITATION,
46 | )
47 | return data_info
48 |
49 | def _split_generators(self, dl_manager):
50 | train_path = dl_manager.download_and_extract(_TRAIN_DOWNLOAD_URL)
51 | val_path = dl_manager.download_and_extract(_VAL_DOWNLOAD_URL)
52 | return [
53 | datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": train_path}),
54 | datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"filepath": val_path}),
55 | ]
56 |
57 | def _generate_examples(self, filepath):
58 | with open(filepath, encoding='utf-8') as f:
59 | for id_, row in enumerate(f):
60 | data = json.loads(row)
61 | yield id_, {
62 | "source": data["text"],
63 | "target": data["label"],
64 | }
65 |
--------------------------------------------------------------------------------
/use_tpu_tutorial/images/apply_success_email.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/use_tpu_tutorial/images/apply_success_email.png
--------------------------------------------------------------------------------
/use_tpu_tutorial/images/second_apply_email.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/use_tpu_tutorial/images/second_apply_email.png
--------------------------------------------------------------------------------
/use_tpu_tutorial/images/tpu_workflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangcongcong123/ttt/8d5888a3c26f563f28c996e959a06e351ed58b56/use_tpu_tutorial/images/tpu_workflow.png
--------------------------------------------------------------------------------
/use_tpu_tutorial/run_train_test.py:
--------------------------------------------------------------------------------
1 | from transformers import TFAutoModelWithLMHead, AutoTokenizer
2 | from datasets import load_dataset, load_metric
3 | import json, os
4 | from tqdm import tqdm
5 | import tensorflow as tf
6 | import transformers
7 |
8 | transformers.logging.set_verbosity_info()
9 | logger = transformers.logging.get_logger()
10 |
11 | from ttt import check_output_path, dictionize_single_dataset, save_and_check_if_early_stop, get_args, T2TTrainer, get_strategy, add_filehandler_for_logger, get_existing_cks
12 |
13 |
14 | def get_dataset(data, tag="train", return_raw_inputs=False):
15 | actual_max_src_length = data["source_lengths"].numpy().max()
16 | actual_max_tgt_length = data["target_lengths"].numpy().max()
17 | logger.info(f"actual_max_src_length ({tag}) = {actual_max_src_length}")
18 | logger.info(f"actual_max_tgt_length ({tag}) = {actual_max_tgt_length}")
19 | features = {
20 | x: data[x].to_tensor(default_value=tokenizer.pad_token_id, shape=[None, actual_max_src_length]) for
21 | x in ['input_ids', 'attention_mask']} # padding here, in memory padding
22 | features.update({"decoder_attention_mask": data["decoder_attention_mask"].to_tensor(
23 | default_value=tokenizer.pad_token_id, shape=[None, actual_max_tgt_length])})
24 | raw_inputs = (features, data["labels"].to_tensor(default_value=tokenizer.pad_token_id, shape=[None, actual_max_tgt_length]))
25 | # there are some compability concerns here, we rename the input names here to be consistent with T2TTrainer.train()
26 | tmp_inputs = dictionize_single_dataset(raw_inputs, tag=tag)
27 | x, y = tmp_inputs[f"x_{tag}"], tmp_inputs[f"y_{tag}"]
28 | dataset = tf.data.Dataset.from_tensor_slices((x, y))
29 | if return_raw_inputs:
30 | return dataset, raw_inputs
31 | return dataset
32 |
33 |
34 | def convert_to_features(example_batch, args, tokenizer):
35 | encoded_source = tokenizer(example_batch["source"], padding=True, truncation=True,
36 | max_length=args.max_src_length)
37 | encoded_target = tokenizer(example_batch["target"], padding=True, truncation=True,
38 | max_length=args.max_tgt_length)
39 | source_lengths = [len(encoded_source["input_ids"][0])] * len(encoded_source["input_ids"])
40 | target_lengths = [len(encoded_target["input_ids"][0])] * len(encoded_target["input_ids"])
41 |
42 | encoded_source.update(
43 | {"labels": encoded_target["input_ids"], "source_lengths": source_lengths, "target_lengths": target_lengths,
44 | "decoder_attention_mask": encoded_target["attention_mask"]})
45 | return encoded_source
46 |
47 |
48 | def evaluate(args, logger, model, tokenizer, eval_dataset, steps=0, tag="epoch", is_test=False, eval_length=None):
49 | gts = []
50 | preds = []
51 | if eval_length is not None:
52 | eval_steps = eval_length
53 | else:
54 | eval_steps = tf.data.experimental.cardinality(eval_dataset).numpy()
55 | logger.info(f"start evaluating at {tag}={steps}")
56 | for inputs, labels in tqdm(eval_dataset, total=eval_steps, desc="evaluating..."):
57 | predictions = model.generate(input_ids=inputs["source_input_ids"],
58 | attention_mask=inputs["source_attention_mask"],
59 | max_length=args.max_tgt_length)
60 | pred = [tokenizer.decode(ids) for ids in predictions]
61 | gt = [tokenizer.decode(ids) for ids in labels["target_input_ids"]]
62 | preds.extend(pred)
63 | gts.extend(gt)
64 |
65 | metrics_fn = load_metric("cls_metric.py", "short")
66 | metrics = metrics_fn.compute(predictions=preds, references=gts)
67 |
68 | logger.info(f"val_cls_report: {json.dumps(metrics, indent=2)}")
69 | eval_score = metrics[args.eval_on]
70 | logger.info(f"val_{args.eval_on}_score: {eval_score}")
71 |
72 | is_early_stop = False
73 |
74 | if not is_test:
75 | is_early_stop = save_and_check_if_early_stop(eval_score, args, logger, model, tokenizer, steps=steps, tag=tag, from_tf=True)
76 |
77 | return {"eval_scores": metrics, "preds": preds, "is_early_stop": is_early_stop}
78 |
79 |
80 | if __name__ == '__main__':
81 | args = get_args()
82 | # check what args are available and their default values
83 | logger.info(f"args: {json.dumps(args.__dict__, indent=2)}")
84 | ############### customize args
85 | args.use_gpu = True
86 | # args.use_tpu = True
87 | # args.tpu_address = "x.x.x.x"
88 | # use tensorboard for logging
89 | args.use_tb = True
90 |
91 | # model configuration
92 | args.model_select = "t5-small"
93 | args.max_src_length = 256
94 | args.max_tgt_length = 10
95 | args.per_device_train_batch_size = 16
96 | args.eval_batch_size = 32
97 | # load data from a customized data loading script
98 | args.dataset_name = "covid_data.py, default"
99 | # any one from TASKS_SUPPORT (check:ttt/args.py)
100 | args.log_steps = 400
101 | args.eval_batch_size = 32
102 | args.per_device_train_batch_size = 8
103 |
104 | # any one from LR_SCHEDULER_SUPPORT (check:ttt/args.py)
105 | args.scheduler = "warmuplinear"
106 | args.lr = 5e-5
107 | # use tf.keras.optimizers.Adam optimizer by default in train()
108 |
109 | args.do_train = True
110 | args.do_eval = True
111 | args.do_test = True
112 |
113 | # what to evaluated if the validation set and an evaluation callback function are passed to the T2TTrainer's train method
114 | args.eval_on = "acc"
115 | # how many checkpoints to keep based on args.log_steps if args.do_train = True
116 | args.keep_ck_num = 3
117 | # use the best on validation set as the checkpoint on test set evaluation if args.do_test = True
118 | args.ck_index_select = 0
119 | ############### end customize args
120 | # construct the output path argument to save everything to this path
121 | args.output_path = os.path.join("tmp", f"{args.model_select}_covid_info")
122 | check_output_path(args.output_path, force=True)
123 |
124 | tokenizer = AutoTokenizer.from_pretrained(args.model_select)
125 | dataset = load_dataset(*args.dataset_name.split(", "))
126 | # use num_proc = 6 can give 6x speedup ideally as compared to 1 proc, which is really good stuff for tokenizing many examples
127 | # this is the main reason why using HF's datasets instead of torch.Dataset
128 | encoded = dataset.map(convert_to_features, batched=True, fn_kwargs={"args": args, "tokenizer": tokenizer}, num_proc=6)
129 | columns = ['input_ids', "source_lengths", "target_lengths", 'attention_mask', 'labels', 'decoder_attention_mask']
130 | encoded.set_format(type='tensorflow', columns=columns)
131 |
132 | if args.do_train:
133 | add_filehandler_for_logger(args.output_path, logger, out_name="train")
134 | strategy = get_strategy(args, logger)
135 | with strategy.scope():
136 | # from_pt to aovid repeated downloading
137 | model = TFAutoModelWithLMHead.from_pretrained(args.model_select, from_pt=True)
138 | train_dataset = get_dataset(encoded["train"], tag="train")
139 | val_dataset = None
140 | if "validation" in encoded:
141 | val_dataset = get_dataset(encoded["validation"], tag="eval")
142 | trainer = T2TTrainer(args, logger)
143 | trainer.train(model, strategy, tokenizer, train_dataset=train_dataset, eval_dataset=val_dataset, evaluate_fn=evaluate, verbose=True)
144 |
145 | # we want the testing is independent of the training as much as possible
146 | # so that it is okay to do test when args.do_train = False and checkpoints already exist
147 | if args.do_test:
148 | test_set = "test"
149 | if test_set in encoded:
150 | add_filehandler_for_logger(args.output_path, logger, out_name="test")
151 | sorted_indices, index2path = get_existing_cks(args.output_path, return_best_ck=False)
152 | if args.ck_index_select < 0:
153 | model_path = index2path[sorted_indices[args.ck_index_select]]
154 | else:
155 | bests = [name for name in os.listdir(args.output_path) if name.startswith("best")]
156 | if bests != []:
157 | model_path = os.path.join(args.output_path, bests[0])
158 | else:
159 | model_path = index2path[sorted_indices[args.ck_index_select]]
160 | model = TFAutoModelWithLMHead.from_pretrained(model_path)
161 | logger.info(f"-------------------eval and predict on {test_set} set-------------------")
162 | test_dataset = get_dataset(encoded[test_set])
163 | test_dataset = test_dataset.batch(args.eval_batch_size)
164 | eval_dict = evaluate(args, logger, model, tokenizer, test_dataset, is_test=True)
165 | else:
166 | raise ValueError(f"Not found {test_set} for evaluation")
--------------------------------------------------------------------------------
19 |
20 |
21 | Figure 1
22 |
30 |
31 |
32 | Figure 2
33 |
41 |
42 |
43 | Figure 3
44 |