├── .gitignore
├── LICENSE
├── README.md
├── datasets
├── __init__.py
├── collate_functions.py
├── data_sampling.py
├── mrc_ner_dataset.py
├── mrpc_dataset.py
├── mrpc_processor.py
├── squad_dataset.py
├── tnews_dataset.py
└── truncate_dataset.py
├── loss
├── __init__.py
├── dice_loss.py
└── focal_loss.py
├── metrics
├── __init__.py
├── classification_acc_f1.py
├── functional
│ ├── __init__.py
│ ├── cls_acc_f1.py
│ ├── ner_span_f1.py
│ └── squad
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── eval.sh
│ │ ├── evaluate_v1.py
│ │ ├── evaluate_v2.py
│ │ └── postprocess_predication.py
├── mrc_ner_span_f1.py
└── squad_em_f1.py
├── models
├── __init__.py
├── bert_classification.py
├── bert_qa.py
├── bert_query_ner.py
├── classifier.py
└── model_config.py
├── requirements.txt
├── scripts
├── download_ckpt.sh
├── glue_mrpc
│ ├── bert_base_ce.sh
│ ├── bert_base_dice.sh
│ ├── bert_base_focal.sh
│ ├── bert_large_ce.sh
│ ├── bert_large_dice.sh
│ ├── bert_large_focal.sh
│ └── eval.sh
├── mrc_squad1
│ ├── bert_base_ce.sh
│ ├── bert_base_dice.sh
│ ├── bert_base_focal.sh
│ ├── bert_large_ce.sh
│ ├── bert_large_dice.sh
│ ├── bert_large_focal.sh
│ ├── eval_pred_file.sh
│ └── eval_saved_ckpt.sh
├── ner_enconll03
│ ├── bert_dice.sh
│ └── bert_focal.sh
├── ner_enontonotes5
│ ├── bert_dice.sh
│ └── bert_focal.sh
├── ner_zhmsra
│ ├── bert_dice.sh
│ └── bert_focal.sh
├── ner_zhonto4
│ ├── bert_dice.sh
│ └── bert_focal.sh
├── prepare_ckpt.sh
├── prepare_mrpc_data.sh
└── textcl_tnews
│ ├── bert_dice.sh
│ └── bert_focal.sh
├── tasks
├── glue
│ ├── download_glue_data.py
│ ├── evaluate_models.py
│ ├── evaluate_predictions.py
│ ├── mrpc_dev_ids.tsv
│ ├── process_mrpc.py
│ └── train.py
├── mrc_ner
│ ├── data_preprocess
│ │ ├── file_utils.py
│ │ ├── label_utils.py
│ │ └── query_map.py
│ ├── eval.sh
│ ├── evaluate.py
│ ├── generate_mrc_dataset.py
│ ├── train.py
│ └── train.sh
├── pos
│ └── train.py
├── squad
│ ├── evaluate_models.py
│ ├── evaluate_predictions.py
│ └── train.py
└── tnews
│ ├── train.py
│ └── train.sh
├── tests
├── count_length_autotokenizer.py
└── count_length_glue.py
└── utils
├── __init__.py
├── get_parser.py
└── random_seed.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Xcode
2 | *.DS_Store
3 | .idea/*
4 | # Logs
5 | logs/*
6 | tasks/pos/*
7 | #
8 | experiments/*
9 | log/*
10 |
11 | # Byte-compiled / optimized / DLL files
12 | __pycache__/
13 | *.py[cod]
14 | *$py.class
15 |
16 | # C extensions
17 | *.so
18 |
19 | # backup files
20 | bk
21 |
22 | # Distribution / packaging
23 | .Python
24 | build/
25 | develop-eggs/
26 | dist/
27 | downloads/
28 | eggs/
29 | .eggs/
30 | lib/
31 | lib64/
32 | parts/
33 | sdist/
34 | var/
35 | wheels/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 | MANIFEST
40 |
41 | # PyInstaller
42 | # Usually these files are written by a python script from a template
43 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
44 | *.manifest
45 | *.spec
46 |
47 | # Installer logs
48 | pip-log.txt
49 | pip-delete-this-directory.txt
50 |
51 | # Unit test / coverage reports
52 | htmlcov/
53 | .tox/
54 | .coverage
55 | .coverage.*
56 | .cache
57 | nosetests.xml
58 | coverage.xml
59 | *.cover
60 | .hypothesis/
61 | .pytest_cache/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 | db.sqlite3
71 |
72 | # Flask stuff:
73 | instance/
74 | .webassets-cache
75 |
76 | # Scrapy stuff:
77 | .scrapy
78 |
79 | # Sphinx documentation
80 | docs/_build/
81 |
82 | # PyBuilder
83 | target/
84 |
85 | # Jupyter Notebook
86 | .ipynb_checkpoints
87 |
88 | # pyenv
89 | .python-version
90 |
91 | # celery beat schedule file
92 | celerybeat-schedule
93 |
94 | # SageMath parsed files
95 | *.sage.py
96 |
97 | # Environments
98 | .env
99 | .venv
100 | env/
101 | venv/
102 | ENV/
103 | env.bak/
104 | venv.bak/
105 |
106 | # Spyder project settings
107 | .spyderproject
108 | .spyproject
109 |
110 | # Rope project settings
111 | .ropeproject
112 |
113 | # mkdocs documentation
114 | /site
115 |
116 | # mypy
117 | .mypy_cache/
118 |
119 | # Do Not push origin intermediate logging files
120 | *.log
121 | *.out
122 |
123 | # mac book
124 | .DS_Store
125 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Dice Loss for NLP Tasks
2 |
3 | This repository contains code for [Dice Loss for Data-imbalanced NLP Tasks](https://arxiv.org/pdf/1911.02855.pdf) at ACL2020.
4 |
5 | ## Setup
6 |
7 | - Install Package Dependencies
8 |
9 | The code was tested in `Python 3.6.9+` and `Pytorch 1.7.1`.
10 | If you are working on ubuntu GPU machine with CUDA 10.1, please run the following command to setup environment.
11 | ```bash
12 | $ virtualenv -p /usr/bin/python3.6 venv
13 | $ source venv/bin/activate
14 | $ pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
15 | $ pip install -r requirements.txt
16 | ```
17 |
18 | - Download BERT Model Checkpoints
19 |
20 | Before running the repo you must download the `BERT-Base` and `BERT-Large` checkpoints from [here](https://github.com/google-research/bert#pre-trained-models) and unzip it to some directory `$BERT_DIR`.
21 | Then convert original TensorFlow checkpoints for BERT to a PyTorch saved file by running `bash scripts/prepare_ckpt.sh `.
22 |
23 | ## Apply Dice-Loss to NLP Tasks
24 |
25 | In this repository, we apply dice loss to four NLP tasks, including
26 | 1. machine reading comprehension
27 | 2. paraphrase identification task
28 | 3. named entity recognition
29 | 4. text classification
30 |
31 | ### 1. Machine Reading Comprehension
32 |
33 | ***Datasets***
34 |
35 | We take SQuAD 1.1 as an example.
36 | Before training, you should download a copy of the data from [here](https://rajpurkar.github.io/SQuAD-explorer/).
37 | And move the SQuAD 1.1 train `train-v1.1.json` and dev file `dev-v1.1.json` to the directory `$DATA_DIR`.
38 |
39 | ***Train***
40 |
41 | We choose BERT as the backbone.
42 | During training, the task trainer `BertForQA` will automatically evaluate on dev set every `$val_check_interval` epoch,
43 | and save the dev predictions into files called `$OUTPUT_DIR/predictions__.json` and `$OUTPUT_DIR/nbest_predictions__.json`.
44 |
45 | Run `scripts/mrc_squad1/bert__.sh` to reproduce our experimental results.
46 | The variable `` should take the value of `[base, large]`.
47 | The variable `` should take the value of `[bce, focal, dice]` which denotes fine-tuning `BERT-Base` with `binary cross entropy loss`, `focal loss`, `dice loss` , respectively.
48 |
49 | * Run `bash scripts/mrc_squad1/bert_base_focal.sh` to start training. After training, run `bash scripts/mrc_squad1/eval_pred_file.sh $DATA_DIR $OUTPUT_DIR` for focal loss.
50 |
51 | * Run `bash scripts/mrc_squad1/bert_base_dice.sh` to start training. After training, run `bash scripts/mrc_squad1/eval_pred_file.sh $DATA_DIR $OUTPUT_DIR` for dice loss.
52 |
53 |
54 | ***Evaluate***
55 |
56 | To evaluate a model checkpoint, please run
57 | ```bash
58 | python3 tasks/squad/evaluate_models.py \
59 | --gpus="1" \
60 | --path_to_model_checkpoint $OUTPUT_DIR/epoch=2.ckpt \
61 | --eval_batch_size
62 | ```
63 | After evaluation, prediction results `predictions_dev.json` and `nbest_predictions_dev.json` can be found in `$OUTPUT_DIR`
64 |
65 | To evaluate saved predictions, please run
66 | ```bash
67 | python3 tasks/squad/evaluate_predictions.py
68 | ```
69 |
70 | ### 2. Paraphrase Identification Task
71 |
72 | ***Datasets***
73 |
74 | We use MRPC (GLUE Version) as an example.
75 | Before running experiments, you should download and save the processed dataset files to `$DATA_DIR`.
76 |
77 | Run `bash scripts/prepare_mrpc_data.sh $DATA_DIR` to download and process datasets for MPRC (GLUE Version) task.
78 |
79 | ***Train***
80 |
81 | Please run `scripts/glue_mrpc/bert__.sh` to train and evaluate on the dev set every `$val_check_interval` epoch.
82 | After training, the task trainer evaluates on the test set with the best checkpoint which achieves the highest F1-score on the dev set.
83 | The variable `` should take the value of `[base, large]`.
84 | The variable `` should take the value of `[focal, dice]` which denotes fine-tuning `BERT` with `focal loss`, `dice loss` , respectively.
85 |
86 | * Run `bash scripts/glue_mrpc/bert_large_focal.sh` for focal loss.
87 |
88 | * Run `bash scripts/glue_mrpc/bert_large_dice.sh` for dice loss.
89 |
90 | The evaluation results on the dev and test set are saved at `$OUTPUT_DIR/eval_result_log.txt` file.
91 | The intermediate model checkpoints are saved at most `$max_keep_ckpt` times.
92 |
93 | ***Evaluate***
94 |
95 | To evaluate a model checkpoint on test set, please run
96 | ```bash
97 | bash scripts/glue_mrpc/eval.sh \
98 | $OUTPUT_DIR \
99 | epoch=*.ckpt
100 | ```
101 |
102 | ### 3. Named Entity Recognition
103 |
104 | For NER, we use MRC-NER model as the backbone.
105 | Processed datasets and model architecture can be found [here](https://arxiv.org/pdf/1910.11476.pdf).
106 |
107 | ***Train***
108 |
109 | Please run `scripts//bert_.sh` to train and evaluate on the dev set every `$val_check_interval` epoch.
110 | After training, the task trainer evaluates on the test set with the best checkpoint.
111 | The variable `` should take the value of `[ner_enontonotes5, ner_zhmsra, ner_zhonto4]`.
112 | The variable `` should take the value of `[focal, dice]` which denotes fine-tuning `BERT` with `focal loss`, `dice loss` , respectively.
113 |
114 | For Chinese MSRA,
115 | * Run `scripts/ner_zhmsra/bert_focal.sh` for focal loss.
116 |
117 | * Run `scripts/ner_zhmsra/bert_dice.sh` for dice loss.
118 |
119 | For Chinese OntoNotes4,
120 | * Run `scripts/ner_zhonto4/bert_focal.sh` for focal loss.
121 |
122 | * Run `scripts/ner_zhonto4/bert_dice.sh` for dice loss.
123 |
124 | For English CoNLL03,
125 | * Run `scritps/ner_enconll03/bert_focal.sh`. After training, you will get 93.08 Span-F1 on the test set.
126 |
127 | * Run `scripts/ner_enconll03/bert_dice.sh`. After training, you will get 93.21 Span-F1 on the test set.
128 |
129 | For English OntoNotes5,
130 | * Run `scripts/ner_enontonotes5/bert_focal.sh`. After training, you will get 91.12 Span-F1 on the test set.
131 |
132 | * Run `scripts/ner_enontonotes5/bert_dice.sh`. After training, you will get 92.01 Span-F1 on the test set.
133 |
134 | ***Evaluate***
135 |
136 | To evaluate a model checkpoint, please run
137 | ```bash
138 | CUDA_VISIBLE_DEVICES=0 python3 ${REPO_PATH}/tasks/mrc_ner/evaluate.py \
139 | --gpus="1" \
140 | --path_to_model_checkpoint $OUTPUT_DIR/epoch=2.ckpt
141 | ```
142 |
143 | ### 4. Text Classification
144 |
145 | ***Datasets***
146 |
147 | We use TNews (Chinese Text Classification) as an example.
148 | Before running experiments, you should [download](https://storage.googleapis.com/cluebenchmark/tasks/tnews_public.zip) and save the processed dataset files to `$DATA_DIR`.
149 |
150 | ***Train***
151 |
152 | We choose BERT as the backbone.
153 | Please run `scripts/textcl_tnews/bert_.sh` to train and evaluate on the dev set every `$val_check_interval` epoch.
154 | The variable `` should take the value of `[focal, dice]` which denotes fine-tuning `BERT` with `focal loss`, `dice loss` , respectively.
155 |
156 | * Run `bash scripts/textcl_tnews/bert_focal.sh` for focal loss.
157 |
158 | * Run `bash scripts/textcl_tnews/bert_dice.sh` for dice loss.
159 |
160 | The intermediate model checkpoints are saved at most `$max_keep_ckpt` times.
161 |
162 |
163 | ## Citation
164 |
165 | If you find this repository useful , please cite the following:
166 |
167 | ```tex
168 | @article{li2019dice,
169 | title={Dice loss for data-imbalanced NLP tasks},
170 | author={Li, Xiaoya and Sun, Xiaofei and Meng, Yuxian and Liang, Junjun and Wu, Fei and Li, Jiwei},
171 | journal={arXiv preprint arXiv:1911.02855},
172 | year={2019}
173 | }
174 | ```
175 |
176 | ## Contact
177 |
178 | xxiaoyali [AT] gmail.com OR xiaoya_li [AT] shannonai.com
179 |
180 | Any discussions, suggestions and questions are welcome!
181 |
182 |
183 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShannonAI/dice_loss_for_NLP/d437bb999185535df46fdb74d1f2f57161331b44/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/collate_functions.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | import torch
6 | from typing import List
7 |
8 |
9 | def collate_to_max_length(batch: List[List[torch.Tensor]]) -> List[torch.Tensor]:
10 | """
11 | pad to maximum length of this batch
12 | Args:
13 | batch: a batch of samples, each contains a list of field data(Tensor):
14 | tokens, attention_mask, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, label_idx
15 | Returns:
16 | output: list of field batched data, which shape is [batch, max_length]
17 | """
18 | batch_size = len(batch)
19 | max_length = max(x[0].shape[0] for x in batch)
20 | output = []
21 |
22 | for field_idx in range(7):
23 | pad_output = torch.full([batch_size, max_length], 0, dtype=batch[0][field_idx].dtype)
24 | for sample_idx in range(batch_size):
25 | data = batch[sample_idx][field_idx]
26 | pad_output[sample_idx][: data.shape[0]] = data
27 | output.append(pad_output)
28 |
29 | pad_match_labels = torch.zeros([batch_size, max_length, max_length], dtype=torch.long)
30 | for sample_idx in range(batch_size):
31 | data = batch[sample_idx][7]
32 | pad_match_labels[sample_idx, : data.shape[1], : data.shape[1]] = data
33 | output.append(pad_match_labels)
34 | output.append(torch.stack([x[8] for x in batch]))
35 | if len(batch[0]) == 9:
36 | return output
37 |
38 | output.append(torch.stack([x[9] for x in batch]))
39 | return output
40 |
41 |
--------------------------------------------------------------------------------
/datasets/data_sampling.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | import random
6 |
7 | def sample_positive_and_negative_by_ratio(positive_data_lst, negative_data_lst, ratio=0.5):
8 | num_negative_examples = int(len(positive_data_lst) * ratio)
9 |
10 | random.shuffle(negative_data_lst)
11 | truncated_negative_data_lst = random.sample(negative_data_lst, num_negative_examples)
12 | all_data_lst = positive_data_lst + truncated_negative_data_lst
13 | # need to use random data sampler
14 | return all_data_lst
15 |
16 |
17 | def undersample_majority_classes(data_lst, label_lst):
18 | pass
19 |
20 |
21 | def oversample_minority_classes(data_lst, sampling_strategy=None):
22 | collect_data_by_label = {}
23 |
24 | for data_item in data_lst:
25 | data_item_label = data_item["label"]
26 | if data_item_label not in collect_data_by_label.keys():
27 | collect_data_by_label[data_item_label] = [data_item]
28 | else:
29 | collect_data_by_label[data_item_label].append(data_item)
30 |
31 | count_data_by_label = {key: len(value) for key, value in collect_data_by_label.items()}
--------------------------------------------------------------------------------
/datasets/mrc_ner_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: mrc_ner_dataset.py
5 |
6 | import json
7 | import torch
8 | from torch.utils.data import Dataset
9 | from transformers import AutoTokenizer
10 | from datasets.data_sampling import sample_positive_and_negative_by_ratio
11 |
12 |
13 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text, return_subtoken_start=False):
14 | """Returns tokenized answer spans that better match the annotated answer."""
15 | doc_tokens = [str(tmp) for tmp in doc_tokens]
16 | answer_tokens = tokenizer.encode(orig_answer_text, add_special_tokens=False)
17 | tok_answer_text = " ".join([str(tmp) for tmp in answer_tokens])
18 | for new_start in range(input_start, input_end + 1):
19 | for new_end in range(input_end, new_start - 1, -1):
20 | text_span = " ".join(doc_tokens[new_start : (new_end+1)])
21 | if text_span == tok_answer_text:
22 | if not return_subtoken_start:
23 | return (new_start, new_end)
24 | tokens = tokenizer.convert_ids_to_tokens(doc_tokens[new_start: (new_end + 1)])
25 | if "##" not in tokens[-1]:
26 | return (new_start, new_end)
27 | else:
28 | for idx in range(len(tokens)-1, -1, -1):
29 | if "##" not in tokens[idx]:
30 | new_end = new_end - (len(tokens)-1 - idx)
31 | return (new_start, new_end)
32 |
33 | return (input_start, input_end)
34 |
35 |
36 | class MRCNERDataset(Dataset):
37 | """
38 | MRC NER Dataset
39 | Args:
40 | json_path: path to mrc-ner style json
41 | tokenizer: BertTokenizer
42 | max_length: int, max length of query+context
43 | possible_only: if True, only use possible samples that contain answer for the query/context
44 | is_chinese: is chinese dataset
45 | """
46 | def __init__(self, json_path, tokenizer: AutoTokenizer, max_length: int = 512, possible_only=False, is_chinese=False,
47 | pad_to_maxlen=False, negative_sampling=False, prefix="train", data_sign="zh_onto", do_lower_case=False,
48 | pred_answerable=True):
49 | self.all_data = json.load(open(json_path, encoding="utf-8"))
50 | self.tokenzier = tokenizer
51 | self.max_length = max_length
52 | self.do_lower_case = do_lower_case
53 | self.label2idx = {value:key for key, value in enumerate(MRCNERDataset.get_labels(data_sign))}
54 |
55 | if prefix == "train" and negative_sampling:
56 | neg_data_items = [x for x in self.all_data if not x["start_position"]]
57 | pos_data_items = [x for x in self.all_data if x["start_position"]]
58 | self.all_data = sample_positive_and_negative_by_ratio(pos_data_items, neg_data_items)
59 | elif prefix == "train" and possible_only:
60 | self.all_data = [
61 | x for x in self.all_data if x["start_position"]
62 | ]
63 | else:
64 | pass
65 |
66 | self.is_chinese = is_chinese
67 | self.pad_to_maxlen = pad_to_maxlen
68 | self.pred_answerable = pred_answerable
69 |
70 | def __len__(self):
71 | return len(self.all_data)
72 |
73 | def __getitem__(self, item):
74 | """
75 | Args:
76 | item: int, idx
77 | Returns:
78 | tokens: tokens of query + context, [seq_len]
79 | token_type_ids: token type ids, 0 for query, 1 for context, [seq_len]
80 | start_labels: start labels of NER in tokens, [seq_len]
81 | end_labels: end labelsof NER in tokens, [seq_len]
82 | label_mask: label mask, 1 for counting into loss, 0 for ignoring. shape of [seq_len]. 1 for no-subword context tokens. 0 for query tokens and [CLS] [SEP] tokens.
83 | match_labels: match labels, [seq_len, seq_len]
84 | sample_idx: sample id
85 | label_idx: label id
86 |
87 | """
88 | data = self.all_data[item]
89 | tokenizer = self.tokenzier
90 | label_idx = torch.tensor(self.label2idx[data["entity_label"]], dtype=torch.long)
91 |
92 | if self.is_chinese:
93 | query = "".join(data["query"].strip().split())
94 | context = "".join(data["context"].strip().split())
95 | else:
96 | query = data["query"]
97 | context = data["context"]
98 |
99 | start_positions = data["start_position"]
100 | end_positions = data["end_position"]
101 |
102 | query_context_tokens = tokenizer.encode_plus(query, context,
103 | add_special_tokens=True,
104 | max_length=self.max_length,
105 | return_overflowing_tokens=True,
106 | return_token_type_ids=True)
107 |
108 | if tokenizer.pad_token_id in query_context_tokens["input_ids"]:
109 | non_padded_ids = query_context_tokens["input_ids"][: query_context_tokens["input_ids"].index(tokenizer.pad_token_id)]
110 | else:
111 | non_padded_ids = query_context_tokens["input_ids"]
112 |
113 | non_pad_tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
114 | first_sep_token = non_pad_tokens.index("[SEP]")
115 | end_sep_token = len(non_pad_tokens) - 1
116 | new_start_positions = []
117 | new_end_positions = []
118 | if len(start_positions) != 0:
119 | for start_index, end_index in zip(start_positions, end_positions):
120 | if self.is_chinese:
121 | answer_text_span = " ".join(context[start_index: end_index+1])
122 | else:
123 | answer_text_span = " ".join(context.split(" ")[start_index: end_index+1])
124 | new_start, new_end = _improve_answer_span(query_context_tokens["input_ids"], first_sep_token, end_sep_token, self.tokenzier, answer_text_span)
125 | new_start_positions.append(new_start)
126 | new_end_positions.append(new_end)
127 | else:
128 | new_start_positions = start_positions
129 | new_end_positions = end_positions
130 |
131 | # clip out-of-boundary entity positions.
132 | new_start_positions = [start_pos for start_pos in new_start_positions if start_pos < self.max_length]
133 | new_end_positions = [end_pos for end_pos in new_end_positions if end_pos < self.max_length]
134 |
135 | tokens = query_context_tokens["input_ids"]
136 | token_type_ids = query_context_tokens['token_type_ids']
137 | # token_type_ids -> 0 for query tokens and 1 for context tokens.
138 | attention_mask = query_context_tokens['attention_mask']
139 | start_labels = [(1 if idx in new_start_positions else 0) for idx in range(len(tokens))]
140 | end_labels = [(1 if idx in new_end_positions else 0) for idx in range(len(tokens))]
141 | label_mask = [1 if token_type_ids[token_idx] == 1 else 0 for token_idx in range(len(tokens))]
142 | start_label_mask = label_mask.copy()
143 | end_label_mask = label_mask.copy()
144 |
145 | seq_len = len(tokens)
146 | match_labels = torch.zeros([seq_len, seq_len], dtype=torch.long)
147 | for start, end in zip(new_start_positions, new_end_positions):
148 | if start >= seq_len or end >= seq_len:
149 | continue
150 | match_labels[start, end] = 1
151 |
152 | if self.pred_answerable:
153 | answerable_label = 1 if len(new_start_positions) != 0 else 0
154 | return [torch.tensor(tokens, dtype=torch.long),
155 | torch.tensor(attention_mask, dtype=torch.long),
156 | torch.tensor(token_type_ids, dtype=torch.long),
157 | torch.tensor(start_labels, dtype=torch.long),
158 | torch.tensor(end_labels, dtype=torch.long),
159 | torch.tensor(start_label_mask, dtype=torch.long),
160 | torch.tensor(end_label_mask, dtype=torch.long),
161 | match_labels,
162 | label_idx,
163 | torch.tensor([answerable_label], dtype=torch.long)]
164 |
165 | return [torch.tensor(tokens, dtype=torch.long),
166 | torch.tensor(attention_mask, dtype=torch.long),
167 | torch.tensor(token_type_ids, dtype=torch.long),
168 | torch.tensor(start_labels, dtype=torch.long),
169 | torch.tensor(end_labels, dtype=torch.long),
170 | torch.tensor(start_label_mask, dtype=torch.long),
171 | torch.tensor(end_label_mask, dtype=torch.long),
172 | match_labels,
173 | label_idx]
174 |
175 | @classmethod
176 | def get_labels(cls, data_sign):
177 | """gets the list of labels for this data set."""
178 | if data_sign == "zh_onto":
179 | return ["GPE", "LOC", "PER", "ORG"]
180 | elif data_sign == "zh_msra":
181 | return ["NS", "NR", "NT"]
182 | elif data_sign == "en_onto":
183 | return ["LAW", "EVENT", "CARDINAL", "FAC", "TIME", "DATE", "ORDINAL", "ORG", "QUANTITY", \
184 | "PERCENT", "WORK_OF_ART", "LOC", "LANGUAGE", "NORP", "MONEY", "PERSON", "GPE", "PRODUCT"]
185 | elif data_sign == "en_conll03":
186 | return ["ORG", "PER", "LOC", "MISC"]
187 | return ["0", "1"]
188 |
189 |
190 |
--------------------------------------------------------------------------------
/datasets/mrpc_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: mrpc_dataset.py
5 | # description:
6 | # dataset processor for semantic textual similarity task MRPC
7 | # train: 3669, dev: 1726, test: 1726
8 |
9 | from collections import namedtuple
10 | from typing import Dict, Optional, List, Union
11 |
12 | import torch
13 | from torch.utils.data import TensorDataset
14 | from torch.utils.data.dataset import Dataset
15 |
16 | from transformers import BertTokenizer
17 | from datasets.mrpc_processor import MRPCProcessor, MrpcDataExample
18 |
19 | MrpcDataFeature = namedtuple("MrpcDataFeature", ["input_ids", "attention_mask", "token_type_ids", "label"])
20 |
21 |
22 | class MRPCDataset(Dataset):
23 | def __init__(self,
24 | args,
25 | tokenizer: BertTokenizer,
26 | mode: Optional[str] = "train",
27 | cache_dir: Optional[str] = None,
28 | debug: Optional[bool] = False):
29 |
30 | self.args = args
31 | self.tokenizer = tokenizer
32 | self.mode = mode
33 | self.processor = MRPCProcessor(self.args.data_dir)
34 | self.debug = debug
35 | self.cache_dir = cache_dir
36 | self.max_seq_length = self.args.max_seq_length
37 |
38 | if self.mode == "dev":
39 | self.examples = self.processor.get_dev_examples()
40 | elif self.mode == "test":
41 | self.examples = self.processor.get_test_examples()
42 | else:
43 | self.examples = self.processor.get_train_examples()
44 |
45 | self.features, self.dataset = mrpc_convert_examples_to_features(
46 | examples=self.examples,
47 | tokenizer=tokenizer,
48 | max_length=self.max_seq_length,
49 | label_list=MRPCProcessor.get_labels(),
50 | is_training= mode == "train",)
51 |
52 | def __len__(self):
53 | return len(self.features)
54 |
55 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
56 | # convert to Tensors and build dataset
57 | feature = self.features[i]
58 |
59 | input_ids = torch.tensor(feature.input_ids, dtype=torch.long)
60 | attention_mask = torch.tensor(feature.attention_mask, dtype=torch.long)
61 | token_type_ids = torch.tensor(feature.token_type_ids, dtype=torch.long)
62 | label = torch.tensor(feature.label, dtype=torch.long)
63 |
64 | inputs = {
65 | "input_ids": input_ids,
66 | "token_type_ids": token_type_ids,
67 | "attention_mask": attention_mask,
68 | "label": label
69 | }
70 |
71 | return inputs
72 |
73 |
74 | def mrpc_convert_examples_to_features(examples: Union[List[MrpcDataExample]],
75 | tokenizer: BertTokenizer,
76 | max_length: int = 256,
77 | label_list: Union[List[str]] = None,
78 | is_training: bool = False,):
79 | """
80 | Description:
81 | GLUE Version
82 | - test.tsv -> index #1 ID #2 ID #1 String #2 String
83 | - train/dev.tsv -> Quality #1 ID #2 ID #1 String #2 String
84 | """
85 |
86 | label_map = {label: i for i, label in enumerate(label_list)}
87 |
88 | labels = [label_map[example.label] for example in examples]
89 | batch_encoding = tokenizer(
90 | [(example.text_a, example.text_b) for example in examples],
91 | max_length=max_length, padding="max_length", truncation=True
92 | )
93 |
94 | features = []
95 | for i in range(len(examples)):
96 | inputs = {k: batch_encoding[k][i] for k in batch_encoding}
97 |
98 | feature = MrpcDataFeature(**inputs, label=labels[i])
99 | features.append(feature)
100 |
101 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
102 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
103 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
104 | all_label = torch.tensor([f.label for f in features], dtype=torch.long)
105 |
106 | dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_label)
107 | return features, dataset
108 |
109 |
--------------------------------------------------------------------------------
/datasets/mrpc_processor.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file name: mrpc_processor.py
5 | # description:
6 | # code for loading data samples from files.
7 |
8 | import os
9 | import csv
10 | from collections import namedtuple
11 |
12 | MrpcDataExample = namedtuple("DataExample", ["guid", "text_a", "text_b", "label"])
13 |
14 |
15 | class MRPCProcessor:
16 | """
17 | Processor for the MRPC data set.
18 | """
19 | def __init__(self, data_dir):
20 | self.data_dir = data_dir
21 | self.train_file = os.path.join(data_dir, "train.tsv")
22 | self.dev_file = os.path.join(data_dir, "dev.tsv")
23 | # TODO: add test.tsv processing
24 | self.test_file = os.path.join(data_dir, "msr_paraphrase_test.txt")
25 |
26 | def get_train_examples(self, ):
27 | return self._create_examples(self._read_tsv(self.train_file), "train")
28 |
29 | def get_dev_examples(self, ):
30 | return self._create_examples(self._read_tsv(self.dev_file), "dev")
31 |
32 | def get_test_examples(self, ):
33 | return self._create_examples(self._read_tsv(self.test_file), "test")
34 |
35 | def _create_examples(self, lines, set_type):
36 | """create examples for the train/dev/test datasets"""
37 | examples = []
38 | for idx, line in enumerate(lines):
39 | if idx == 0:
40 | continue
41 | guid = "%s-%s" % (set_type, idx)
42 | text_a = line[3]
43 | text_b = line[4]
44 | label = line[0]
45 | examples.append(MrpcDataExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
46 | return examples
47 |
48 | def _read_tsv(self, input_file, quotechar=None):
49 | """reads a tab separated value file."""
50 | with open(input_file, "r", encoding="utf-8") as f:
51 | return list(csv.reader(f, delimiter="\t", quotechar=quotechar))
52 |
53 | @classmethod
54 | def get_labels(cls, ):
55 | """gets the list of labels for this data set."""
56 | return ["0", "1"]
57 |
58 |
59 |
60 | if __name__ == "__main__":
61 | num_labels = MRPCProcessor.get_labels()
62 | print(num_labels)
--------------------------------------------------------------------------------
/datasets/squad_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: squad_dataset.py
5 | # description:
6 | # dataset class for the squad task.
7 | # NOTICE:
8 | # https://github.com/huggingface/transformers/issues/7735
9 | # fast tokenizers don’t currently work with the QA pipeline.
10 |
11 | import os
12 | from typing import Dict, Optional
13 |
14 | import torch
15 | from torch.utils.data.dataset import Dataset
16 |
17 | from transformers import AutoTokenizer
18 | from transformers.data.processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor
19 | from transformers.data.processors.squad import squad_convert_examples_to_features
20 |
21 |
22 | class SquadDataset(Dataset):
23 | def __init__(self,
24 | args,
25 | tokenizer: AutoTokenizer,
26 | mode: Optional[str] = "train",
27 | is_language_sensitive: Optional[bool] = False,
28 | cache_dir: Optional[str] = None,
29 | dataset_format: Optional[str] = "pt",
30 | threads: Optional[int] = 1,
31 | debug: Optional[bool] = False,):
32 |
33 | self.args = args
34 | self.tokenizer = tokenizer
35 | self.is_language_sensitive = is_language_sensitive
36 | self.processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
37 | self.mode = mode
38 | self.debug = debug
39 | self.threads = threads
40 |
41 | self.max_seq_length = self.args.max_seq_length
42 | self.doc_stride = self.args.doc_stride
43 | self.max_query_length = self.args.max_query_length
44 |
45 | # dataset format configurations
46 | self.column_names = ["id", "title", "context", "question", "answers" ]
47 | self.question_column_name = "question" if "question" in self.column_names else self.column_names[0]
48 | self.context_column_name = "context" if "context" in self.column_names else self.column_names[1]
49 | self.answer_column_name = "answers" if "answers" in self.column_names else self.column_names[2]
50 |
51 | # Padding side determines if we do (question|context) or (context|question).
52 | self.pad_on_right = tokenizer.padding_side == "right"
53 |
54 | # load data features from cache or dataset file
55 | version_tag = "v2" if args.version_2_with_negative else "v1"
56 | cached_features_file = os.path.join(
57 | cache_dir if cache_dir is not None else args.data_dir,
58 | "cached_{}_{}_{}_{}".format(
59 | mode,
60 | tokenizer.__class__.__name__,
61 | str(args.max_seq_length),
62 | version_tag
63 | )
64 | )
65 |
66 | self.cached_data_file = cached_features_file
67 |
68 | # Make sure only the first process in distributed training processes the dataset,
69 | # and the others will use the cache.
70 | if os.path.exists(cached_features_file) and not args.overwrite_cache:
71 | self.old_features = torch.load(cached_features_file)
72 |
73 | # legacy cache files have only features,
74 | # which new cache files will have dataset and examples also.
75 | self.features = self.old_features["features"]
76 | self.dataset = self.old_features.get("dataset", None)
77 | self.examples = self.old_features.get("examples", None)
78 |
79 | if self.dataset is None or self.examples is None:
80 | raise ValueError
81 | else:
82 | if self.mode == "dev":
83 | self.examples = self.processor.get_dev_examples(args.data_dir)
84 | else:
85 | self.examples = self.processor.get_train_examples(args.data_dir)
86 |
87 | if self.debug:
88 | print(f"DEBUG INFO -> already load {self.mode} data ...")
89 | print(f"DEBUG INFO -> show 2 EXAMPLES ...")
90 | for idx, data_examples in enumerate(self.examples):
91 | # data_examples should be an object of transformers.data.processors.squad.SquadExample
92 | if idx <= 2:
93 | print(f"DEBUG INFO -> {idx}, {data_examples}")
94 | print(f"{idx} qas_id -> {data_examples.qas_id}")
95 | print(f"{idx} question_text -> {data_examples.question_text}")
96 | print(f"{idx} context_text -> {data_examples.context_text}")
97 | print(f"{idx} answer_text -> {data_examples.answer_text}")
98 | print("-*-"*10)
99 |
100 | self.features, self.dataset = squad_convert_examples_to_features(
101 | examples=self.examples,
102 | tokenizer=tokenizer,
103 | max_seq_length=self.max_seq_length,
104 | doc_stride=self.doc_stride,
105 | max_query_length=self.max_query_length,
106 | is_training= mode == "train",
107 | threads=self.threads,
108 | return_dataset=dataset_format,
109 | )
110 |
111 | torch.save(
112 | {"features": self.features, "dataset": self.dataset, "examples": self.examples},
113 | cached_features_file,)
114 |
115 | def __len__(self):
116 | return len(self.features)
117 |
118 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
119 | # convert to Tensors and build dataset
120 | feature = self.features[i]
121 |
122 | input_ids = torch.tensor(feature.input_ids, dtype=torch.long)
123 | attention_mask = torch.tensor(feature.attention_mask, dtype=torch.long)
124 | token_type_ids = torch.tensor(feature.token_type_ids, dtype=torch.long)
125 | # only for "xlnet", "xlm" models.
126 | # cls_index = torch.tensor(feature.cls_index, dtype=torch.long)
127 | # p_mask = torch.tensor(feature.p_mask, dtype=torch.float)
128 | # is_impossible = torch.tensor(feature.is_impossible, dtype=torch.float)
129 |
130 | inputs = {
131 | "input_ids": input_ids,
132 | "attention_mask": attention_mask,
133 | "token_type_ids": token_type_ids
134 | }
135 |
136 | label_mask = [1] + feature.token_type_ids[1:]
137 |
138 | start_labels = torch.tensor(feature.start_position, dtype=torch.long)
139 | end_labels = torch.tensor(feature.end_position, dtype=torch.long)
140 | label_mask = torch.tensor(label_mask, dtype=torch.long)
141 |
142 | inputs.update({"start_labels": start_labels, "end_labels": end_labels, "label_mask": label_mask})
143 |
144 | if self.mode != "train":
145 | inputs.update({"unique_id": feature.unique_id})
146 |
147 | return inputs
148 |
149 |
150 |
151 |
--------------------------------------------------------------------------------
/datasets/tnews_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: tnews_dataset.py
5 | # Data Example:
6 | # {"label": "113", "label_desc": "news_world", "sentence": "日本虎视眈眈“武力夺岛”, 美军向俄后院开火,普京终不再忍!", "keywords": "普京,北方四岛,安倍,俄罗斯,黑海"}
7 |
8 | import os
9 | import json
10 | import torch
11 | from torch.utils.data import Dataset
12 | from tokenizers import BertWordPieceTokenizer
13 |
14 | class TNewsDataset(Dataset):
15 | def __init__(self, prefix: str = "train", data_dir: str = "", tokenizer: BertWordPieceTokenizer = None, max_length: int = 512):
16 | super().__init__()
17 | self.data_prefix = prefix
18 | self.max_length = max_length
19 | data_file = os.path.join(data_dir, f"{prefix}.json")
20 | with open(data_file, "r", encoding="utf-8") as f:
21 | data_items = f.readlines()
22 | self.data_items = data_items
23 | self.tokenizer = tokenizer
24 | self.labels2id = {value: key for key, value in enumerate(TNewsDataset.get_labels())}
25 |
26 | def __len__(self):
27 | return len(self.data_items)
28 |
29 | def __getitem__(self, idx):
30 | """
31 | Description:
32 | for single sentence task, BERTWordPieceTokenizer will [CLS}++[SEP]
33 | Returns:
34 | input_token_ids, token_type_ids, attention_mask, label_id
35 | """
36 | data_item = self.data_items[idx]
37 | data_item = json.loads(data_item)
38 | label, sentence = data_item["label"], data_item["sentence"]
39 | label_id = self.labels2id[label]
40 | sentence = sentence[: self.max_length-3]
41 | tokenizer_output = self.tokenizer.encode(sentence)
42 |
43 | tokens = tokenizer_output.ids + (self.max_length - len(tokenizer_output.ids)) * [0]
44 | token_type_ids = tokenizer_output.type_ids + (self.max_length - len(tokenizer_output.type_ids)) * [0]
45 | attention_mask = tokenizer_output.attention_mask + (self.max_length - len(tokenizer_output.attention_mask)) * [0]
46 |
47 | input_token_ids = torch.tensor(tokens, dtype=torch.long)
48 | token_type_ids = torch.tensor(token_type_ids, dtype=torch.long)
49 | attention_mask = torch.tensor(attention_mask, dtype=torch.long)
50 | label_id = torch.tensor(label_id, dtype=torch.long)
51 |
52 | return input_token_ids, token_type_ids, attention_mask, label_id
53 |
54 | @classmethod
55 | def get_labels(cls, ):
56 | return ['100', '101', '102', '103', '104', '106', '107', '108', '109', '110', '112', '113', '114', '115', '116']
57 |
58 |
59 |
--------------------------------------------------------------------------------
/datasets/truncate_dataset.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 |
3 | from torch.utils.data import Dataset
4 |
5 |
6 | class TruncateDataset(Dataset):
7 | """Truncate dataset to certain num"""
8 | def __init__(self, dataset: Dataset, max_num: int = 100):
9 | self.dataset = dataset
10 | self.max_num = min(max_num, len(self.dataset))
11 |
12 | def __len__(self):
13 | return self.max_num
14 |
15 | def __getitem__(self, item):
16 | return self.dataset[item]
17 |
18 | def __getattr__(self, item):
19 | """other dataset func"""
20 | return getattr(self.dataset, item)
21 |
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from .dice_loss import DiceLoss
2 | from .focal_loss import FocalLoss
--------------------------------------------------------------------------------
/loss/dice_loss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: dice_loss.py
5 | # description:
6 | # implementation of dice loss for NLP tasks.
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from torch import Tensor
12 | from typing import Optional
13 |
14 |
15 | class DiceLoss(nn.Module):
16 | """
17 | Dice coefficient for short, is an F1-oriented statistic used to gauge the similarity of two sets.
18 | Given two sets A and B, the vanilla dice coefficient between them is given as follows:
19 | Dice(A, B) = 2 * True_Positive / (2 * True_Positive + False_Positive + False_Negative)
20 | = 2 * |A and B| / (|A| + |B|)
21 |
22 | Math Function:
23 | U-NET: https://arxiv.org/abs/1505.04597.pdf
24 | dice_loss(p, y) = 1 - numerator / denominator
25 | numerator = 2 * \sum_{1}^{t} p_i * y_i + smooth
26 | denominator = \sum_{1}^{t} p_i + \sum_{1} ^{t} y_i + smooth
27 | if square_denominator is True, the denominator is \sum_{1}^{t} (p_i ** 2) + \sum_{1} ^{t} (y_i ** 2) + smooth
28 | V-NET: https://arxiv.org/abs/1606.04797.pdf
29 | Args:
30 | smooth (float, optional): a manual smooth value for numerator and denominator.
31 | square_denominator (bool, optional): [True, False], specifies whether to square the denominator in the loss function.
32 | with_logits (bool, optional): [True, False], specifies whether the input tensor is normalized by Sigmoid/Softmax funcs.
33 | ohem_ratio: max ratio of positive/negative, defautls to 0.0, which means no ohem.
34 | alpha: dsc alpha
35 | Shape:
36 | - input: (*)
37 | - target: (*)
38 | - mask: (*) 0,1 mask for the input sequence.
39 | - Output: Scalar loss
40 | Examples:
41 | >>> loss = DiceLoss(with_logits=True, ohem_ratio=0.1)
42 | >>> input = torch.FloatTensor([2, 1, 2, 2, 1])
43 | >>> input.requires_grad=True
44 | >>> target = torch.LongTensor([0, 1, 0, 0, 0])
45 | >>> output = loss(input, target)
46 | >>> output.backward()
47 | """
48 | def __init__(self,
49 | smooth: Optional[float] = 1e-4,
50 | square_denominator: Optional[bool] = False,
51 | with_logits: Optional[bool] = True,
52 | ohem_ratio: float = 0.0,
53 | alpha: float = 0.0,
54 | reduction: Optional[str] = "mean",
55 | index_label_position=True) -> None:
56 | super(DiceLoss, self).__init__()
57 |
58 | self.reduction = reduction
59 | self.with_logits = with_logits
60 | self.smooth = smooth
61 | self.square_denominator = square_denominator
62 | self.ohem_ratio = ohem_ratio
63 | self.alpha = alpha
64 | self.index_label_position = index_label_position
65 |
66 | def forward(self, input: Tensor, target: Tensor, mask: Optional[Tensor] = None) -> Tensor:
67 | logits_size = input.shape[-1]
68 |
69 | if logits_size != 1:
70 | loss = self._multiple_class(input, target, logits_size, mask=mask)
71 | else:
72 | loss = self._binary_class(input, target, mask=mask)
73 |
74 | if self.reduction == "mean":
75 | return loss.mean()
76 | if self.reduction == "sum":
77 | return loss.sum()
78 | return loss
79 |
80 | def _compute_dice_loss(self, flat_input, flat_target):
81 | flat_input = ((1 - flat_input) ** self.alpha) * flat_input
82 | interection = torch.sum(flat_input * flat_target, -1)
83 | if not self.square_denominator:
84 | loss = 1 - ((2 * interection + self.smooth) /
85 | (flat_input.sum() + flat_target.sum() + self.smooth))
86 | else:
87 | loss = 1 - ((2 * interection + self.smooth) /
88 | (torch.sum(torch.square(flat_input, ), -1) + torch.sum(torch.square(flat_target), -1) + self.smooth))
89 |
90 | return loss
91 |
92 | def _multiple_class(self, input, target, logits_size, mask=None):
93 | flat_input = input
94 | flat_target = F.one_hot(target, num_classes=logits_size).float() if self.index_label_position else target.float()
95 | flat_input = torch.nn.Softmax(dim=1)(flat_input) if self.with_logits else flat_input
96 |
97 | if mask is not None:
98 | mask = mask.float()
99 | flat_input = flat_input * mask
100 | flat_target = flat_target * mask
101 | else:
102 | mask = torch.ones_like(target)
103 |
104 | loss = None
105 | if self.ohem_ratio > 0 :
106 | mask_neg = torch.logical_not(mask)
107 | for label_idx in range(logits_size):
108 | pos_example = target == label_idx
109 | neg_example = target != label_idx
110 |
111 | pos_num = pos_example.sum()
112 | neg_num = mask.sum() - (pos_num - (mask_neg & pos_example).sum())
113 | keep_num = min(int(pos_num * self.ohem_ratio / logits_size), neg_num)
114 |
115 | if keep_num > 0:
116 | neg_scores = torch.masked_select(flat_input, neg_example.view(-1, 1).bool()).view(-1, logits_size)
117 | neg_scores_idx = neg_scores[:, label_idx]
118 | neg_scores_sort, _ = torch.sort(neg_scores_idx, )
119 | threshold = neg_scores_sort[-keep_num + 1]
120 | cond = (torch.argmax(flat_input, dim=1) == label_idx & flat_input[:, label_idx] >= threshold) | pos_example.view(-1)
121 | ohem_mask_idx = torch.where(cond, 1, 0)
122 |
123 | flat_input_idx = flat_input[:, label_idx]
124 | flat_target_idx = flat_target[:, label_idx]
125 |
126 | flat_input_idx = flat_input_idx * ohem_mask_idx
127 | flat_target_idx = flat_target_idx * ohem_mask_idx
128 | else:
129 | flat_input_idx = flat_input[:, label_idx]
130 | flat_target_idx = flat_target[:, label_idx]
131 |
132 | loss_idx = self._compute_dice_loss(flat_input_idx.view(-1, 1), flat_target_idx.view(-1, 1))
133 | if loss is None:
134 | loss = loss_idx
135 | else:
136 | loss += loss_idx
137 | return loss
138 |
139 | else:
140 | for label_idx in range(logits_size):
141 | pos_example = target == label_idx
142 | flat_input_idx = flat_input[:, label_idx]
143 | flat_target_idx = flat_target[:, label_idx]
144 |
145 | loss_idx = self._compute_dice_loss(flat_input_idx.view(-1, 1), flat_target_idx.view(-1, 1))
146 | if loss is None:
147 | loss = loss_idx
148 | else:
149 | loss += loss_idx
150 | return loss
151 |
152 | def _binary_class(self, input, target, mask=None):
153 | flat_input = input.view(-1)
154 | flat_target = target.view(-1).float()
155 | flat_input = torch.sigmoid(flat_input) if self.with_logits else flat_input
156 |
157 | if mask is not None:
158 | mask = mask.float()
159 | flat_input = flat_input * mask
160 | flat_target = flat_target * mask
161 | else:
162 | mask = torch.ones_like(target)
163 |
164 | if self.ohem_ratio > 0:
165 | pos_example = target > 0.5
166 | neg_example = target <= 0.5
167 | mask_neg_num = mask <= 0.5
168 |
169 | pos_num = pos_example.sum() - (pos_example & mask_neg_num).sum()
170 | neg_num = neg_example.sum()
171 | keep_num = min(int(pos_num * self.ohem_ratio), neg_num)
172 |
173 | neg_scores = torch.masked_select(flat_input, neg_example.bool())
174 | neg_scores_sort, _ = torch.sort(neg_scores, )
175 | threshold = neg_scores_sort[-keep_num+1]
176 | cond = (flat_input > threshold) | pos_example.view(-1)
177 | ohem_mask = torch.where(cond, 1, 0)
178 | flat_input = flat_input * ohem_mask
179 | flat_target = flat_target * ohem_mask
180 |
181 | return self._compute_dice_loss(flat_input, flat_target)
182 |
183 | def __str__(self):
184 | return f"Dice Loss smooth:{self.smooth}, ohem: {self.ohem_ratio}, alpha: {self.alpha}"
185 |
186 | def __repr__(self):
187 | return str(self)
188 |
--------------------------------------------------------------------------------
/loss/focal_loss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | from typing import List
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class FocalLoss(nn.Module):
13 | """
14 | Focal loss(https://arxiv.org/pdf/1708.02002.pdf)
15 | Shape:
16 | - input: (N, C)
17 | - target: (N)
18 | - Output: Scalar loss
19 | Examples:
20 | >>> loss = FocalLoss(gamma=2, alpha=[1.0]*7)
21 | >>> input = torch.randn(3, 7, requires_grad=True)
22 | >>> target = torch.empty(3, dtype=torch.long).random_(7)
23 | >>> output = loss(input, target)
24 | >>> output.backward()
25 | """
26 | def __init__(self, gamma=0, alpha: List[float] = None, reduction="none"):
27 | super(FocalLoss, self).__init__()
28 | self.gamma = gamma
29 | self.alpha = alpha
30 | if alpha is not None:
31 | self.alpha = torch.FloatTensor(alpha)
32 | self.reduction = reduction
33 |
34 | def forward(self, input, target):
35 | # [N, 1]
36 | target = target.unsqueeze(-1)
37 | # [N, C]
38 | pt = F.softmax(input, dim=-1)
39 | logpt = F.log_softmax(input, dim=-1)
40 | # [N]
41 | pt = pt.gather(1, target).squeeze(-1)
42 | logpt = logpt.gather(1, target).squeeze(-1)
43 |
44 | if self.alpha is not None:
45 | # [N] at[i] = alpha[target[i]]
46 | at = self.alpha.gather(0, target.squeeze(-1))
47 | logpt = logpt * at
48 |
49 | loss = -1 * (1 - pt) ** self.gamma * logpt
50 | if self.reduction == "none":
51 | return loss
52 | if self.reduction == "mean":
53 | return loss.mean()
54 | return loss.sum()
55 |
56 | @staticmethod
57 | def convert_binary_pred_to_two_dimension(x, is_logits=True):
58 | """
59 | Args:
60 | x: (*): (log) prob of some instance has label 1
61 | is_logits: if True, x represents log prob; otherwhise presents prob
62 | Returns:
63 | y: (*, 2), where y[*, 1] == log prob of some instance has label 0,
64 | y[*, 0] = log prob of some instance has label 1
65 | """
66 | probs = torch.sigmoid(x) if is_logits else x
67 | probs = probs.unsqueeze(-1)
68 | probs = torch.cat([1-probs, probs], dim=-1)
69 | logprob = torch.log(probs+1e-4) # 1e-4 to prevent being rounded to 0 in fp16
70 | return logprob
71 |
72 | def __str__(self):
73 | return f"Focal Loss gamma:{self.gamma}"
74 |
75 | def __repr__(self):
76 | return str(self)
77 |
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShannonAI/dice_loss_for_NLP/d437bb999185535df46fdb74d1f2f57161331b44/metrics/__init__.py
--------------------------------------------------------------------------------
/metrics/classification_acc_f1.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: classification_acc_f1.py
5 | # description:
6 | # compute acc & f1 scores for classification tasks.
7 |
8 | import torch
9 | from pytorch_lightning.metrics.metric import TensorMetric
10 | from metrics.functional.cls_acc_f1 import collect_confusion_matrix, compute_precision_recall_f1_scores
11 |
12 |
13 | class ClassificationF1Metric(TensorMetric):
14 | """
15 | compute acc and f1 scores for text classification tasks.
16 | """
17 | def __init__(self, reduce_group=None, reduce_op=None, num_classes=2, f1_type="micro"):
18 | super(ClassificationF1Metric, self).__init__(name="classification_f1_metric", reduce_group=reduce_group, reduce_op=reduce_op)
19 | self.num_classes = num_classes
20 | self.f1_type = f1_type
21 |
22 | def forward(self, pred_labels, gold_labels):
23 | """
24 | Description:
25 | collect confusion matrix for one batch.
26 | Args:
27 | pred_labels: a tensor in shape of [eval_batch_size]
28 | gold_labels: a tensor in shape if [eval_batch_size]
29 | Returns:
30 | a tensor of [true_positive, false_positive, true_negative, false_negative]
31 | """
32 | confusion_matrix = collect_confusion_matrix(pred_labels, gold_labels, num_classes=self.num_classes)
33 |
34 | return confusion_matrix
35 |
36 |
37 | def compute_f1(self, all_confusion_matrix):
38 | """
39 | Args:
40 | true_positive, false_positive, true_negative, false_negative in ALL CORPUS.
41 | Returns:
42 | four tensors -> acc, precision, recall, f1
43 | """
44 | precision, recall, f1 = compute_precision_recall_f1_scores(all_confusion_matrix, num_classes=self.num_classes, f1_type=self.f1_type)
45 | precision, recall, f1 = torch.tensor(precision, dtype=torch.float), torch.tensor(recall, dtype=torch.float), torch.tensor(f1, dtype=torch.float)
46 | # The metric you returned including Precision, Recall, F1 (e.g., 0.91638) must be a `torch.Tensor` instance.
47 | return precision, recall, f1
--------------------------------------------------------------------------------
/metrics/functional/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShannonAI/dice_loss_for_NLP/d437bb999185535df46fdb74d1f2f57161331b44/metrics/functional/__init__.py
--------------------------------------------------------------------------------
/metrics/functional/cls_acc_f1.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: cls_acc_f1.py
5 | # description:
6 | # compute acc and f1 scores for text classification task.
7 |
8 | import torch
9 | import torch.nn.functional as F
10 |
11 |
12 | def collect_confusion_matrix(y_pred_labels, y_gold_labels, num_classes=2):
13 | """
14 | compute accuracy and f1 scores for text classification task.
15 | Args:
16 | pred_labels: [batch_size] index of labels.
17 | gold_labels: [batch_size] index of labels.
18 | Returns:
19 | A LongTensor composed by [true_positive, false_positive, false_negative]
20 | """
21 | if num_classes <= 0:
22 | raise ValueError
23 |
24 | if num_classes == 1 or num_classes == 2:
25 | num_classes = 1
26 | y_true_onehot = y_gold_labels.bool()
27 | y_pred_onehot = y_pred_labels.bool()
28 | else:
29 | y_true_onehot = F.one_hot(y_gold_labels, num_classes=num_classes)
30 | y_pred_onehot = F.one_hot(y_pred_labels, num_classes=num_classes)
31 |
32 | if num_classes == 1:
33 | y_true_onehot = y_true_onehot.bool()
34 | y_pred_onehot = y_pred_onehot.bool()
35 |
36 | true_positive = (y_true_onehot & y_pred_onehot).long().sum()
37 | false_positive = (y_pred_onehot & ~ y_true_onehot).long().sum()
38 | false_negative = (~ y_pred_onehot & y_true_onehot).long().sum()
39 |
40 | stack_confusion_matrix = torch.stack([true_positive, false_positive, false_negative])
41 | return stack_confusion_matrix
42 |
43 | multi_label_confusion_matrix = []
44 |
45 | for idx in range(num_classes):
46 | index_item = torch.tensor([idx], dtype=torch.long).cuda()
47 | y_true_item_onehot = torch.index_select(y_true_onehot, 1, index_item)
48 | y_pred_item_onehot = torch.index_select(y_pred_onehot, 1, index_item)
49 |
50 | true_sum_item = torch.sum(y_true_item_onehot)
51 | pred_sum_item = torch.sum(y_pred_item_onehot)
52 |
53 | true_positive_item = torch.sum(y_true_item_onehot.multiply(y_pred_item_onehot))
54 |
55 | false_positive_item = pred_sum_item - true_positive_item
56 | false_negative_item = true_sum_item - true_positive_item
57 |
58 | confusion_matrix_item = torch.tensor([true_positive_item, false_positive_item, false_negative_item],
59 | dtype=torch.long)
60 |
61 | multi_label_confusion_matrix.append(confusion_matrix_item)
62 |
63 | stack_confusion_matrix = torch.stack(multi_label_confusion_matrix, dim=0)
64 |
65 | return stack_confusion_matrix
66 |
67 | def compute_precision_recall_f1_scores(confusion_matrix, num_classes=2, f1_type="micro"):
68 | """
69 | compute precision, recall and f1 scores.
70 | Description:
71 | f1: 2 * precision * recall / (precision + recall)
72 | - precision = true_positive / true_positive + false_positive
73 | - recall = true_positive / true_positive + false_negative
74 | Returns:
75 | precision, recall, f1
76 | """
77 |
78 | if num_classes == 2 or num_classes == 1:
79 | confusion_matrix = confusion_matrix.to("cpu").numpy().tolist()
80 | true_positive, false_positive, false_negative = tuple(confusion_matrix)
81 | precision = true_positive / (true_positive + false_positive + 1e-10)
82 | recall = true_positive / (true_positive + false_negative + 1e-10)
83 | f1 = 2 * precision * recall / (precision + recall + 1e-10)
84 | precision, recall, f1 = round(precision, 5), round(recall, 5), round(f1, 5)
85 | return precision, recall, f1
86 |
87 | if f1_type == "micro":
88 | precision, recall, f1 = micro_precision_recall_f1(confusion_matrix, num_classes)
89 | elif f1_type == "macro":
90 | precision, recall, f1 = macro_precision_recall_f1(confusion_matrix)
91 | else:
92 | raise ValueError
93 |
94 | return precision, recall, f1
95 |
96 |
97 | def micro_precision_recall_f1(all_confusion_matrix, num_classes):
98 | precision_lst = []
99 | recall_lst = []
100 |
101 | all_confusion_matrix_lst = all_confusion_matrix.to("cpu").numpy().tolist()
102 | for idx in range(num_classes):
103 | matrix_item = all_confusion_matrix_lst[idx]
104 | true_positive_item, false_positive_item, false_negative_item = tuple(matrix_item)
105 |
106 | precision_item = true_positive_item / (true_positive_item + false_positive_item + 1e-10)
107 | recall_item = true_positive_item / (true_positive_item + false_negative_item + 1e-10)
108 |
109 | precision_lst.append(precision_item)
110 | recall_lst.append(recall_item)
111 |
112 | avg_precision = sum(precision_lst) / num_classes
113 | avg_recall = sum(recall_lst) / num_classes
114 | avg_f1 = 2 * avg_recall * avg_precision / (avg_recall + avg_precision + 1e-10)
115 |
116 | avg_precision, avg_recall, avg_f1 = round(avg_precision, 5), round(avg_recall, 5), round(avg_f1, 5)
117 |
118 | return avg_precision, avg_recall, avg_f1
119 |
120 |
121 | def macro_precision_recall_f1(all_confusion_matrix, ):
122 | confusion_matrix = torch.sum(all_confusion_matrix, 1, keepdim=False)
123 | confusion_matrix_lst = confusion_matrix.to("cpu").numpy().tolist()
124 | true_positive, false_positive, false_negative = tuple(confusion_matrix_lst)
125 |
126 | precision = true_positive / (true_positive + false_positive + 1e-10)
127 | recall = true_positive / (true_positive + false_negative + 1e-10)
128 | f1 = 2 * precision * recall / (precision + recall + 1e-10)
129 |
130 | precision, recall, f1 = round(precision, 5), round(recall, 5), round(f1, 5)
131 |
132 | return precision, recall, f1
--------------------------------------------------------------------------------
/metrics/functional/ner_span_f1.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: ner_span_f1.py
5 |
6 |
7 | import torch
8 | from typing import Tuple, List
9 |
10 |
11 | class Tag(object):
12 | def __init__(self, term, tag, begin, end):
13 | self.term = term
14 | self.tag = tag
15 | self.begin = begin
16 | self.end = end
17 |
18 | def to_tuple(self):
19 | return tuple([self.term, self.begin, self.end])
20 |
21 | def __str__(self):
22 | return str({key: value for key, value in self.__dict__.items()})
23 |
24 | def __repr__(self):
25 | return str({key: value for key, value in self.__dict__.items()})
26 |
27 |
28 | def bmes_decode(char_label_list: List[Tuple[str, str]]) -> List[Tag]:
29 | """
30 | decode inputs to tags
31 | Args:
32 | char_label_list: list of tuple (word, bmes-tag)
33 | Returns:
34 | tags
35 | Examples:
36 | >>> x = [("Hi", "O"), ("Beijing", "S-LOC")]
37 | >>> bmes_decode(x)
38 | [{'term': 'Beijing', 'tag': 'LOC', 'begin': 1, 'end': 2}]
39 | [(1, 2), (4, 6), (7, 8), (9, 10)]
40 | """
41 | idx = 0
42 | length = len(char_label_list)
43 | tags = []
44 | while idx < length:
45 | term, label = char_label_list[idx]
46 | current_label = label[0]
47 |
48 | if current_label == "O":
49 | idx += 1
50 | continue
51 | if current_label == "S":
52 | tags.append(Tag(term, label[2:], idx, idx + 1))
53 | idx += 1
54 | continue
55 | if current_label == "B":
56 | end = idx + 1
57 | while end + 1 < length and char_label_list[end][1][0] == "M":
58 | end += 1
59 |
60 | if end == len(char_label_list):
61 | entity = "".join(char_label_list[i][0] for i in range(idx, end))
62 | tags.append(Tag(entity, label[2:], idx, end))
63 | idx = end
64 | continue
65 |
66 | if char_label_list[end][1][0] == "E": # end with E
67 | entity = "".join(char_label_list[i][0] for i in range(idx, end + 1))
68 | tags.append(Tag(entity, label[2:], idx, end + 1))
69 | idx = end + 1
70 | else: # end with M/B
71 | entity = "".join(char_label_list[i][0] for i in range(idx, end))
72 | tags.append(Tag(entity, label[2:], idx, end))
73 | idx = end
74 | continue
75 | else:
76 | idx += 1
77 | continue
78 | return tags
79 |
80 |
81 | def bmes_decode_flat_query_span_f1(start_preds, end_preds, match_logits, start_end_label_mask, start_labels, end_labels, match_labels, answerable_pred=None):
82 | sum_true_positive, sum_false_positive, sum_false_negative = 0, 0, 0
83 | start_preds, end_preds, match_logits, start_end_label_mask = start_preds.to("cpu").numpy().tolist(), end_preds.to("cpu").numpy().tolist(), match_logits.to("cpu").numpy().tolist(), start_end_label_mask.to("cpu").numpy().tolist()
84 | start_labels, end_labels, match_labels = start_labels.to("cpu").numpy().tolist(), end_labels.to("cpu").numpy().tolist(), match_labels.to("cpu").numpy().tolist()
85 | batch_size, seq_len = len(start_labels), len(start_labels[0])
86 |
87 | if answerable_pred is not None:
88 | answerable_pred = answerable_pred.to("cpu").numpy().tolist()
89 | else:
90 | answerable_pred = [1] * batch_size
91 |
92 | for start_pred_item, end_pred_item, match_logits_item, start_end_label_mask_item, start_label_item, end_label_item, match_label_item, answerable_item in \
93 | zip(start_preds, end_preds, match_logits, start_end_label_mask, start_labels, end_labels, match_labels, answerable_pred):
94 | if answerable_item == 0:
95 | start_pred_item = [0] * len(start_pred_item)
96 | end_pred_item = [0] * len(end_pred_item)
97 |
98 | pred_entity_lst = extract_flat_spans(start_pred_item, end_pred_item, match_logits_item, start_end_label_mask_item,)
99 | gold_entity_lst = extract_flat_spans(start_label_item, end_label_item, match_label_item, start_end_label_mask_item)
100 |
101 | true_positive_item, false_positive_item, false_negative_item = count_confusion_matrix(pred_entity_lst, gold_entity_lst)
102 | sum_true_positive += true_positive_item
103 | sum_false_negative += false_negative_item
104 | sum_false_positive += false_positive_item
105 |
106 | batch_confusion_matrix = torch.tensor([sum_true_positive, sum_false_positive, sum_false_negative], dtype=torch.long)
107 | return batch_confusion_matrix
108 |
109 | def count_confusion_matrix(pred_entities, gold_entities):
110 | true_positive, false_positive, false_negative = 0, 0, 0
111 | for span_item in pred_entities:
112 | if span_item in gold_entities:
113 | true_positive += 1
114 | gold_entities.remove(span_item)
115 | else:
116 | false_positive += 1
117 |
118 | # these entities are not predicted.
119 | for span_item in gold_entities:
120 | false_negative += 1
121 |
122 | return true_positive, false_positive, false_negative
123 |
124 | def extract_flat_spans(start_pred, end_pred, match_pred, label_mask):
125 | """
126 | Extract flat-ner spans from start/end/match logits
127 | Args:
128 | start_pred: [seq_len], 1/True for start, 0/False for non-start
129 | end_pred: [seq_len, 2], 1/True for end, 0/False for non-end
130 | match_pred: [seq_len, seq_len], 1/True for match, 0/False for non-match
131 | label_mask: [seq_len], 1 for valid boundary.
132 | Returns:
133 | tags: list of tuple (start, end)
134 | Examples:
135 | >>> start_pred = [0, 1]
136 | >>> end_pred = [0, 1]
137 | >>> match_pred = [[0, 0], [0, 1]]
138 | >>> label_mask = [1, 1]
139 | >>> extract_flat_spans(start_pred, end_pred, match_pred, label_mask)
140 | """
141 | pseudo_tag = "TAG"
142 | pseudo_input = "a"
143 |
144 | bmes_labels = ["O"] * len(start_pred)
145 | start_positions = [idx for idx, tmp in enumerate(start_pred) if tmp and label_mask[idx]]
146 | end_positions = [idx for idx, tmp in enumerate(end_pred) if tmp and label_mask[idx]]
147 |
148 | for start_item in start_positions:
149 | bmes_labels[start_item] = f"B-{pseudo_tag}"
150 | for end_item in end_positions:
151 | if end_item in start_positions:
152 | bmes_labels[end_item] = f"B-{pseudo_tag}"
153 | # bmes_labels[end_item] = f"S-{pseudo_tag}"
154 | else:
155 | bmes_labels[end_item] = f"E-{pseudo_tag}"
156 |
157 | for tmp_start in start_positions:
158 | tmp_end = [tmp for tmp in end_positions if tmp >= tmp_start]
159 | if len(tmp_end) == 0:
160 | continue
161 | else:
162 | tmp_end = min(tmp_end)
163 | # if match_pred[tmp_start][tmp_end]:
164 | if tmp_start != tmp_end:
165 | for i in range(tmp_start+1, tmp_end):
166 | bmes_labels[i] = f"M-{pseudo_tag}"
167 | else:
168 | bmes_labels[tmp_end] = f"S-{pseudo_tag}"
169 |
170 | tags = bmes_decode([(pseudo_input, label) for label in bmes_labels])
171 |
172 | return [(tag.begin, tag.end) for tag in tags]
173 |
174 |
175 | def remove_overlap(spans):
176 | """
177 | remove overlapped spans greedily for flat-ner
178 | Args:
179 | spans: list of tuple (start, end), which means [start, end] is a ner-span
180 | Returns:
181 | spans without overlap
182 | """
183 | output = []
184 | occupied = set()
185 | for start, end in spans:
186 | if any(x for x in range(start, end+1)) in occupied:
187 | continue
188 | output.append((start, end))
189 | for x in range(start, end + 1):
190 | occupied.add(x)
191 | return output
192 |
--------------------------------------------------------------------------------
/metrics/functional/squad/README.md:
--------------------------------------------------------------------------------
1 | ## Evaluate the saved models using the SQuAD2.0 official evaluation.
2 | - To run the evaluation, use `python3 evaluate_v2.py `
3 | - Sample Prediction File (On Dev 2.0) can be found `./tests/data/sample_prediction_v2.0.json`
--------------------------------------------------------------------------------
/metrics/functional/squad/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShannonAI/dice_loss_for_NLP/d437bb999185535df46fdb74d1f2f57161331b44/metrics/functional/squad/__init__.py
--------------------------------------------------------------------------------
/metrics/functional/squad/eval.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: eval.sh
5 | # description:
6 | # bash for evaluate prediction files.
7 | # Example:
8 | # bash eval.sh mrc-with-dice-loss/metrics/functional/squad/evaluate_v1.py /data/dev-v1.1.json predictions_10_10.json
9 |
10 | EVAL_SCRIPT=$1
11 | DATA_FILE=$2
12 | PRED_FILE=$3
13 |
14 | python3 ${EVAL_SCRIPT} ${DATA_FILE} ${PRED_FILE} 1
--------------------------------------------------------------------------------
/metrics/functional/squad/evaluate_v1.py:
--------------------------------------------------------------------------------
1 | """ Official evaluation script for v1.1 of the SQuAD dataset. """
2 | from __future__ import print_function
3 | from collections import Counter
4 | import string
5 | import re
6 | import json
7 | import sys
8 |
9 |
10 | def normalize_answer(s):
11 | """Lower text and remove punctuation, articles and extra whitespace."""
12 | def remove_articles(text):
13 | return re.sub(r'\b(a|an|the)\b', ' ', text)
14 |
15 | def white_space_fix(text):
16 | return ' '.join(text.split())
17 |
18 | def remove_punc(text):
19 | exclude = set(string.punctuation)
20 | return ''.join(ch for ch in text if ch not in exclude)
21 |
22 | def lower(text):
23 | return text.lower()
24 |
25 | return white_space_fix(remove_articles(remove_punc(lower(s))))
26 |
27 |
28 | def f1_score(prediction, ground_truth):
29 | prediction_tokens = normalize_answer(prediction).split()
30 | ground_truth_tokens = normalize_answer(ground_truth).split()
31 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
32 | num_same = sum(common.values())
33 | if num_same == 0:
34 | return 0
35 | precision = 1.0 * num_same / len(prediction_tokens)
36 | recall = 1.0 * num_same / len(ground_truth_tokens)
37 | f1 = (2 * precision * recall) / (precision + recall)
38 | return f1
39 |
40 |
41 | def exact_match_score(prediction, ground_truth):
42 | return (normalize_answer(prediction) == normalize_answer(ground_truth))
43 |
44 |
45 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
46 | scores_for_ground_truths = []
47 | for ground_truth in ground_truths:
48 | score = metric_fn(prediction, ground_truth)
49 | scores_for_ground_truths.append(score)
50 | return max(scores_for_ground_truths)
51 |
52 |
53 | def evaluate(dataset, predictions, stdout=False):
54 | f1 = exact_match = total = 0
55 | for article in dataset:
56 | for paragraph in article['paragraphs']:
57 | for qa in paragraph['qas']:
58 | total += 1
59 | if qa['id'] not in predictions:
60 | message = 'Unanswered question ' + qa['id'] + \
61 | ' will receive score 0.'
62 | print(message, file=sys.stderr)
63 | continue
64 | ground_truths = list(map(lambda x: x['text'], qa['answers']))
65 | prediction = predictions[qa['id']]
66 | exact_match += metric_max_over_ground_truths(
67 | exact_match_score, prediction, ground_truths)
68 | f1 += metric_max_over_ground_truths(
69 | f1_score, prediction, ground_truths)
70 |
71 | exact_match = 100.0 * exact_match / total
72 | f1 = 100.0 * f1 / total
73 |
74 | result_dict = {"exact_match": exact_match, "f1": f1}
75 | if stdout:
76 | print(json.dumps(result_dict))
77 | return result_dict
78 |
79 |
80 | def main(data_file, prediction_file, stdout=False):
81 | with open(data_file, "r") as datafile:
82 | dataset_json = json.load(datafile)
83 | dataset = dataset_json["data"]
84 |
85 | with open(prediction_file, "r") as predfile:
86 | predictions = json.load(predfile)
87 |
88 | eval_result = evaluate(dataset, predictions, stdout=stdout)
89 | return eval_result
90 |
91 |
92 | if __name__ == '__main__':
93 | data_file = sys.argv[1]
94 | prediction_file = sys.argv[2]
95 | stdout_sign = sys.argv[3]
96 | if stdout_sign == "1":
97 | stdout = True
98 | else:
99 | stdout = False
100 | main(data_file, prediction_file, stdout=stdout)
101 |
--------------------------------------------------------------------------------
/metrics/functional/squad/evaluate_v2.py:
--------------------------------------------------------------------------------
1 | """Official evaluation script for SQuAD version 2.0.
2 |
3 | In addition to basic functionality, we also compute additional statistics and
4 | plot precision-recall curves if an additional na_prob.json file is provided.
5 | This file is expected to map question ID's to the model's predicted probability
6 | that a question is unanswerable.
7 | """
8 | import argparse
9 | import collections
10 | import json
11 | import os
12 | import re
13 | import string
14 | import sys
15 |
16 |
17 | def parse_args():
18 | parser = argparse.ArgumentParser('Official evaluation script for SQuAD version 2.0.')
19 | parser.add_argument('data_file', metavar='data.json', help='Input data JSON file.')
20 | parser.add_argument('pred_file', metavar='pred.json', help='Model predictions.')
21 | parser.add_argument('--out-file', '-o', metavar='eval.json',
22 | help='Write accuracy metrics to file (default is stdout).')
23 | parser.add_argument('--na-prob-file', '-n', metavar='na_prob.json',
24 | help='Model estimates of probability of no answer.')
25 | parser.add_argument('--na-prob-thresh', '-t', type=float, default=1.0,
26 | help='Predict "" if no-answer probability exceeds this (default = 1.0).')
27 | parser.add_argument('--out-image-dir', '-p', metavar='out_images', default=None,
28 | help='Save precision-recall curves to directory.')
29 | parser.add_argument('--verbose', '-v', action='store_true')
30 | if len(sys.argv) == 1:
31 | parser.print_help()
32 | sys.exit(1)
33 | return parser.parse_args()
34 |
35 |
36 | def make_qid_to_has_ans(dataset):
37 | qid_to_has_ans = {}
38 | for article in dataset:
39 | for p in article['paragraphs']:
40 | for qa in p['qas']:
41 | qid_to_has_ans[qa['id']] = bool(qa['answers'])
42 | return qid_to_has_ans
43 |
44 |
45 | def normalize_answer(s):
46 | """Lower text and remove punctuation, articles and extra whitespace."""
47 |
48 | def remove_articles(text):
49 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
50 | return re.sub(regex, ' ', text)
51 |
52 | def white_space_fix(text):
53 | return ' '.join(text.split())
54 |
55 | def remove_punc(text):
56 | exclude = set(string.punctuation)
57 | return ''.join(ch for ch in text if ch not in exclude)
58 |
59 | def lower(text):
60 | return text.lower()
61 |
62 | return white_space_fix(remove_articles(remove_punc(lower(s))))
63 |
64 |
65 | def get_tokens(s):
66 | if not s: return []
67 | return normalize_answer(s).split()
68 |
69 |
70 | def compute_exact(a_gold, a_pred):
71 | return int(normalize_answer(a_gold) == normalize_answer(a_pred))
72 |
73 |
74 | def compute_f1(a_gold, a_pred):
75 | gold_toks = get_tokens(a_gold)
76 | pred_toks = get_tokens(a_pred)
77 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
78 | num_same = sum(common.values())
79 | if len(gold_toks) == 0 or len(pred_toks) == 0:
80 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
81 | return int(gold_toks == pred_toks)
82 | if num_same == 0:
83 | return 0
84 | precision = 1.0 * num_same / len(pred_toks)
85 | recall = 1.0 * num_same / len(gold_toks)
86 | f1 = (2 * precision * recall) / (precision + recall)
87 | return f1
88 |
89 |
90 | def get_raw_scores(dataset, preds):
91 | exact_scores = {}
92 | f1_scores = {}
93 | for article in dataset:
94 | for p in article['paragraphs']:
95 | for qa in p['qas']:
96 | qid = qa['id']
97 | gold_answers = [a['text'] for a in qa['answers']
98 | if normalize_answer(a['text'])]
99 | if not gold_answers:
100 | # For unanswerable questions, only correct answer is empty string
101 | gold_answers = ['']
102 | if qid not in preds:
103 | print('Missing prediction for %s' % qid)
104 | continue
105 | a_pred = preds[qid]
106 | # Take max over all gold answers
107 | exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers)
108 | f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
109 | return exact_scores, f1_scores
110 |
111 |
112 | def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
113 | new_scores = {}
114 | for qid, s in scores.items():
115 | pred_na = na_probs[qid] > na_prob_thresh
116 | if pred_na:
117 | new_scores[qid] = float(not qid_to_has_ans[qid])
118 | else:
119 | new_scores[qid] = s
120 | return new_scores
121 |
122 |
123 | def make_eval_dict(exact_scores, f1_scores, qid_list=None):
124 | if not qid_list:
125 | total = len(exact_scores)
126 | return collections.OrderedDict([
127 | ('exact', 100.0 * sum(exact_scores.values()) / total),
128 | ('f1', 100.0 * sum(f1_scores.values()) / total),
129 | ('total', total),
130 | ])
131 | else:
132 | total = len(qid_list)
133 | return collections.OrderedDict([
134 | ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
135 | ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
136 | ('total', total),
137 | ])
138 |
139 |
140 | def merge_eval(main_eval, new_eval, prefix):
141 | for k in new_eval:
142 | main_eval['%s_%s' % (prefix, k)] = new_eval[k]
143 |
144 |
145 | def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, out_image=None, title=None):
146 | qid_list = sorted(na_probs, key=lambda k: na_probs[k])
147 | true_pos = 0.0
148 | cur_p = 1.0
149 | cur_r = 0.0
150 | precisions = [1.0]
151 | recalls = [0.0]
152 | avg_prec = 0.0
153 | for i, qid in enumerate(qid_list):
154 | if qid_to_has_ans[qid]:
155 | true_pos += scores[qid]
156 | cur_p = true_pos / float(i + 1)
157 | cur_r = true_pos / float(num_true_pos)
158 | if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i + 1]]:
159 | # i.e., if we can put a threshold after this point
160 | avg_prec += cur_p * (cur_r - recalls[-1])
161 | precisions.append(cur_p)
162 | recalls.append(cur_r)
163 | return {'ap': 100.0 * avg_prec}
164 |
165 |
166 | def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs,
167 | qid_to_has_ans, out_image_dir):
168 | if out_image_dir and not os.path.exists(out_image_dir):
169 | os.makedirs(out_image_dir)
170 | num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
171 | if num_true_pos == 0:
172 | return
173 | pr_exact = make_precision_recall_eval(
174 | exact_raw, na_probs, num_true_pos, qid_to_has_ans,
175 | out_image=os.path.join(out_image_dir, 'pr_exact.png'),
176 | title='Precision-Recall curve for Exact Match score')
177 | pr_f1 = make_precision_recall_eval(
178 | f1_raw, na_probs, num_true_pos, qid_to_has_ans,
179 | out_image=os.path.join(out_image_dir, 'pr_f1.png'),
180 | title='Precision-Recall curve for F1 score')
181 | oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
182 | pr_oracle = make_precision_recall_eval(
183 | oracle_scores, na_probs, num_true_pos, qid_to_has_ans,
184 | out_image=os.path.join(out_image_dir, 'pr_oracle.png'),
185 | title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)')
186 | merge_eval(main_eval, pr_exact, 'pr_exact')
187 | merge_eval(main_eval, pr_f1, 'pr_f1')
188 | merge_eval(main_eval, pr_oracle, 'pr_oracle')
189 |
190 |
191 | def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
192 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
193 | cur_score = num_no_ans
194 | best_score = cur_score
195 | best_thresh = 0.0
196 | qid_list = sorted(na_probs, key=lambda k: na_probs[k])
197 | for i, qid in enumerate(qid_list):
198 | if qid not in scores: continue
199 | if qid_to_has_ans[qid]:
200 | diff = scores[qid]
201 | else:
202 | if preds[qid]:
203 | diff = -1
204 | else:
205 | diff = 0
206 | cur_score += diff
207 | if cur_score > best_score:
208 | best_score = cur_score
209 | best_thresh = na_probs[qid]
210 | return 100.0 * best_score / len(scores), best_thresh
211 |
212 |
213 | def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
214 | best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
215 | best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
216 | main_eval['best_exact'] = best_exact
217 | main_eval['best_exact_thresh'] = exact_thresh
218 | main_eval['best_f1'] = best_f1
219 | main_eval['best_f1_thresh'] = f1_thresh
220 |
221 |
222 | def evaluate(data_file, pred_file, na_prob_file, na_prob_thresh, out_file):
223 | with open(data_file) as f:
224 | dataset_json = json.load(f)
225 | dataset = dataset_json['data']
226 | with open(pred_file) as f:
227 | preds = json.load(f)
228 | if na_prob_file:
229 | with open(na_prob_file) as f:
230 | na_probs = json.load(f)
231 | else:
232 | na_probs = {k: 0.0 for k in preds}
233 | qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False
234 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
235 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
236 | exact_raw, f1_raw = get_raw_scores(dataset, preds)
237 | exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, na_prob_thresh)
238 | f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, na_prob_thresh)
239 | out_eval = make_eval_dict(exact_thresh, f1_thresh)
240 | if has_ans_qids:
241 | has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
242 | merge_eval(out_eval, has_ans_eval, 'HasAns')
243 | if no_ans_qids:
244 | no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
245 | merge_eval(out_eval, no_ans_eval, 'NoAns')
246 | if na_prob_file:
247 | find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans)
248 | if out_file:
249 | with open(out_file, 'w') as f:
250 | json.dump(out_eval, f)
251 | else:
252 | print(json.dumps(out_eval, indent=2))
253 |
254 |
255 | if __name__ == '__main__':
256 | data_file = ""
257 | pred_file = ""
258 | na_prob_file = ""
259 | na_prob_thresh = ""
260 | out_file = ""
261 | evaluate(data_file, pred_file, na_prob_file, na_prob_thresh, out_file)
262 |
--------------------------------------------------------------------------------
/metrics/mrc_ner_span_f1.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | from pytorch_lightning.metrics.metric import TensorMetric
6 | from metrics.functional.ner_span_f1 import bmes_decode_flat_query_span_f1
7 |
8 |
9 | class MRCNERSpanF1(TensorMetric):
10 | """
11 | Query Span F1
12 | Args:
13 | flat: is flat-ner
14 | """
15 | def __init__(self, reduce_group=None, reduce_op=None, flat=False):
16 | super(MRCNERSpanF1, self).__init__(name="query_span_f1",
17 | reduce_group=reduce_group,
18 | reduce_op=reduce_op)
19 | self.flat = flat
20 |
21 | def forward(self, start_preds, end_preds, match_logits, start_end_label_mask, start_labels, end_labels, match_labels, answerable_pred=None):
22 | return bmes_decode_flat_query_span_f1(start_preds, end_preds, match_logits, start_end_label_mask, start_labels,
23 | end_labels, match_labels, answerable_pred=answerable_pred)
24 |
25 |
--------------------------------------------------------------------------------
/metrics/squad_em_f1.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # file:
5 | # squad_em_f1.py
6 | # description:
7 | # compute exact match / f1-score for SQuAD task.
8 |
9 | import os
10 | import json
11 | from metrics.functional.squad.postprocess_predication import compute_predictions_logits
12 | from metrics.functional.squad.evaluate_v1 import evaluate as evaluate_squad_v1
13 |
14 | class SquadEvalMetric:
15 | def __init__(self,
16 | n_best_size: int = 20,
17 | max_answer_length: int = 20,
18 | do_lower_case: bool = False,
19 | verbose_logging: bool = False,
20 | version_2_with_negative: bool = False,
21 | null_score_diff_threshold: float = 0,
22 | data_dir: str = "",
23 | output_dir: str = ""):
24 |
25 | self.n_best_size = n_best_size
26 | self.max_answer_length = max_answer_length
27 | self.do_lower_case = do_lower_case
28 | self.verbose_logging = verbose_logging
29 | self.version_2_with_negative = version_2_with_negative
30 | self.null_score_diff_threshold = null_score_diff_threshold
31 |
32 | self.data_dir = data_dir
33 | self.output_dir = output_dir
34 |
35 |
36 | def forward(self, all_examples, all_features, all_results, tokenizer, prefix = "dev", sigmoid=True):
37 | if not self.version_2_with_negative:
38 | with open(os.path.join(self.data_dir, "dev-v1.1.json"), "r") as f:
39 | text_dataset = json.load(f)["data"]
40 | else:
41 | with open(os.path.join(self.data_dir, "dev-v2.0.json"), "r") as f:
42 | text_dataset = json.load(f)["data"]
43 |
44 | output_prediction_file = os.path.join(self.output_dir, "predictions_{}.json".format(prefix))
45 | output_nbest_file = os.path.join(self.output_dir, "nbest_predictions_{}.json".format(prefix))
46 |
47 | if self.version_2_with_negative:
48 | output_null_log_odds_file = os.path.join(self.output_dir, "null_odds_{}.json".format(prefix))
49 | else:
50 | output_null_log_odds_file = None
51 |
52 | all_predictions = compute_predictions_logits(all_examples, all_features, all_results,
53 | self.n_best_size,
54 | self.max_answer_length,
55 | self.do_lower_case,
56 | output_prediction_file,
57 | output_nbest_file,
58 | output_null_log_odds_file,
59 | self.verbose_logging,
60 | self.version_2_with_negative,
61 | self.null_score_diff_threshold,
62 | tokenizer,
63 | sigmoid=sigmoid)
64 | if not self.version_2_with_negative:
65 | eval_results = evaluate_squad_v1(text_dataset, all_predictions)
66 | exact_match, f1 = eval_results["exact_match"], eval_results["f1"]
67 | else:
68 | raise ValueError("Evaluation for SQuAD 2.0 is not Implementation yet")
69 |
70 | return exact_match, f1
71 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShannonAI/dice_loss_for_NLP/d437bb999185535df46fdb74d1f2f57161331b44/models/__init__.py
--------------------------------------------------------------------------------
/models/bert_classification.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: bert_classification.py
5 | # description:
6 | # model for fine-tuning BERT on text classification tasks.
7 |
8 | import torch.nn as nn
9 | from torch import Tensor
10 | from transformers import BertModel, BertPreTrainedModel
11 |
12 | from models.classifier import truncated_normal_
13 | from models.model_config import BertForSequenceClassificationConfig
14 |
15 | class BertForSequenceClassification(BertPreTrainedModel):
16 | """Fine-tune BERT model for text classification."""
17 | def __init__(self, config: BertForSequenceClassificationConfig,):
18 | super(BertForSequenceClassification, self).__init__(config)
19 | self.bert = BertModel(config,)
20 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
21 | self.cls_classifier = nn.Linear(config.hidden_size, config.num_labels)
22 | self.cls_classifier.weight = truncated_normal_(self.cls_classifier.weight, mean=0, std=0.02)
23 | self.init_weights()
24 |
25 | def forward(self, input_ids: Tensor, token_type_ids: Tensor, attention_mask: Tensor):
26 | """
27 | Args:
28 | inputs_ids: input tokens, tensor of shape [batch_size, seq_len].
29 | token_type_ids: 1 for text_b tokens and 0 for text_a tokens. tensor of shape [batch_size, seq_len].
30 | attention_mask: 1 for non-[PAD] tokens and 0 for [PAD] tokens. tensor of shape [batch_size, seq_len].
31 | Returns:
32 | cls_outputs: output logits for the [CLS] token. tensor of shape [batch_size, num_labels].
33 | """
34 | bert_outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
35 | bert_cls_output = bert_outputs[1]
36 | bert_cls_output = self.dropout(bert_cls_output)
37 | cls_logits = self.cls_classifier(bert_cls_output)
38 |
39 | return cls_logits
--------------------------------------------------------------------------------
/models/bert_qa.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: bert_qa.py
5 | # description:
6 | # BERT for question answering task.
7 |
8 |
9 | import torch.nn as nn
10 | from torch import Tensor
11 | from transformers import BertModel, BertPreTrainedModel
12 | from models.classifier import truncated_normal_, BertMLP
13 |
14 |
15 | class BertForQuestionAnswering(BertPreTrainedModel):
16 | """Finetuning Bert Model for the question answering task."""
17 | def __init__(self, config):
18 | super(BertForQuestionAnswering, self).__init__(config)
19 |
20 | self.bert = BertModel(config, add_pooling_layer=False)
21 | if config.multi_layer_classifier:
22 | self.qa_classifier = BertMLP(config)
23 | else:
24 | self.qa_classifier = nn.Linear(config.hidden_size, 2)
25 | self.qa_classifier.weight = truncated_normal_(self.qa_classifier.weight, mean=0, std=0.02)
26 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
27 | self.init_weights()
28 |
29 | def forward(self, input_ids: Tensor, token_type_ids: Tensor, attention_mask: Tensor):
30 | """
31 | Args:
32 | input_ids: Bert input tokens, tensor of shape [batch, seq_len]
33 | token_type_ids: 0 for query, 1 for context, tensor of shape [batch, seq_len]
34 | attention_mask: attention mask, tensor of shape [batch, seq_len]
35 | Returns:
36 | start_logits: start/non-start logits of shape [batch, seq_len]
37 | end_logits: end/non-end logits of shape [batch, seq_len]
38 | """
39 |
40 | bert_outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
41 |
42 | sequence_heatmap = self.dropout(bert_outputs[0]) # [batch, seq_len, hidden]
43 | logits = self.qa_classifier(sequence_heatmap)
44 | start_logits, end_logits = logits.split(1, dim=-1)
45 | start_logits = start_logits.squeeze(-1)
46 | end_logits = end_logits.squeeze(-1)
47 |
48 | return start_logits, end_logits
49 |
50 |
51 |
--------------------------------------------------------------------------------
/models/bert_query_ner.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: bert_mrc_ner.py
5 |
6 | import torch
7 | import torch.nn as nn
8 | from transformers import BertModel, BertPreTrainedModel
9 |
10 | from models.classifier import SpanClassifier, MultiLayerPerceptronClassifier
11 |
12 |
13 | class BertForQueryNER(BertPreTrainedModel):
14 | def __init__(self, config):
15 | super(BertForQueryNER, self).__init__(config)
16 | self.bert = BertModel(config)
17 |
18 | self.construct_entity_span = config.construct_entity_span
19 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
20 |
21 | if self.construct_entity_span == "start_end_match":
22 | self.start_outputs = MultiLayerPerceptronClassifier(hidden_size=config.hidden_size, num_labels=config.num_labels, activate_func=config.activate_func)
23 | self.end_outputs = MultiLayerPerceptronClassifier(hidden_size=config.hidden_size, num_labels=config.num_labels, activate_func=config.activate_func)
24 | self.span_embedding = MultiLayerPerceptronClassifier(hidden_size=config.hidden_size*2, num_labels=1, activate_func=config.activate_func)
25 | elif self.construct_entity_span == "match":
26 | self.span_nn = SpanClassifier(config.hidden_size, config.hidden_dropout_prob)
27 | elif self.construct_entity_span == "start_and_end":
28 | self.start_outputs = MultiLayerPerceptronClassifier(hidden_size=config.hidden_size, num_labels=config.num_labels, activate_func=config.activate_func)
29 | self.end_outputs = MultiLayerPerceptronClassifier(hidden_size=config.hidden_size, num_labels=config.num_labels, activate_func=config.activate_func)
30 | elif self.construct_entity_span == "start_end":
31 | self.start_end_outputs = MultiLayerPerceptronClassifier(hidden_size=config.hidden_size, num_labels=2, activate_func=config.activate_func)
32 | else:
33 | raise ValueError
34 |
35 | self.pred_answerable = config.pred_answerable
36 | if self.pred_answerable:
37 | self.answerable_cls_output = MultiLayerPerceptronClassifier(hidden_size=config.hidden_size, num_labels=1, activate_func=config.activate_func)
38 |
39 | self.init_weights()
40 |
41 | def forward(self, input_ids, token_type_ids=None, attention_mask=None):
42 | """
43 | Args:
44 | input_ids: bert input tokens, tensor of shape [batch, seq_len]
45 | token_type_ids: 0 for query, 1 for context, tensor of shape [batch, seq_len]
46 | attention_mask: attention mask, tensor of shape [batch, seq_len]
47 | Returns:
48 | start_logits: start/non-start probs of shape [batch, seq_len]
49 | end_logits: end/non-end probs of shape [batch, seq_len]
50 | match_logits: start-end-match probs of shape [batch, seq_len, seq_len]
51 | """
52 |
53 | bert_outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
54 |
55 | sequence_heatmap = bert_outputs[0] # [batch, seq_len, hidden]
56 | sequence_cls = bert_outputs[1]
57 |
58 | batch_size, seq_len, hid_size = sequence_heatmap.size()
59 | if self.construct_entity_span == "match" :
60 | start_logits = end_logits = torch.ones_like(input_ids).float()
61 | span_logits = self.span_nn(sequence_heatmap)
62 | elif self.construct_entity_span == "start_end_match":
63 | sequence_heatmap = self.dropout(sequence_heatmap)
64 | start_logits = self.start_outputs(sequence_heatmap).view(batch_size, seq_len, -1) # [batch, seq_len, 1]
65 | end_logits = self.end_outputs(sequence_heatmap).view(batch_size, seq_len, -1) # [batch, seq_len, 1]
66 |
67 | # for every position $i$ in sequence, should concate $j$ to
68 | # predict if $i$ and $j$ are start_pos and end_pos for an entity.
69 | # [batch, seq_len, seq_len, hidden]
70 | start_extend = sequence_heatmap.unsqueeze(2).expand(-1, -1, seq_len, -1)
71 | # [batch, seq_len, seq_len, hidden]
72 | end_extend = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, -1)
73 | # [batch, seq_len, seq_len, hidden*2]
74 | span_matrix = torch.cat([start_extend, end_extend], 3)
75 | # [batch, seq_len, seq_len]
76 | span_logits = self.span_embedding(span_matrix).squeeze(-1)
77 | elif self.construct_entity_span == "start_and_end":
78 | sequence_heatmap = self.dropout(sequence_heatmap)
79 | start_logits = self.start_outputs(sequence_heatmap).view(batch_size, seq_len, -1) # [batch, seq_len, 1]
80 | end_logits = self.end_outputs(sequence_heatmap).view(batch_size, seq_len, -1) # [batch, seq_len, 1]
81 |
82 | span_logits = None
83 | elif self.construct_entity_span == "start_end":
84 | sequence_heatmap = self.dropout(sequence_heatmap)
85 | start_end_logits = self.start_end_outputs(sequence_heatmap)
86 | start_logits, end_logits = start_end_logits.split(1, dim=-1)
87 |
88 | span_logits = None
89 | else:
90 | raise ValueError
91 |
92 | if self.pred_answerable:
93 | cls_logits = self.answerable_cls_output(sequence_cls)
94 | return start_logits, end_logits, span_logits, cls_logits
95 |
96 | return start_logits, end_logits, span_logits
97 |
--------------------------------------------------------------------------------
/models/classifier.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: modules.py
5 | # description:
6 | # modules for building models.
7 |
8 | import math
9 | import torch
10 | import torch.nn as nn
11 | from torch.nn import init
12 | from torch.nn import functional as F
13 |
14 |
15 | def truncated_normal_(tensor, mean=0, std=1):
16 | size = tensor.shape
17 | tmp = tensor.new_empty(size + (4,)).normal_()
18 | valid = (tmp < 2) & (tmp > -2)
19 | ind = valid.max(-1, keepdim=True)[1]
20 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
21 | tensor.data.mul_(std).add_(mean)
22 | return tensor
23 |
24 |
25 | class BertMLP(nn.Module):
26 | def __init__(self, config):
27 | super().__init__()
28 | self.dense_layer = nn.Linear(config.hidden_size, config.hidden_size)
29 | self.dense_to_labels_layer = nn.Linear(config.hidden_size, config.num_labels)
30 | self.activation = nn.Tanh()
31 | if config.truncated_normal:
32 | self.dense_layer.weight = truncated_normal_(self.dense_layer.weight, mean=0, std=0.02)
33 | self.dense_to_labels_layer.weight = truncated_normal_(self.dense_to_labels_layer.weight, mean=0, std=0.02)
34 |
35 | def forward(self, sequence_hidden_states):
36 | sequence_output = self.dense_layer(sequence_hidden_states)
37 | sequence_output = self.activation(sequence_output)
38 | sequence_output = self.dense_to_labels_layer(sequence_output)
39 | return sequence_output
40 |
41 |
42 | class MultiLayerPerceptronClassifier(nn.Module):
43 | def __init__(self, hidden_size=None, num_labels=None, activate_func="gelu"):
44 | super().__init__()
45 | self.dense_layer = nn.Linear(hidden_size, hidden_size)
46 | self.dense_to_labels_layer = nn.Linear(hidden_size, num_labels)
47 | if activate_func == "tanh":
48 | self.activation = nn.Tanh()
49 | elif activate_func == "relu":
50 | self.activation = nn.ReLU()
51 | elif activate_func == "gelu":
52 | self.activation = nn.GELU()
53 | else:
54 | raise ValueError
55 |
56 | def forward(self, sequence_hidden_states):
57 | sequence_output = self.dense_layer(sequence_hidden_states)
58 | sequence_output = self.activation(sequence_output)
59 | sequence_output = self.dense_to_labels_layer(sequence_output)
60 | return sequence_output
61 |
62 |
63 | class SpanClassifier(nn.Module):
64 | def __init__(self, hidden_size: int, dropout_rate: float):
65 | super(SpanClassifier, self).__init__()
66 | self.start_proj = nn.Linear(hidden_size, hidden_size)
67 | self.end_proj = nn.Linear(hidden_size, hidden_size)
68 | self.biaffine = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
69 | self.concat_proj = nn.Linear(hidden_size * 2, 1)
70 | self.dropout = nn.Dropout(dropout_rate)
71 | self.reset_parameters()
72 |
73 | def forward(self, input_features):
74 | bsz, seq_len, dim = input_features.size()
75 | # B, L, h
76 | start_feature = self.dropout(F.gelu(self.start_proj(input_features)))
77 | # B, L, h
78 | end_feature = self.dropout(F.gelu(self.end_proj(input_features)))
79 | # B, L, L
80 | biaffine_logits = torch.bmm(torch.matmul(start_feature, self.biaffine), end_feature.transpose(1, 2))
81 |
82 | start_extend = start_feature.unsqueeze(2).expand(-1, -1, seq_len, -1)
83 | # [B, L, L, h]
84 | end_extend = end_feature.unsqueeze(1).expand(-1, seq_len, -1, -1)
85 | # [B, L, L, h]
86 | span_matrix = torch.cat([start_extend, end_extend], 3)
87 | # [B, L, L]
88 | concat_logits = self.concat_proj(span_matrix).squeeze(-1)
89 | # B, L, L
90 | return biaffine_logits + concat_logits
91 |
92 | def reset_parameters(self) -> None:
93 | init.kaiming_uniform_(self.biaffine, a=math.sqrt(5))
94 |
95 |
96 |
--------------------------------------------------------------------------------
/models/model_config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: model_config.py
5 | # description:
6 | # user defined configuration class for NLP tasks.
7 |
8 | from transformers import BertConfig
9 |
10 |
11 | class BertForQAConfig(BertConfig):
12 | def __init__(self, **kwargs):
13 | super(BertForQAConfig, self).__init__(**kwargs)
14 | self.hidden_size = kwargs.get("hidden_size", 768)
15 | self.multi_layer_classifier = kwargs.get("multi_layer_classifier", True)
16 | self.truncated_normal = kwargs.get("truncated_normal", True)
17 |
18 | class BertForSequenceClassificationConfig(BertConfig):
19 | def __init__(self, **kwargs):
20 | super(BertForSequenceClassificationConfig, self).__init__(**kwargs)
21 | self.hidden_dropout_prob = kwargs.get("hidden_dropout_prob", 0.0)
22 | self.num_labels = kwargs.get("num_labels", 2)
23 | self.hidden_size = kwargs.get("hidden_size", 768)
24 | self.truncated_normal = kwargs.get("truncated_normal", False)
25 |
26 | class BertForQueryNERConfig(BertConfig):
27 | def __init__(self, **kwargs):
28 | super(BertForQueryNERConfig, self).__init__(**kwargs)
29 | self.hidden_dropout_prob = kwargs.get("hidden_dropout_prob", 0.1)
30 | self.truncated_normal = kwargs.get("truncated_normal", False)
31 | self.construct_entity_span = kwargs.get("construct_entity_span", "start_end_match")
32 | self.pred_answerable = kwargs.get("pred_answerable", True)
33 | self.num_labels = kwargs.get("num_labels", 2)
34 | self.activate_func = kwargs.get("activate_func", "gelu")
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch-lightning==0.9.0
2 | tokenizers==0.9.3
3 | transformers==3.5.1
--------------------------------------------------------------------------------
/scripts/download_ckpt.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # description:
5 | # download pretrained model ckpt
6 |
7 | BERT_PRETRAIN_CKPT=$1
8 | MODEL_NAME=$2
9 |
10 | if [[ $MODEL_NAME == "bert_cased_base" ]]; then
11 | mkdir -p $BERT_PRETRAIN_CKPT
12 | echo "DownLoad BERT Cased Base"
13 | wget https://storage.googleapis.com/bert_models/2018_10_18/cased_L-12_H-768_A-12.zip -P $BERT_PRETRAIN_CKPT
14 | unzip $BERT_PRETRAIN_CKPT/cased_L-12_H-768_A-12.zip -d $BERT_PRETRAIN_CKPT
15 | rm $BERT_PRETRAIN_CKPT/cased_L-12_H-768_A-12.zip
16 | mv $BERT_PRETRAIN_CKPT/cased_L-12_H-768_A-12 $BERT_PRETRAIN_CKPT/bert_cased_base
17 | elif [[ $MODEL_NAME == "bert_cased_large" ]]; then
18 | echo "DownLoad BERT Cased Large"
19 | wget https://storage.googleapis.com/bert_models/2018_10_18/cased_L-24_H-1024_A-16.zip -P $BERT_PRETRAIN_CKPT
20 | unzip $BERT_PRETRAIN_CKPT/cased_L-24_H-1024_A-16.zip -d $BERT_PRETRAIN_CKPT
21 | rm $BERT_PRETRAIN_CKPT/cased_L-24_H-1024_A-16.zip
22 | mv $BERT_PRETRAIN_CKPT/cased_L-24_H-1024_A-16 $BERT_PRETRAIN_CKPT/bert_cased_large
23 | elif [[ $MODEL_NAME == "bert_uncased_base" ]]; then
24 | echo "DownLoad BERT Uncased Base"
25 | wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip -P $BERT_PRETRAIN_CKPT
26 | unzip $BERT_PRETRAIN_CKPT/uncased_L-12_H-768_A-12.zip -d $BERT_PRETRAIN_CKPT
27 | rm $BERT_PRETRAIN_CKPT/uncased_L-12_H-768_A-12.zip
28 | mv $BERT_PRETRAIN_CKPT/uncased_L-12_H-768_A-12 $BERT_PRETRAIN_CKPT/bert_uncased_base
29 | elif [[ $MODEL_NAME == "bert_uncased_large" ]]; then
30 | echo "DownLoad BERT Uncased Large"
31 | wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-24_H-1024_A-16.zip -P $BERT_PRETRAIN_CKPT
32 | unzip $BERT_PRETRAIN_CKPT/uncased_L-24_H-1024_A-16.zip -d $BERT_PRETRAIN_CKPT
33 | rm $BERT_PRETRAIN_CKPT/uncased_L-24_H-1024_A-16.zip
34 | mv $BERT_PRETRAIN_CKPT/uncased_L-24_H-1024_A-16 $BERT_PRETRAIN_CKPT/bert_uncased_large
35 | elif [[ $MODEL_NAME == "bert_tiny" ]]; then
36 | each "DownLoad BERT Uncased Tiny; Helps to debug on GPU."
37 | wget https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-2_H-128_A-2.zip -P $BERT_PRETRAIN_CKPT
38 | unzip -zxvf $BERT_PRETRAIN_CKPT/uncased_L-2_H-128_A-2.zip -d $BERT_PRETRAIN_CKPT
39 | rm $BERT_PRETRAIN_CKPT/uncased_L-2_H-128_A-2.zip
40 | mv $BERT_PRETRAIN_CKPT/uncased_L-2_H-128_A-2 $BERT_PRETRAIN_CKPT/bert_uncased_tiny
41 | else
42 | echo 'Unknown argment 2 (Model Sign)'
43 | fi
--------------------------------------------------------------------------------
/scripts/glue_mrpc/bert_base_ce.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: glue/bert_base_ce.sh
5 | # result:
6 | # dev f1/acc: 90.03/85.78
7 | # test f1/acc: 87.36/82.72
8 |
9 | FILE=bert_base_ce
10 | MODEL_SCALE=base
11 | TASK=mrpc
12 | OUTPUT_BASE_DIR=/data/xiaoya/outputs/dice_loss/glue_mrpc
13 | DATA_DIR=/data/xiaoya/datasets/mrpc
14 | BERT_DIR=/data/xiaoya/pretrain_lm/cased_L-12_H-768_A-12
15 | REPO_PATH=/data/xiaoya/workspace/dice_loss_for_NLP
16 |
17 | # NEED CHANGE
18 | LR=2e-5
19 | LR_SCHEDULER=linear
20 | TRAIN_BATCH_SIZE=24
21 | ACC_GRAD=1
22 | DROPOUT=0.1
23 | WEIGHT_DECAY=0.01
24 | WARMUP_PROPORTION=0.1
25 | LOSS_TYPE=ce
26 | DICE_SMOOTH=1e-4
27 | DICE_OHEM=0.1
28 | DICE_ALPHA=0.01
29 | FOCAL_GAMMA=0.1
30 |
31 | # DONOT NEED CHANGE
32 | PRECISION=32
33 | MAX_SEQ_LEN=128
34 | MAX_EPOCH=3
35 | PROGRESS_BAR=1
36 | VAL_CHECK_INTERVAL=0.25
37 | DISTRIBUTE=ddp
38 | GRAD_CLIP=1
39 |
40 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE}_${MODEL_SCALE}_${LOSS_TYPE}_${MAX_SEQ_LEN}_${TRAIN_BATCH_SIZE}_${ACC_GRAD}_${MAX_EPOCH}_${LR}_${LR_SCHEDULER}_${DROPOUT}_${WARMUP_PROPORTION}_${WEIGHT_DECAY}
41 | mkdir -p ${OUTPUT_DIR}
42 | CACHE_DIR=${OUTPUT_DIR}/cache
43 | mkdir -p ${CACHE_DIR}
44 |
45 |
46 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
47 | CUDA_VISIBLE_DEVICES=1 python3 ${REPO_PATH}/tasks/glue/train.py \
48 | --gpus="1" \
49 | --task_name ${TASK} \
50 | --max_seq_length ${MAX_SEQ_LEN} \
51 | --train_batch_size ${TRAIN_BATCH_SIZE} \
52 | --precision=${PRECISION} \
53 | --default_root_dir ${OUTPUT_DIR} \
54 | --output_dir ${OUTPUT_DIR} \
55 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
56 | --val_check_interval ${VAL_CHECK_INTERVAL} \
57 | --loss_type ${LOSS_TYPE} \
58 | --data_dir ${DATA_DIR} \
59 | --bert_config_dir ${BERT_DIR} \
60 | --bert_hidden_dropout ${DROPOUT} \
61 | --lr ${LR} \
62 | --lr_scheduler ${LR_SCHEDULER} \
63 | --accumulate_grad_batches ${ACC_GRAD} \
64 | --output_dir ${OUTPUT_DIR} \
65 | --max_epochs ${MAX_EPOCH} \
66 | --gradient_clip_val ${GRAD_CLIP} \
67 | --pad_to_max_length \
68 | --weight_decay ${WEIGHT_DECAY} \
69 | --warmup_proportion ${WARMUP_PROPORTION} \
70 | --overwrite_cache
71 |
--------------------------------------------------------------------------------
/scripts/glue_mrpc/bert_base_dice.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # result:
5 | # Dev f1/acc: 90.42/86.03
6 | # Test f1/acc: 88.23/83.59
7 | # gpu4: /data/xiaoya/outputs/dice_loss/glue_mrpc/2021.01.24/dice_night8_base_dice_128_20_1_5_3e-5_linear_0.2_0.002_0.003
8 |
9 |
10 | FILE=reproduce_dice_bertbase
11 | MODEL_SCALE=base
12 | TASK=mrpc
13 | OUTPUT_BASE_DIR=/data/xiaoya/outputs/dice_loss/glue_mrpc
14 | DATA_DIR=/data/xiaoya/datasets/mrpc
15 | BERT_DIR=/data/xiaoya/pretrain_lm/cased_L-12_H-768_A-12
16 | REPO_PATH=/data/xiaoya/workspace/dice_loss_for_NLP
17 |
18 | # NEED CHANGE
19 | LR=3e-5
20 | LR_SCHEDULER=linear
21 | TRAIN_BATCH_SIZE=20
22 | ACC_GRAD=1
23 | DROPOUT=0.2
24 | WEIGHT_DECAY=0.003
25 | WARMUP_PROPORTION=0.002
26 | LOSS_TYPE=dice
27 | DICE_SMOOTH=1
28 | DICE_OHEM=0
29 | DICE_ALPHA=0.01
30 |
31 | # DONOT NEED CHANGE
32 | PRECISION=32
33 | MAX_SEQ_LEN=128
34 | MAX_EPOCH=5
35 | PROGRESS_BAR=1
36 | VAL_CHECK_INTERVAL=0.25
37 | DISTRIBUTE=ddp
38 | GRAD_CLIP=1
39 |
40 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE}_${MODEL_SCALE}_${LOSS_TYPE}_${MAX_SEQ_LEN}_${TRAIN_BATCH_SIZE}_${ACC_GRAD}_${MAX_EPOCH}_${LR}_${LR_SCHEDULER}_${DROPOUT}_${WARMUP_PROPORTION}_${WEIGHT_DECAY}
41 | mkdir -p ${OUTPUT_DIR}
42 | CACHE_DIR=${OUTPUT_DIR}/cache
43 | mkdir -p ${CACHE_DIR}
44 |
45 |
46 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
47 | CUDA_VISIBLE_DEVICES=1 python3 ${REPO_PATH}/tasks/glue/train.py \
48 | --gpus="1" \
49 | --task_name ${TASK} \
50 | --max_seq_length ${MAX_SEQ_LEN} \
51 | --train_batch_size ${TRAIN_BATCH_SIZE} \
52 | --precision=${PRECISION} \
53 | --default_root_dir ${OUTPUT_DIR} \
54 | --output_dir ${OUTPUT_DIR} \
55 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
56 | --val_check_interval ${VAL_CHECK_INTERVAL} \
57 | --loss_type ${LOSS_TYPE} \
58 | --data_dir ${DATA_DIR} \
59 | --bert_config_dir ${BERT_DIR} \
60 | --bert_hidden_dropout ${DROPOUT} \
61 | --lr ${LR} \
62 | --lr_scheduler ${LR_SCHEDULER} \
63 | --accumulate_grad_batches ${ACC_GRAD} \
64 | --output_dir ${OUTPUT_DIR} \
65 | --max_epochs ${MAX_EPOCH} \
66 | --gradient_clip_val ${GRAD_CLIP} \
67 | --pad_to_max_length \
68 | --weight_decay ${WEIGHT_DECAY} \
69 | --warmup_proportion ${WARMUP_PROPORTION} \
70 | --overwrite_cache \
71 | --dice_square \
72 | --dice_smooth ${DICE_SMOOTH} \
73 | --dice_ohem ${DICE_OHEM} \
74 | --dice_alpha ${DICE_ALPHA}
75 |
76 |
--------------------------------------------------------------------------------
/scripts/glue_mrpc/bert_base_focal.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # Result:
5 | # - Dev f1/acc: 89.31/84.80
6 | # - Test f1/acc: 88.06/83.59
7 | # gpu3: /data/xiaoya/outputs/dice_loss/glue_mrpc/2021.01.25/mrpc_focal_base4_base_focal_128_38_1_3_3e-5_linear_0.1_0.1_0.002
8 |
9 |
10 | FILE=mrpc_focal_base4
11 | MODEL_SCALE=base
12 | TASK=mrpc
13 | OUTPUT_BASE_DIR=/data/xiaoya/outputs/dice_loss/glue_mrpc
14 | DATA_DIR=/data/xiaoya/datasets/mrpc
15 | BERT_DIR=/data/xiaoya/pretrain_lm/cased_L-12_H-768_A-12
16 | REPO_PATH=/data/xiaoya/workspace/dice_loss_for_NLP
17 |
18 | # NEED CHANGE
19 | LR=3e-5
20 | LR_SCHEDULER=linear
21 | TRAIN_BATCH_SIZE=38
22 | ACC_GRAD=1
23 | DROPOUT=0.1
24 | WEIGHT_DECAY=0.002
25 | WARMUP_PROPORTION=0.1
26 | LOSS_TYPE=focal
27 | DICE_SMOOTH=1e-4
28 | DICE_OHEM=0.1
29 | DICE_ALPHA=0.01
30 | FOCAL_GAMMA=2
31 | OPTIMIZER=debias
32 |
33 | # DONOT NEED CHANGE
34 | PRECISION=32
35 | MAX_SEQ_LEN=128
36 | MAX_EPOCH=3
37 | PROGRESS_BAR=1
38 | VAL_CHECK_INTERVAL=0.25
39 | DISTRIBUTE=ddp
40 | GRAD_CLIP=1
41 |
42 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE}_${MODEL_SCALE}_${LOSS_TYPE}_${MAX_SEQ_LEN}_${TRAIN_BATCH_SIZE}_${ACC_GRAD}_${MAX_EPOCH}_${LR}_${LR_SCHEDULER}_${DROPOUT}_${WARMUP_PROPORTION}_${WEIGHT_DECAY}
43 | mkdir -p ${OUTPUT_DIR}
44 | CACHE_DIR=${OUTPUT_DIR}/cache
45 | mkdir -p ${CACHE_DIR}
46 |
47 |
48 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
49 | CUDA_VISIBLE_DEVICES=3 python3 ${REPO_PATH}/tasks/glue/train.py \
50 | --gpus="1" \
51 | --task_name ${TASK} \
52 | --max_seq_length ${MAX_SEQ_LEN} \
53 | --train_batch_size ${TRAIN_BATCH_SIZE} \
54 | --precision=${PRECISION} \
55 | --default_root_dir ${OUTPUT_DIR} \
56 | --output_dir ${OUTPUT_DIR} \
57 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
58 | --val_check_interval ${VAL_CHECK_INTERVAL} \
59 | --loss_type ${LOSS_TYPE} \
60 | --data_dir ${DATA_DIR} \
61 | --bert_config_dir ${BERT_DIR} \
62 | --bert_hidden_dropout ${DROPOUT} \
63 | --lr ${LR} \
64 | --lr_scheduler ${LR_SCHEDULER} \
65 | --accumulate_grad_batches ${ACC_GRAD} \
66 | --output_dir ${OUTPUT_DIR} \
67 | --max_epochs ${MAX_EPOCH} \
68 | --gradient_clip_val ${GRAD_CLIP} \
69 | --pad_to_max_length \
70 | --weight_decay ${WEIGHT_DECAY} \
71 | --warmup_proportion ${WARMUP_PROPORTION} \
72 | --overwrite_cache \
73 | --optimizer ${OPTIMIZER} \
74 | --focal_gamma ${FOCAL_GAMMA}
75 |
76 |
77 |
--------------------------------------------------------------------------------
/scripts/glue_mrpc/bert_large_ce.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # result:
5 | # - Dev f1/acc: 91.53/87.75
6 | # - Test f1/acc: 87.98/83.13
7 | # gpu4: /data/xiaoya/outputs/dice_loss/glue_mrpc/2021.01.25/ce_large_large_ce_128_16_2_3_2e-5_linear_0.1_0.1_0.002
8 |
9 |
10 | FILE=ce_large
11 | MODEL_SCALE=large
12 | TASK=mrpc
13 | OUTPUT_BASE_DIR=/data/xiaoya/outputs/dice_loss/glue_mrpc
14 | DATA_DIR=/data/xiaoya/datasets/mrpc
15 | BERT_DIR=/data/xiaoya/pretrain_lm/bert_cased_large
16 | REPO_PATH=/data/xiaoya/workspace/dice_loss_for_NLP
17 |
18 | # NEED CHANGE
19 | LR=2e-5
20 | LR_SCHEDULER=linear
21 | TRAIN_BATCH_SIZE=16
22 | ACC_GRAD=2
23 | DROPOUT=0.1
24 | WEIGHT_DECAY=0.002
25 | WARMUP_PROPORTION=0.1
26 | LOSS_TYPE=ce
27 | DICE_SMOOTH=1e-4
28 | DICE_OHEM=0.1
29 | DICE_ALPHA=0.01
30 | FOCAL_GAMMA=0.1
31 | OPTIMIZER=debias
32 |
33 | # DONOT NEED CHANGE
34 | PRECISION=32
35 | MAX_SEQ_LEN=128
36 | MAX_EPOCH=3
37 | PROGRESS_BAR=1
38 | VAL_CHECK_INTERVAL=0.25
39 | DISTRIBUTE=ddp
40 | GRAD_CLIP=1
41 |
42 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE}_${MODEL_SCALE}_${LOSS_TYPE}_${MAX_SEQ_LEN}_${TRAIN_BATCH_SIZE}_${ACC_GRAD}_${MAX_EPOCH}_${LR}_${LR_SCHEDULER}_${DROPOUT}_${WARMUP_PROPORTION}_${WEIGHT_DECAY}
43 | mkdir -p ${OUTPUT_DIR}
44 | CACHE_DIR=${OUTPUT_DIR}/cache
45 | mkdir -p ${CACHE_DIR}
46 |
47 |
48 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
49 | CUDA_VISIBLE_DEVICES=2 python3 ${REPO_PATH}/tasks/glue/train.py \
50 | --gpus="1" \
51 | --task_name ${TASK} \
52 | --max_seq_length ${MAX_SEQ_LEN} \
53 | --train_batch_size ${TRAIN_BATCH_SIZE} \
54 | --precision=${PRECISION} \
55 | --default_root_dir ${OUTPUT_DIR} \
56 | --output_dir ${OUTPUT_DIR} \
57 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
58 | --val_check_interval ${VAL_CHECK_INTERVAL} \
59 | --loss_type ${LOSS_TYPE} \
60 | --data_dir ${DATA_DIR} \
61 | --bert_config_dir ${BERT_DIR} \
62 | --bert_hidden_dropout ${DROPOUT} \
63 | --lr ${LR} \
64 | --lr_scheduler ${LR_SCHEDULER} \
65 | --accumulate_grad_batches ${ACC_GRAD} \
66 | --output_dir ${OUTPUT_DIR} \
67 | --max_epochs ${MAX_EPOCH} \
68 | --gradient_clip_val ${GRAD_CLIP} \
69 | --pad_to_max_length \
70 | --weight_decay ${WEIGHT_DECAY} \
71 | --warmup_proportion ${WARMUP_PROPORTION} \
72 | --overwrite_cache \
73 | --optimizer ${OPTIMIZER}
74 |
75 |
--------------------------------------------------------------------------------
/scripts/glue_mrpc/bert_large_dice.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # result:
5 | # - Dev f1/acc: 91.31/87.50
6 | # - Test f1/acc: 88.98/84.64
7 | # gpu6: /data/xiaoya/outputs/dice_loss/glue_mrpc/2021.02.19/dice_large4_large_dice_128_12_1_5_2e-5_linear_0.1_0.06_0.001
8 |
9 |
10 | FILE=dice_large4
11 | MODEL_SCALE=large
12 | TASK=mrpc
13 | OUTPUT_BASE_DIR=/data/xiaoya/outputs/dice_loss/glue_mrpc
14 | DATA_DIR=/data/xiaoya/datasets/mrpc
15 | BERT_DIR=/data/xiaoya/pretrain_lm/bert_cased_large
16 | REPO_PATH=/data/xiaoya/workspace/dice_loss_for_NLP
17 |
18 | # NEED CHANGE
19 | LR=2e-5
20 | LR_SCHEDULER=linear
21 | TRAIN_BATCH_SIZE=12
22 | ACC_GRAD=1
23 | DROPOUT=0.1
24 | WEIGHT_DECAY=0.001
25 | WARMUP_PROPORTION=0.06
26 | LOSS_TYPE=dice
27 | DICE_SMOOTH=1
28 | DICE_OHEM=0
29 | DICE_ALPHA=0.01
30 | FOCAL_GAMMA=1
31 | OPTIMIZER=debias
32 |
33 | # DONOT NEED CHANGE
34 | PRECISION=32
35 | MAX_SEQ_LEN=128
36 | MAX_EPOCH=5
37 | PROGRESS_BAR=1
38 | VAL_CHECK_INTERVAL=0.25
39 | DISTRIBUTE=ddp
40 | GRAD_CLIP=1
41 |
42 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE}_${MODEL_SCALE}_${LOSS_TYPE}_${MAX_SEQ_LEN}_${TRAIN_BATCH_SIZE}_${ACC_GRAD}_${MAX_EPOCH}_${LR}_${LR_SCHEDULER}_${DROPOUT}_${WARMUP_PROPORTION}_${WEIGHT_DECAY}
43 | mkdir -p ${OUTPUT_DIR}
44 | CACHE_DIR=${OUTPUT_DIR}/cache
45 | mkdir -p ${CACHE_DIR}
46 |
47 |
48 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
49 | CUDA_VISIBLE_DEVICES=1 python3 ${REPO_PATH}/tasks/glue/train.py \
50 | --gpus="1" \
51 | --task_name ${TASK} \
52 | --max_seq_length ${MAX_SEQ_LEN} \
53 | --train_batch_size ${TRAIN_BATCH_SIZE} \
54 | --precision=${PRECISION} \
55 | --default_root_dir ${OUTPUT_DIR} \
56 | --output_dir ${OUTPUT_DIR} \
57 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
58 | --val_check_interval ${VAL_CHECK_INTERVAL} \
59 | --loss_type ${LOSS_TYPE} \
60 | --data_dir ${DATA_DIR} \
61 | --bert_config_dir ${BERT_DIR} \
62 | --bert_hidden_dropout ${DROPOUT} \
63 | --lr ${LR} \
64 | --lr_scheduler ${LR_SCHEDULER} \
65 | --accumulate_grad_batches ${ACC_GRAD} \
66 | --output_dir ${OUTPUT_DIR} \
67 | --max_epochs ${MAX_EPOCH} \
68 | --gradient_clip_val ${GRAD_CLIP} \
69 | --pad_to_max_length \
70 | --weight_decay ${WEIGHT_DECAY} \
71 | --warmup_proportion ${WARMUP_PROPORTION} \
72 | --overwrite_cache \
73 | --optimizer ${OPTIMIZER} \
74 | --dice_square \
75 | --dice_smooth ${DICE_SMOOTH} \
76 | --dice_ohem ${DICE_OHEM} \
77 | --dice_alpha ${DICE_ALPHA}
78 |
79 |
--------------------------------------------------------------------------------
/scripts/glue_mrpc/bert_large_focal.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # Result:
5 | # - Dev f1/acc: 90.91/86.76
6 | # - Test f1/acc: 88.35/84.06
7 | # gpu3: /data/xiaoya/outputs/dice_loss/glue_mrpc/2021.01.25/mrpc_focal_large2_large_focal_128_16_2_3_2e-5_linear_0.1_0.1_0.002
8 |
9 |
10 | FILE=mrpc_focal_large2
11 | MODEL_SCALE=large
12 | TASK=mrpc
13 | OUTPUT_BASE_DIR=/data/xiaoya/outputs/dice_loss/glue_mrpc
14 | DATA_DIR=/data/xiaoya/datasets/mrpc
15 | BERT_DIR=/data/xiaoya/pretrain_lm/bert_cased_large
16 | REPO_PATH=/data/xiaoya/workspace/dice_loss_for_NLP
17 |
18 | # NEED CHANGE
19 | LR=2e-5
20 | LR_SCHEDULER=linear
21 | TRAIN_BATCH_SIZE=16
22 | ACC_GRAD=2
23 | DROPOUT=0.1
24 | WEIGHT_DECAY=0.002
25 | WARMUP_PROPORTION=0.1
26 | LOSS_TYPE=focal
27 | DICE_SMOOTH=1e-4
28 | DICE_OHEM=0.1
29 | DICE_ALPHA=0.01
30 | FOCAL_GAMMA=3
31 | OPTIMIZER=debias
32 |
33 | # DONOT NEED CHANGE
34 | PRECISION=32
35 | MAX_SEQ_LEN=128
36 | MAX_EPOCH=3
37 | PROGRESS_BAR=1
38 | VAL_CHECK_INTERVAL=0.25
39 | DISTRIBUTE=ddp
40 | GRAD_CLIP=1
41 |
42 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE}_${MODEL_SCALE}_${LOSS_TYPE}_${MAX_SEQ_LEN}_${TRAIN_BATCH_SIZE}_${ACC_GRAD}_${MAX_EPOCH}_${LR}_${LR_SCHEDULER}_${DROPOUT}_${WARMUP_PROPORTION}_${WEIGHT_DECAY}
43 | mkdir -p ${OUTPUT_DIR}
44 | CACHE_DIR=${OUTPUT_DIR}/cache
45 | mkdir -p ${CACHE_DIR}
46 |
47 |
48 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
49 | CUDA_VISIBLE_DEVICES=2 python3 ${REPO_PATH}/tasks/glue/train.py \
50 | --gpus="1" \
51 | --task_name ${TASK} \
52 | --max_seq_length ${MAX_SEQ_LEN} \
53 | --train_batch_size ${TRAIN_BATCH_SIZE} \
54 | --precision=${PRECISION} \
55 | --default_root_dir ${OUTPUT_DIR} \
56 | --output_dir ${OUTPUT_DIR} \
57 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
58 | --val_check_interval ${VAL_CHECK_INTERVAL} \
59 | --loss_type ${LOSS_TYPE} \
60 | --data_dir ${DATA_DIR} \
61 | --bert_config_dir ${BERT_DIR} \
62 | --bert_hidden_dropout ${DROPOUT} \
63 | --lr ${LR} \
64 | --lr_scheduler ${LR_SCHEDULER} \
65 | --accumulate_grad_batches ${ACC_GRAD} \
66 | --output_dir ${OUTPUT_DIR} \
67 | --max_epochs ${MAX_EPOCH} \
68 | --gradient_clip_val ${GRAD_CLIP} \
69 | --pad_to_max_length \
70 | --weight_decay ${WEIGHT_DECAY} \
71 | --warmup_proportion ${WARMUP_PROPORTION} \
72 | --overwrite_cache \
73 | --optimizer ${OPTIMIZER} \
74 | --focal_gamma ${FOCAL_GAMMA}
75 |
--------------------------------------------------------------------------------
/scripts/glue_mrpc/eval.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: scripts/glue/eval.sh
5 |
6 |
7 | REPO_PATH=/data/xiaoya/workspace/dice_loss_for_NLP
8 |
9 | EVAL_DIR=$1
10 | CKPT_PATH=${EVAL_DIR}/$2
11 |
12 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
13 | CUDA_VISIBLE_DEVICES=0 python3 ${REPO_PATH}/tasks/glue/evaluate_models.py \
14 | --gpus="1" \
15 | --path_to_model_checkpoint ${CKPT_PATH}
--------------------------------------------------------------------------------
/scripts/mrc_squad1/bert_base_ce.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP
6 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
7 |
8 | DATA_DIR=/userhome/xiaoya/dataset/squad1
9 | BERT_DIR=/userhome/xiaoya/bert/uncased_L-12_H-768_A-12
10 |
11 | LOSS_TYPE=ce
12 | LR=3e-5
13 | LR_SCHEDULE=linear
14 | OPTIMIZER=adamw
15 | WARMUP_PROPORTION=0.002
16 |
17 | GRAD_CLIP=1.0
18 | ACC_GRAD=1
19 | MAX_EPOCH=2
20 | BERT_DROPOUT=0.1
21 | WEIGHT_DECAY=0.002
22 |
23 | TRAIN_BATCH_SIZE=12
24 | MAX_QUERY_LEN=64
25 | MAX_SEQ_LEN=384
26 | DOC_STRIDE=128
27 |
28 | PRECISION=16
29 | PROGRESS_BAR=1
30 | VAL_CHECK_INTERVAL=0.125
31 | DISTRIBUTE=ddp
32 |
33 | OUTPUT_DIR_BASE=/userhome/xiaoya/outputs/dice_loss/squad
34 | OUTPUT_DIR=${OUTPUT_DIR_BASE}/reproduce_bert_base_ce
35 |
36 | mkdir -p ${OUTPUT_DIR}
37 | CACHE_DIR=${OUTPUT_DIR}/cache
38 | mkdir -p ${CACHE_DIR}
39 |
40 | python ${REPO_PATH}/tasks/squad/train.py \
41 | --gpus="1" \
42 | --precision=${PRECISION} \
43 | --train_batch_size ${TRAIN_BATCH_SIZE} \
44 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
45 | --val_check_interval ${VAL_CHECK_INTERVAL} \
46 | --max_query_length ${MAX_QUERY_LEN} \
47 | --max_seq_length ${MAX_SEQ_LEN} \
48 | --doc_stride ${DOC_STRIDE} \
49 | --optimizer ${OPTIMIZER} \
50 | --loss_type ${LOSS_TYPE} \
51 | --data_dir ${DATA_DIR} \
52 | --bert_hidden_dropout ${BERT_DROPOUT} \
53 | --bert_config_dir ${BERT_DIR} \
54 | --lr ${LR} \
55 | --lr_scheduler ${LR_SCHEDULE} \
56 | --accumulate_grad_batches ${ACC_GRAD} \
57 | --default_root_dir ${OUTPUT_DIR} \
58 | --output_dir ${OUTPUT_DIR} \
59 | --max_epochs ${MAX_EPOCH} \
60 | --gradient_clip_val ${GRAD_CLIP} \
61 | --weight_decay ${WEIGHT_DECAY} \
62 | --do_lower_case \
63 | --warmup_proportion ${WARMUP_PROPORTION}
64 |
--------------------------------------------------------------------------------
/scripts/mrc_squad1/bert_base_dice.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP
6 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
7 |
8 | DATA_DIR=/userhome/xiaoya/dataset/squad1
9 | BERT_DIR=/userhome/xiaoya/bert/uncased_L-12_H-768_A-12
10 |
11 | LOSS_TYPE=dice
12 | LR=3e-5
13 | LR_SCHEDULE=linear
14 | GRAD_CLIP=1.0
15 | OPTIMIZER=adamw
16 | WARMUP=0
17 | WARMUP_PROPORTION=0
18 | ACC_GRAD=5
19 | MAX_EPOCH=10
20 | BERT_DROPOUT=0.1
21 | WEIGHT_DECAY=0
22 | TRAIN_BATCH_SIZE=4
23 | MAX_QUERY_LEN=64
24 | MAX_SEQ_LEN=384
25 | DOC_STRIDE=128
26 |
27 | DICE_SMOOTH=1
28 | DICE_OHEM=0.3
29 | DICE_ALPHA=0.01
30 |
31 | PRECISION=16
32 | PROGRESS_BAR=1
33 | VAL_CHECK_INTERVAL=0.125
34 |
35 | OUTPUT_DIR_BASE=/userhome/xiaoya/outputs/dice_loss/squad
36 | OUTPUT_DIR=${OUTPUT_DIR_BASE}/reproduce_bert_base_dice
37 |
38 | echo "INFO -> OUTPUT_DIR is ${OUTPUT_DIR}"
39 | mkdir -p ${OUTPUT_DIR}
40 | CACHE_DIR=${OUTPUT_DIR}/cache
41 | mkdir -p ${CACHE_DIR}
42 |
43 | python ${REPO_PATH}/tasks/squad/train.py \
44 | --gpus="1" \
45 | --precision=${PRECISION} \
46 | --train_batch_size ${TRAIN_BATCH_SIZE} \
47 | --dice_smooth ${DICE_SMOOTH} \
48 | --dice_ohem ${DICE_OHEM} \
49 | --dice_alpha ${DICE_ALPHA} \
50 | --dice_square \
51 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
52 | --val_check_interval ${VAL_CHECK_INTERVAL} \
53 | --max_query_length ${MAX_QUERY_LEN} \
54 | --max_seq_length ${MAX_SEQ_LEN} \
55 | --doc_stride ${DOC_STRIDE} \
56 | --optimizer ${OPTIMIZER} \
57 | --loss_type ${LOSS_TYPE} \
58 | --data_dir ${DATA_DIR} \
59 | --bert_hidden_dropout ${BERT_DROPOUT} \
60 | --bert_config_dir ${BERT_DIR} \
61 | --lr ${LR} \
62 | --lr_scheduler ${LR_SCHEDULE} \
63 | --accumulate_grad_batches ${ACC_GRAD} \
64 | --default_root_dir ${OUTPUT_DIR} \
65 | --output_dir ${OUTPUT_DIR} \
66 | --max_epochs ${MAX_EPOCH} \
67 | --gradient_clip_val ${GRAD_CLIP} \
68 | --weight_decay ${WEIGHT_DECAY} \
69 | --do_lower_case \
70 | --warmup_proportion ${WARMUP_PROPORTION}
71 |
72 |
73 |
--------------------------------------------------------------------------------
/scripts/mrc_squad1/bert_base_focal.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP
6 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
7 |
8 | DATA_DIR=/userhome/xiaoya/dataset/squad1
9 | BERT_DIR=/userhome/xiaoya/bert/uncased_L-12_H-768_A-12
10 |
11 | LR=3e-5
12 | LR_SCHEDULE=onecycle
13 | OPTIMIZER=adamw
14 | WARMUP_PROPORTION=0
15 | GRAD_CLIP=1.0
16 | MAX_EPOCH=3
17 | ACC_GRAD=6
18 | BERT_DROPOUT=0.1
19 | WEIGHT_DECAY=0.01
20 | BATCH_SIZE=2
21 | MAX_QUERY_LEN=64
22 | MAX_SEQ_LEN=384
23 | DOC_STRIDE=128
24 |
25 | LOSS_TYPE=focal
26 | FOCAL_GAMMA=2
27 |
28 | OUTPUT_DIR_BASE=/userhome/xiaoya/outputs/dice_loss/squad
29 | OUTPUT_DIR=${OUTPUT_DIR_BASE}/reproduce_bert_base_focal
30 |
31 | mkdir -p ${OUTPUT_DIR}
32 | CACHE_DIR=${OUTPUT_DIR}/cache
33 | mkdir -p ${CACHE_DIR}
34 |
35 |
36 | PRECISION=16
37 | PROGRESS_BAR=1
38 | VAL_CHECK_INTERVAL=0.5
39 |
40 | python ${REPO_PATH}/tasks/squad/train.py \
41 | --gpus="1" \
42 | --train_batch_size ${BATCH_SIZE} \
43 | --precision=${PRECISION} \
44 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
45 | --val_check_interval ${VAL_CHECK_INTERVAL} \
46 | --max_query_length ${MAX_QUERY_LEN} \
47 | --max_seq_length ${MAX_SEQ_LEN} \
48 | --doc_stride ${DOC_STRIDE} \
49 | --optimizer ${OPTIMIZER} \
50 | --loss_type ${LOSS_TYPE} \
51 | --data_dir ${DATA_DIR} \
52 | --bert_hidden_dropout ${BERT_DROPOUT} \
53 | --bert_config_dir ${BERT_DIR} \
54 | --lr ${LR} \
55 | --lr_scheduler ${LR_SCHEDULE} \
56 | --warmup_proportion ${WARMUP_PROPORTION} \
57 | --accumulate_grad_batches ${ACC_GRAD} \
58 | --default_root_dir ${OUTPUT_DIR} \
59 | --output_dir ${OUTPUT_DIR} \
60 | --max_epochs ${MAX_EPOCH} \
61 | --gradient_clip_val ${GRAD_CLIP} \
62 | --do_lower_case \
63 | --weight_decay ${WEIGHT_DECAY} \
64 | --focal_gamma ${FOCAL_GAMMA} \
65 |
66 |
67 |
68 |
--------------------------------------------------------------------------------
/scripts/mrc_squad1/bert_large_ce.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # description:
5 | # predictions_4_7387.json
6 | # EM -> 83.98; F1 -> 90.89
7 |
8 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP
9 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
10 |
11 | MODEL_SCALE=large
12 | DATA_DIR=/userhome/xiaoya/dataset/squad1
13 | BERT_DIR=/userhome/xiaoya/bert/uncased_L-24_H-1024_A-16
14 |
15 | LOSS_TYPE=ce
16 | LR=3e-5
17 | LR_SCHEDULE=linear
18 | OPTIMIZER=adamw
19 | WARMUP_PROPORTION=0.002
20 | GRAD_CLIP=1.0
21 | ACC_GRAD=6
22 | MAX_EPOCH=2
23 |
24 | BERT_DROPOUT=0.1
25 | WEIGHT_DECAY=0.002
26 | TRAIN_BATCH_SIZE=4
27 | MAX_QUERY_LEN=64
28 | MAX_SEQ_LEN=384
29 | DOC_STRIDE=128
30 |
31 | PRECISION=16
32 | PROGRESS_BAR=1
33 | VAL_CHECK_INTERVAL=0.125
34 | DISTRIBUTE=ddp
35 |
36 | OUTPUT_DIR_BASE=/userhome/xiaoya/outputs/dice_loss/squad
37 | OUTPUT_DIR=${OUTPUT_DIR_BASE}/reproduce_bert_large_ce
38 |
39 | mkdir -p ${OUTPUT_DIR}
40 | CACHE_DIR=${OUTPUT_DIR}/cache
41 | mkdir -p ${CACHE_DIR}
42 |
43 | python ${REPO_PATH}/tasks/squad/train.py \
44 | --gpus="1" \
45 | --precision=${PRECISION} \
46 | --train_batch_size ${TRAIN_BATCH_SIZE} \
47 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
48 | --val_check_interval ${VAL_CHECK_INTERVAL} \
49 | --max_query_length ${MAX_QUERY_LEN} \
50 | --max_seq_length ${MAX_SEQ_LEN} \
51 | --doc_stride ${DOC_STRIDE} \
52 | --optimizer ${OPTIMIZER} \
53 | --loss_type ${LOSS_TYPE} \
54 | --data_dir ${DATA_DIR} \
55 | --bert_hidden_dropout ${BERT_DROPOUT} \
56 | --bert_config_dir ${BERT_DIR} \
57 | --lr ${LR} \
58 | --lr_scheduler ${LR_SCHEDULE} \
59 | --accumulate_grad_batches ${ACC_GRAD} \
60 | --default_root_dir ${OUTPUT_DIR} \
61 | --output_dir ${OUTPUT_DIR} \
62 | --max_epochs ${MAX_EPOCH} \
63 | --gradient_clip_val ${GRAD_CLIP} \
64 | --weight_decay ${WEIGHT_DECAY} \
65 | --do_lower_case \
66 | --warmup_proportion ${WARMUP_PROPORTION}
67 |
68 |
--------------------------------------------------------------------------------
/scripts/mrc_squad1/bert_large_dice.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # description:
5 | # predictions_1_7387.json
6 | # EM -> 83.98; F1 -> 90.89
7 |
8 |
9 | FILE_NAME=dice_large
10 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP
11 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
12 | echo "DEBUG INFO -> repo dir is : ${REPO_PATH} "
13 |
14 | MODEL_SCALE=large
15 | DATA_DIR=/userhome/xiaoya/dataset/squad1
16 | BERT_DIR=/userhome/xiaoya/bert/uncased_L-24_H-1024_A-16
17 |
18 | LOSS_TYPE=dice
19 | LR=2e-5
20 | LR_SCHEDULE=polydecay
21 | OPTIMIZER=torch.adam
22 | WARMUP_PROPORTION=0.06
23 |
24 | GRAD_CLIP=1.0
25 | ACC_GRAD=20
26 | MAX_EPOCH=5
27 |
28 | BERT_DROPOUT=0.1
29 | WEIGHT_DECAY=0.002
30 |
31 | # data
32 | TRAIN_BATCH_SIZE=2
33 | MAX_QUERY_LEN=64
34 | MAX_SEQ_LEN=384
35 | DOC_STRIDE=128
36 |
37 | DICE_SMOOTH=1
38 | DICE_OHEM=0.1
39 | DICE_ALPHA=0.01
40 |
41 | PRECISION=16 # default AMP configuration in pytorch-lightning==0.9.0 is 'O2'
42 | PROGRESS_BAR=1
43 | VAL_CHECK_INTERVAL=0.125
44 | DISTRIBUTE=ddp
45 |
46 |
47 | # define OUTPUT directory
48 | OUTPUT_DIR_BASE=/userhome/xiaoya/outputs/dice_loss/squad
49 | OUTPUT_DIR=${OUTPUT_DIR_BASE}/${FILE_NAME}_${MODEL_SCALE}_${MAX_EPOCH}_${GRAD_CLIP}_${ACC_GRAD}_${WARMUP}_${OPTIMIZER}_${LR}_${BERT_DROPOUT}_${BATCH_SIZE}_${MAX_QUERY_LEN}_${MAX_SEQ_LEN}_${DOC_STRIDE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA}
50 |
51 |
52 | echo "INFO -> OUTPUT_DIR is ${OUTPUT_DIR}"
53 | mkdir -p ${OUTPUT_DIR}
54 | CACHE_DIR=${OUTPUT_DIR}/cache
55 | mkdir -p ${CACHE_DIR}
56 |
57 |
58 | python ${REPO_PATH}/squad/train.py \
59 | --gpus="1" \
60 | --precision=${PRECISION} \
61 | --train_batch_size ${TRAIN_BATCH_SIZE} \
62 | --dice_smooth ${DICE_SMOOTH} \
63 | --dice_ohem ${DICE_OHEM} \
64 | --dice_alpha ${DICE_ALPHA} \
65 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
66 | --val_check_interval ${VAL_CHECK_INTERVAL} \
67 | --max_query_length ${MAX_QUERY_LEN} \
68 | --max_seq_length ${MAX_SEQ_LEN} \
69 | --doc_stride ${DOC_STRIDE} \
70 | --optimizer ${OPTIMIZER} \
71 | --loss_type ${LOSS_TYPE} \
72 | --data_dir ${DATA_DIR} \
73 | --bert_hidden_dropout ${BERT_DROPOUT} \
74 | --bert_config_dir ${BERT_DIR} \
75 | --lr ${LR} \
76 | --lr_scheduler ${LR_SCHEDULE} \
77 | --accumulate_grad_batches ${ACC_GRAD} \
78 | --default_root_dir ${OUTPUT_DIR} \
79 | --output_dir ${OUTPUT_DIR} \
80 | --max_epochs ${MAX_EPOCH} \
81 | --gradient_clip_val ${GRAD_CLIP} \
82 | --weight_decay ${WEIGHT_DECAY} \
83 | --do_lower_case \
84 | --dice_square \
85 | --warmup_proportion ${WARMUP_PROPORTION}
86 |
87 |
88 |
89 |
90 |
--------------------------------------------------------------------------------
/scripts/mrc_squad1/bert_large_focal.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | FILE_NAME=focal_large
6 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP
7 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
8 |
9 | DATA_DIR=/userhome/xiaoya/dataset/squad1
10 | BERT_DIR=/userhome/xiaoya/bert/uncased_L-12_H-768_A-12
11 |
12 | LR=3e-5
13 | LR_SCHEDULE=onecycle
14 | OPTIMIZER=adamw
15 | WARMUP_PROPORTION=0.002
16 | GRAD_CLIP=1.0
17 | MAX_EPOCH=2
18 | ACC_GRAD=6
19 |
20 | BERT_DROPOUT=0.1
21 | WEIGHT_DECAY=0.002
22 |
23 | TRAIN_BATCH_SIZE=4
24 | MAX_QUERY_LEN=64
25 | MAX_SEQ_LEN=384
26 | DOC_STRIDE=128
27 |
28 | LOSS_TYPE=focal
29 | FOCAL_GAMMA=2
30 |
31 | OUTPUT_DIR_BASE=/userhome/xiaoya/outputs/dice_loss/squad
32 | OUTPUT_DIR=${OUTPUT_DIR_BASE}/${FILE_NAME}_${MAX_EPOCH}_${GRAD_CLIP}_${ACC_GRAD}_${WARMUP_PROPORTION}_${OPTIMIZER}_${LR}_${BERT_DROPOUT}_${WEIGHT_DECAY}_${BATCH_SIZE}_${MAX_QUERY_LEN}_${MAX_SEQ_LEN}_${DOC_STRIDE}_${FOCAL_GAMMA}
33 |
34 | echo "INFO -> OUTPUT_DIR is ${OUTPUT_DIR}"
35 | mkdir -p ${OUTPUT_DIR}
36 | CACHE_DIR=${OUTPUT_DIR}/cache
37 | mkdir -p ${CACHE_DIR}
38 |
39 | PRECISION=16
40 | PROGRESS_BAR=1
41 | VAL_CHECK_INTERVAL=0.125
42 |
43 | python ${REPO_PATH}/squad/train.py \
44 | --gpus="0,1,2" \
45 | --train_batch_size ${TRAIN_BATCH_SIZE} \
46 | --precision=${PRECISION} \
47 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
48 | --val_check_interval ${VAL_CHECK_INTERVAL} \
49 | --max_query_length ${MAX_QUERY_LEN} \
50 | --max_seq_length ${MAX_SEQ_LEN} \
51 | --doc_stride ${DOC_STRIDE} \
52 | --optimizer ${OPTIMIZER} \
53 | --loss_type ${LOSS_TYPE} \
54 | --data_dir ${DATA_DIR} \
55 | --bert_hidden_dropout ${BERT_DROPOUT} \
56 | --bert_config_dir ${BERT_DIR} \
57 | --lr ${LR} \
58 | --lr_scheduler ${LR_SCHEDULE} \
59 | --warmup_proportion ${WARMUP_PROPORTION} \
60 | --accumulate_grad_batches ${ACC_GRAD} \
61 | --default_root_dir ${OUTPUT_DIR} \
62 | --output_dir ${OUTPUT_DIR} \
63 | --max_epochs ${MAX_EPOCH} \
64 | --gradient_clip_val ${GRAD_CLIP} \
65 | --do_lower_case \
66 | --weight_decay ${WEIGHT_DECAY} \
67 | --focal_gamma ${FOCAL_GAMMA}
68 |
69 |
--------------------------------------------------------------------------------
/scripts/mrc_squad1/eval_pred_file.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP
6 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
7 |
8 | DATA_DIR=$1
9 | OUTPUT_DIR=$2
10 | DATA_FILE=${DATA_DIR}/dev-v1.1.json
11 |
12 | python3 ${REPO_PATH}/tasks/squad/evaluate_predictions.py ${DATA_FILE} ${OUTPUT_DIR}
13 |
14 |
--------------------------------------------------------------------------------
/scripts/mrc_squad1/eval_saved_ckpt.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | REPO_PATH=/data/xiaoya/workspace/dice_loss_for_NLP
6 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
7 |
8 | OUTPUT_DIR=/data/xiaoya/outputs/dice_loss/squad/gpu4_ce_base_2_1.0_1__adamw_3e-5_0.1_12_64_384_128
9 | MODEL_CKPT=${OUTPUT_DIR}/epoch=0_v2.ckpt
10 | HPARAMS_PATH=${OUTPUT_DIR}/lightning_logs/version_0/hparams.yaml
11 |
12 | CUDA_VISIBLE_DEVICES=3 python ${REPO_PATH}/tasks/squad/evaluate_models.py \
13 | --gpus="1" \
14 | --path_to_model_checkpoint ${MODEL_CKPT} \
15 | --path_to_model_hparams_file ${HPARAMS_PATH}
--------------------------------------------------------------------------------
/scripts/ner_enconll03/bert_dice.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # author: xiaoya li
5 | # first create: 2021.02.02
6 | # file: train.sh
7 |
8 |
9 | TIME=2021.09.06
10 | FILE_NAME=enconll_dice
11 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP
12 | MODEL_SCALE=large
13 | DATA_DIR=/userhome/xiaoya/dataset/en_conll03
14 | BERT_DIR=/userhome/xiaoya/bert/bert_cased_large
15 |
16 | TRAIN_BATCH_SIZE=36
17 | EVAL_BATCH_SIZE=1
18 | MAX_LENGTH=256
19 |
20 | OPTIMIZER=torch.adam
21 | LR_SCHEDULE=polydecay
22 | LR=3e-5
23 |
24 | BERT_DROPOUT=0.2
25 | ACC_GRAD=8
26 | MAX_EPOCH=10
27 | GRAD_CLIP=1.0
28 | WEIGHT_DECAY=0.01
29 | WARMUP_PROPORTION=0.01
30 |
31 | LOSS_TYPE=dice
32 | W_START=1
33 | W_END=1
34 | W_SPAN=0.3
35 | DICE_SMOOTH=1
36 | DICE_OHEM=0.0
37 | DICE_ALPHA=0.01
38 | FOCAL_GAMMA=2
39 |
40 | PRECISION=16
41 | PROGRESS_BAR=1
42 | VAL_CHECK_INTERVAL=0.25
43 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
44 |
45 | if [[ ${LOSS_TYPE} == "bce" ]]; then
46 | LOSS_SIGN=${LOSS_TYPE}
47 | elif [[ ${LOSS_TYPE} == "focal" ]]; then
48 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA}
49 | elif [[ ${LOSS_TYPE} == "dice" ]]; then
50 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA}
51 | fi
52 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}"
53 |
54 | OUTPUT_BASE_DIR=/userhome/xiaoya/outputs/dice_loss/mrc_ner/${TIME}
55 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${W_START}_${W_END}_${W_SPAN}_${LOSS_SIGN}
56 |
57 | mkdir -p ${OUTPUT_DIR}
58 |
59 | CUDA_VISIBLE_DEVICES=0 python ${REPO_PATH}/tasks/mrc_ner/train.py \
60 | --gpus="1" \
61 | --precision=${PRECISION} \
62 | --train_batch_size ${TRAIN_BATCH_SIZE} \
63 | --eval_batch_size ${EVAL_BATCH_SIZE} \
64 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
65 | --val_check_interval ${VAL_CHECK_INTERVAL} \
66 | --max_length ${MAX_LENGTH} \
67 | --optimizer ${OPTIMIZER} \
68 | --data_dir ${DATA_DIR} \
69 | --bert_hidden_dropout ${BERT_DROPOUT} \
70 | --bert_config_dir ${BERT_DIR} \
71 | --lr ${LR} \
72 | --lr_scheduler ${LR_SCHEDULE} \
73 | --accumulate_grad_batches ${ACC_GRAD} \
74 | --default_root_dir ${OUTPUT_DIR} \
75 | --output_dir ${OUTPUT_DIR} \
76 | --max_epochs ${MAX_EPOCH} \
77 | --gradient_clip_val ${GRAD_CLIP} \
78 | --weight_decay ${WEIGHT_DECAY} \
79 | --loss_type ${LOSS_TYPE} \
80 | --weight_start ${W_START} \
81 | --weight_end ${W_END} \
82 | --weight_span ${W_SPAN} \
83 | --dice_smooth ${DICE_SMOOTH} \
84 | --dice_ohem ${DICE_OHEM} \
85 | --dice_alpha ${DICE_ALPHA} \
86 | --dice_square \
87 | --warmup_proportion ${WARMUP_PROPORTION} \
88 | --span_loss_candidates gold_pred_random \
89 | --construct_entity_span start_and_end \
90 | --flat_ner \
91 | --pred_answerable "train_infer" \
92 | --answerable_task_ratio 0.4 \
93 | --activate_func relu \
94 | --data_sign en_conll03
--------------------------------------------------------------------------------
/scripts/ner_enconll03/bert_focal.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # author: xiaoya li
5 | # first create: 2021.02.02
6 | # file: train.sh
7 |
8 | TIME=2021.07.23
9 | FILE_NAME=enconll03_focal
10 | REPO_PATH=/userhome/xiaoya/mrc-with-dice-loss
11 | MODEL_SCALE=large
12 | DATA_DIR=/userhome/xiaoya/dataset/en_conll03
13 | BERT_DIR=/userhome/xiaoya/bert/bert-large-cased
14 |
15 | TRAIN_BATCH_SIZE=18
16 | EVAL_BATCH_SIZE=1
17 | MAX_LENGTH=256
18 |
19 | OPTIMIZER=torch.adam
20 | LR_SCHEDULE=polydecay
21 | LR=2e-5
22 |
23 | BERT_DROPOUT=0.2
24 | ACC_GRAD=4
25 | MAX_EPOCH=10
26 | GRAD_CLIP=1.0
27 | WEIGHT_DECAY=0.002
28 | WARMUP_PROPORTION=0.06
29 |
30 | LOSS_TYPE=focal
31 | W_START=1
32 | W_END=1
33 | W_SPAN=0.3
34 | DICE_SMOOTH=1
35 | DICE_OHEM=0.0
36 | DICE_ALPHA=0.01
37 | FOCAL_GAMMA=3
38 |
39 | PRECISION=16
40 | PROGRESS_BAR=1
41 | VAL_CHECK_INTERVAL=0.25
42 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
43 |
44 | if [[ ${LOSS_TYPE} == "bce" ]]; then
45 | LOSS_SIGN=${LOSS_TYPE}
46 | elif [[ ${LOSS_TYPE} == "focal" ]]; then
47 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA}
48 | elif [[ ${LOSS_TYPE} == "dice" ]]; then
49 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA}
50 | fi
51 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}"
52 |
53 | OUTPUT_BASE_DIR=/userhome/xiaoya/outputs/dice_loss/mrc_ner/${TIME}
54 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${W_START}_${W_END}_${W_SPAN}_${LOSS_SIGN}
55 |
56 | mkdir -p ${OUTPUT_DIR}
57 |
58 | CUDA_VISIBLE_DEVICES=0 python ${REPO_PATH}/tasks/mrc_ner/train.py \
59 | --gpus="1" \
60 | --precision=${PRECISION} \
61 | --train_batch_size ${TRAIN_BATCH_SIZE} \
62 | --eval_batch_size ${EVAL_BATCH_SIZE} \
63 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
64 | --val_check_interval ${VAL_CHECK_INTERVAL} \
65 | --max_length ${MAX_LENGTH} \
66 | --optimizer ${OPTIMIZER} \
67 | --data_dir ${DATA_DIR} \
68 | --bert_hidden_dropout ${BERT_DROPOUT} \
69 | --bert_config_dir ${BERT_DIR} \
70 | --lr ${LR} \
71 | --lr_scheduler ${LR_SCHEDULE} \
72 | --accumulate_grad_batches ${ACC_GRAD} \
73 | --default_root_dir ${OUTPUT_DIR} \
74 | --output_dir ${OUTPUT_DIR} \
75 | --max_epochs ${MAX_EPOCH} \
76 | --gradient_clip_val ${GRAD_CLIP} \
77 | --weight_decay ${WEIGHT_DECAY} \
78 | --loss_type ${LOSS_TYPE} \
79 | --weight_start ${W_START} \
80 | --weight_end ${W_END} \
81 | --weight_span ${W_SPAN} \
82 | --dice_smooth ${DICE_SMOOTH} \
83 | --dice_ohem ${DICE_OHEM} \
84 | --dice_alpha ${DICE_ALPHA} \
85 | --dice_square \
86 | --warmup_proportion ${WARMUP_PROPORTION} \
87 | --span_loss_candidates gold_pred_random \
88 | --construct_entity_span start_and_end \
89 | --flat_ner \
90 | --pred_answerable "train_infer" \
91 | --answerable_task_ratio 0.2 \
92 | --activate_func relu \
93 | --data_sign en_conll03
--------------------------------------------------------------------------------
/scripts/ner_enontonotes5/bert_dice.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | FILE_NAME=brain_enonto_dice
6 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP
7 | MODEL_SCALE=base
8 | DATA_DIR=/userhome/xiaoya/dataset/new_mrc_ner/new_en_onto5
9 | BERT_DIR=/userhome/xiaoya/bert/bert_cased_large
10 |
11 | TRAIN_BATCH_SIZE=12
12 | EVAL_BATCH_SIZE=1
13 | MAX_LENGTH=300
14 |
15 | OPTIMIZER=torch.adam
16 | LR_SCHEDULE=linear
17 | LR=2e-5
18 |
19 | BERT_DROPOUT=0.1
20 | ACC_GRAD=6
21 | MAX_EPOCH=5
22 | GRAD_CLIP=1.0
23 | WEIGHT_DECAY=0.002
24 | WARMUP_PROPORTION=0.1
25 |
26 | LOSS_TYPE=dice
27 | W_START=1
28 | W_END=1
29 | W_SPAN=0.3
30 | DICE_SMOOTH=1
31 | DICE_OHEM=0.3
32 | DICE_ALPHA=0.01
33 | FOCAL_GAMMA=2
34 |
35 | PRECISION=16
36 | PROGRESS_BAR=1
37 | VAL_CHECK_INTERVAL=0.25
38 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
39 |
40 | if [[ ${LOSS_TYPE} == "bce" ]]; then
41 | LOSS_SIGN=${LOSS_TYPE}
42 | elif [[ ${LOSS_TYPE} == "focal" ]]; then
43 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA}
44 | elif [[ ${LOSS_TYPE} == "dice" ]]; then
45 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA}
46 | fi
47 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}"
48 |
49 | OUTPUT_BASE_DIR=/userhome/xiaoya/outputs/dice_loss/mrc_ner
50 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${W_START}_${W_END}_${W_SPAN}_${LOSS_SIGN}
51 |
52 | mkdir -p ${OUTPUT_DIR}
53 |
54 | CUDA_VISIBLE_DEVICES=1 python ${REPO_PATH}/tasks/mrc_ner/train.py \
55 | --gpus="1" \
56 | --precision=${PRECISION} \
57 | --train_batch_size ${TRAIN_BATCH_SIZE} \
58 | --eval_batch_size ${EVAL_BATCH_SIZE} \
59 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
60 | --val_check_interval ${VAL_CHECK_INTERVAL} \
61 | --max_length ${MAX_LENGTH} \
62 | --optimizer ${OPTIMIZER} \
63 | --data_dir ${DATA_DIR} \
64 | --bert_hidden_dropout ${BERT_DROPOUT} \
65 | --bert_config_dir ${BERT_DIR} \
66 | --lr ${LR} \
67 | --lr_scheduler ${LR_SCHEDULE} \
68 | --accumulate_grad_batches ${ACC_GRAD} \
69 | --default_root_dir ${OUTPUT_DIR} \
70 | --output_dir ${OUTPUT_DIR} \
71 | --max_epochs ${MAX_EPOCH} \
72 | --gradient_clip_val ${GRAD_CLIP} \
73 | --weight_decay ${WEIGHT_DECAY} \
74 | --loss_type ${LOSS_TYPE} \
75 | --weight_start ${W_START} \
76 | --weight_end ${W_END} \
77 | --weight_span ${W_SPAN} \
78 | --dice_smooth ${DICE_SMOOTH} \
79 | --dice_ohem ${DICE_OHEM} \
80 | --dice_alpha ${DICE_ALPHA} \
81 | --dice_square \
82 | --warmup_proportion ${WARMUP_PROPORTION} \
83 | --span_loss_candidates gold_pred_random \
84 | --construct_entity_span start_and_end \
85 | --num_labels 1 \
86 | --flat_ner \
87 | --pred_answerable "train_infer" \
88 | --answerable_task_ratio 0.2 \
89 | --activate_func relu \
90 | --data_sign en_onto
--------------------------------------------------------------------------------
/scripts/ner_enontonotes5/bert_focal.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | FILE_NAME=brain_enonto_focal
6 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP
7 | MODEL_SCALE=base
8 | DATA_DIR=/userhome/xiaoya/dataset/mrc_ner/new_enonto5
9 | BERT_DIR=/userhome/xiaoya/bert/english_bert_large
10 |
11 | TRAIN_BATCH_SIZE=4
12 | EVAL_BATCH_SIZE=1
13 | MAX_LENGTH=300
14 |
15 | OPTIMIZER=torch.adam
16 | LR_SCHEDULE=linear
17 | LR=1e-5
18 |
19 | BERT_DROPOUT=0.1
20 | ACC_GRAD=12
21 | MAX_EPOCH=5
22 | GRAD_CLIP=1.0
23 | WEIGHT_DECAY=0.002
24 | WARMUP_PROPORTION=0.1
25 |
26 | LOSS_TYPE=focal
27 | W_START=1
28 | W_END=1
29 | W_SPAN=0.3
30 | DICE_SMOOTH=1
31 | DICE_OHEM=0.4
32 | DICE_ALPHA=0.01
33 | FOCAL_GAMMA=4
34 |
35 | PRECISION=16
36 | PROGRESS_BAR=1
37 | VAL_CHECK_INTERVAL=0.25
38 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
39 |
40 | if [[ ${LOSS_TYPE} == "bce" ]]; then
41 | LOSS_SIGN=${LOSS_TYPE}
42 | elif [[ ${LOSS_TYPE} == "focal" ]]; then
43 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA}
44 | elif [[ ${LOSS_TYPE} == "dice" ]]; then
45 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA}
46 | fi
47 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}"
48 |
49 | OUTPUT_BASE_DIR=/userhome/xiaoya/outputs/dice_loss/mrc_ner
50 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${W_START}_${W_END}_${W_SPAN}_${LOSS_SIGN}
51 |
52 | mkdir -p ${OUTPUT_DIR}
53 |
54 | CUDA_VISIBLE_DEVICES=1 python ${REPO_PATH}/tasks/mrc_ner/train.py \
55 | --gpus="1" \
56 | --precision=${PRECISION} \
57 | --train_batch_size ${TRAIN_BATCH_SIZE} \
58 | --eval_batch_size ${EVAL_BATCH_SIZE} \
59 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
60 | --val_check_interval ${VAL_CHECK_INTERVAL} \
61 | --max_length ${MAX_LENGTH} \
62 | --optimizer ${OPTIMIZER} \
63 | --data_dir ${DATA_DIR} \
64 | --bert_hidden_dropout ${BERT_DROPOUT} \
65 | --bert_config_dir ${BERT_DIR} \
66 | --lr ${LR} \
67 | --lr_scheduler ${LR_SCHEDULE} \
68 | --accumulate_grad_batches ${ACC_GRAD} \
69 | --default_root_dir ${OUTPUT_DIR} \
70 | --output_dir ${OUTPUT_DIR} \
71 | --max_epochs ${MAX_EPOCH} \
72 | --gradient_clip_val ${GRAD_CLIP} \
73 | --weight_decay ${WEIGHT_DECAY} \
74 | --loss_type ${LOSS_TYPE} \
75 | --weight_start ${W_START} \
76 | --weight_end ${W_END} \
77 | --weight_span ${W_SPAN} \
78 | --warmup_proportion ${WARMUP_PROPORTION} \
79 | --span_loss_candidates gold_pred_random \
80 | --construct_entity_span start_and_end \
81 | --num_labels 1 \
82 | --flat_ner \
83 | --focal_gamma ${FOCAL_GAMMA} \
84 | --pred_answerable "train_infer" \
85 | --answerable_task_ratio 0.2 \
86 | --activate_func relu \
87 | --data_sign en_onto
--------------------------------------------------------------------------------
/scripts/ner_zhmsra/bert_dice.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | FILE_NAME=reproduce_zhmsra_dice
6 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP
7 | MODEL_SCALE=base
8 | DATA_DIR=/userhome/xiaoya/dataset/new_mrc_ner/new_zh_msra
9 | BERT_DIR=/userhome/xiaoya/bert/chinese_bert
10 |
11 | TRAIN_BATCH_SIZE=10
12 | EVAL_BATCH_SIZE=1
13 | MAX_LENGTH=275
14 |
15 | OPTIMIZER=torch.adam
16 | LR_SCHEDULE=polydecay
17 | LR=2e-5
18 |
19 | BERT_DROPOUT=0.1
20 | ACC_GRAD=3
21 | MAX_EPOCH=5
22 | GRAD_CLIP=1.0
23 | WEIGHT_DECAY=0.002
24 | WARMUP_PROPORTION=0.02
25 |
26 | LOSS_TYPE=dice
27 | W_START=1
28 | W_END=1
29 | W_SPAN=0.2
30 | DICE_SMOOTH=1
31 | DICE_OHEM=0.3
32 | DICE_ALPHA=0.01
33 | FOCAL_GAMMA=2
34 |
35 | PRECISION=16
36 | PROGRESS_BAR=1
37 | VAL_CHECK_INTERVAL=0.25
38 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
39 |
40 | if [[ ${LOSS_TYPE} == "bce" ]]; then
41 | LOSS_SIGN=${LOSS_TYPE}
42 | elif [[ ${LOSS_TYPE} == "focal" ]]; then
43 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA}
44 | elif [[ ${LOSS_TYPE} == "dice" ]]; then
45 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA}
46 | fi
47 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}"
48 |
49 | OUTPUT_BASE_DIR=/userhome/xiaoya/outputs/dice_loss/mrc_ner
50 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${W_START}_${W_END}_${W_SPAN}_${LOSS_SIGN}
51 |
52 | mkdir -p ${OUTPUT_DIR}
53 |
54 | CUDA_VISIBLE_DEVICES=1 python ${REPO_PATH}/tasks/mrc_ner/train.py \
55 | --gpus="1" \
56 | --precision=${PRECISION} \
57 | --train_batch_size ${TRAIN_BATCH_SIZE} \
58 | --eval_batch_size ${EVAL_BATCH_SIZE} \
59 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
60 | --val_check_interval ${VAL_CHECK_INTERVAL} \
61 | --max_length ${MAX_LENGTH} \
62 | --optimizer ${OPTIMIZER} \
63 | --data_dir ${DATA_DIR} \
64 | --bert_hidden_dropout ${BERT_DROPOUT} \
65 | --bert_config_dir ${BERT_DIR} \
66 | --lr ${LR} \
67 | --lr_scheduler ${LR_SCHEDULE} \
68 | --accumulate_grad_batches ${ACC_GRAD} \
69 | --default_root_dir ${OUTPUT_DIR} \
70 | --output_dir ${OUTPUT_DIR} \
71 | --max_epochs ${MAX_EPOCH} \
72 | --gradient_clip_val ${GRAD_CLIP} \
73 | --weight_decay ${WEIGHT_DECAY} \
74 | --loss_type ${LOSS_TYPE} \
75 | --weight_start ${W_START} \
76 | --weight_end ${W_END} \
77 | --weight_span ${W_SPAN} \
78 | --dice_smooth ${DICE_SMOOTH} \
79 | --dice_ohem ${DICE_OHEM} \
80 | --dice_alpha ${DICE_ALPHA} \
81 | --dice_square \
82 | --warmup_proportion ${WARMUP_PROPORTION} \
83 | --span_loss_candidates gold_pred_random \
84 | --construct_entity_span start_and_end \
85 | --num_labels 1 \
86 | --flat_ner \
87 | --is_chinese \
88 | --pred_answerable "train_infer" \
89 | --answerable_task_ratio 0.3 \
90 | --activate_func relu \
91 | --data_sign zh_msra
--------------------------------------------------------------------------------
/scripts/ner_zhmsra/bert_focal.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | FILE_NAME=brain_zhmsra_focal
6 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP
7 | MODEL_SCALE=base
8 | DATA_DIR=/userhome/xiaoya/dataset/mrc_ner/new_msra
9 | BERT_DIR=/userhome/xiaoya/bert/chinese_bert
10 |
11 | TRAIN_BATCH_SIZE=9
12 | EVAL_BATCH_SIZE=1
13 | MAX_LENGTH=275
14 |
15 | OPTIMIZER=torch.adam
16 | LR_SCHEDULE=linear
17 | LR=1e-5
18 |
19 | BERT_DROPOUT=0.1
20 | ACC_GRAD=3
21 | MAX_EPOCH=5
22 | GRAD_CLIP=1.0
23 | WEIGHT_DECAY=0.002
24 | WARMUP_PROPORTION=0.08
25 |
26 | LOSS_TYPE=focal
27 | W_START=1
28 | W_END=1
29 | W_SPAN=0.3
30 | DICE_SMOOTH=1
31 | DICE_OHEM=0.4
32 | DICE_ALPHA=0.01
33 | FOCAL_GAMMA=3
34 |
35 | PRECISION=16
36 | PROGRESS_BAR=1
37 | VAL_CHECK_INTERVAL=0.25
38 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
39 |
40 | if [[ ${LOSS_TYPE} == "bce" ]]; then
41 | LOSS_SIGN=${LOSS_TYPE}
42 | elif [[ ${LOSS_TYPE} == "focal" ]]; then
43 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA}
44 | elif [[ ${LOSS_TYPE} == "dice" ]]; then
45 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA}
46 | fi
47 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}"
48 |
49 | OUTPUT_BASE_DIR=/userhome/xiaoya/outputs/dice_loss/mrc_ner
50 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${W_START}_${W_END}_${W_SPAN}_${LOSS_SIGN}
51 |
52 | mkdir -p ${OUTPUT_DIR}
53 |
54 | CUDA_VISIBLE_DEVICES=1 python ${REPO_PATH}/tasks/mrc_ner/train.py \
55 | --gpus="1" \
56 | --precision=${PRECISION} \
57 | --train_batch_size ${TRAIN_BATCH_SIZE} \
58 | --eval_batch_size ${EVAL_BATCH_SIZE} \
59 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
60 | --val_check_interval ${VAL_CHECK_INTERVAL} \
61 | --max_length ${MAX_LENGTH} \
62 | --optimizer ${OPTIMIZER} \
63 | --data_dir ${DATA_DIR} \
64 | --bert_hidden_dropout ${BERT_DROPOUT} \
65 | --bert_config_dir ${BERT_DIR} \
66 | --lr ${LR} \
67 | --lr_scheduler ${LR_SCHEDULE} \
68 | --accumulate_grad_batches ${ACC_GRAD} \
69 | --default_root_dir ${OUTPUT_DIR} \
70 | --output_dir ${OUTPUT_DIR} \
71 | --max_epochs ${MAX_EPOCH} \
72 | --gradient_clip_val ${GRAD_CLIP} \
73 | --weight_decay ${WEIGHT_DECAY} \
74 | --loss_type ${LOSS_TYPE} \
75 | --weight_start ${W_START} \
76 | --weight_end ${W_END} \
77 | --weight_span ${W_SPAN} \
78 | --focal_gamma ${FOCAL_GAMMA} \
79 | --warmup_proportion ${WARMUP_PROPORTION} \
80 | --span_loss_candidates gold_pred_random \
81 | --construct_entity_span start_end_match \
82 | --num_labels 1 \
83 | --flat_ner \
84 | --is_chinese \
85 | --pred_answerable "train_infer" \
86 | --answerable_task_ratio 0.2 \
87 | --activate_func relu \
88 | --data_sign zh_msra
--------------------------------------------------------------------------------
/scripts/ner_zhonto4/bert_dice.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | FILE_NAME=reproduce_zhonto_dice
6 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP
7 | MODEL_SCALE=base
8 | DATA_DIR=/userhome/xiaoya/dataset/new_mrc_ner/new_zh_onto4
9 | BERT_DIR=/userhome/xiaoya/bert/chinese_bert
10 |
11 | TRAIN_BATCH_SIZE=8
12 | EVAL_BATCH_SIZE=1
13 | MAX_LENGTH=300
14 |
15 | OPTIMIZER=torch.adam
16 | LR_SCHEDULE=polydecay
17 | LR=2e-5
18 |
19 | BERT_DROPOUT=0.1
20 | ACC_GRAD=2
21 | MAX_EPOCH=5
22 | GRAD_CLIP=1.0
23 | WEIGHT_DECAY=0.002
24 | WARMUP_PROPORTION=0.1
25 |
26 | LOSS_TYPE=dice
27 | W_START=1
28 | W_END=1
29 | W_SPAN=0.3
30 | DICE_SMOOTH=1
31 | DICE_OHEM=0.3
32 | DICE_ALPHA=0.01
33 | FOCAL_GAMMA=2
34 |
35 | PRECISION=16
36 | PROGRESS_BAR=1
37 | VAL_CHECK_INTERVAL=0.25
38 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
39 |
40 | if [[ ${LOSS_TYPE} == "bce" ]]; then
41 | LOSS_SIGN=${LOSS_TYPE}
42 | elif [[ ${LOSS_TYPE} == "focal" ]]; then
43 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA}
44 | elif [[ ${LOSS_TYPE} == "dice" ]]; then
45 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA}
46 | fi
47 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}"
48 |
49 | OUTPUT_BASE_DIR=/userhome/xiaoya/outputs/dice_loss/mrc_ner
50 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${W_START}_${W_END}_${W_SPAN}_${LOSS_SIGN}
51 |
52 | mkdir -p ${OUTPUT_DIR}
53 |
54 | CUDA_VISIBLE_DEVICES=0 python ${REPO_PATH}/tasks/mrc_ner/train.py \
55 | --gpus="1" \
56 | --precision=${PRECISION} \
57 | --train_batch_size ${TRAIN_BATCH_SIZE} \
58 | --eval_batch_size ${EVAL_BATCH_SIZE} \
59 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
60 | --val_check_interval ${VAL_CHECK_INTERVAL} \
61 | --max_length ${MAX_LENGTH} \
62 | --optimizer ${OPTIMIZER} \
63 | --data_dir ${DATA_DIR} \
64 | --bert_hidden_dropout ${BERT_DROPOUT} \
65 | --bert_config_dir ${BERT_DIR} \
66 | --lr ${LR} \
67 | --lr_scheduler ${LR_SCHEDULE} \
68 | --accumulate_grad_batches ${ACC_GRAD} \
69 | --default_root_dir ${OUTPUT_DIR} \
70 | --output_dir ${OUTPUT_DIR} \
71 | --max_epochs ${MAX_EPOCH} \
72 | --gradient_clip_val ${GRAD_CLIP} \
73 | --weight_decay ${WEIGHT_DECAY} \
74 | --loss_type ${LOSS_TYPE} \
75 | --weight_start ${W_START} \
76 | --weight_end ${W_END} \
77 | --weight_span ${W_SPAN} \
78 | --dice_smooth ${DICE_SMOOTH} \
79 | --dice_ohem ${DICE_OHEM} \
80 | --dice_alpha ${DICE_ALPHA} \
81 | --dice_square \
82 | --warmup_proportion ${WARMUP_PROPORTION} \
83 | --span_loss_candidates gold_pred_random \
84 | --construct_entity_span start_and_end \
85 | --num_labels 1 \
86 | --flat_ner \
87 | --is_chinese \
88 | --pred_answerable "train_infer" \
89 | --answerable_task_ratio 0.2 \
90 | --activate_func relu
--------------------------------------------------------------------------------
/scripts/ner_zhonto4/bert_focal.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | FILE_NAME=brain_zhonto_focal
6 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP
7 | MODEL_SCALE=base
8 | DATA_DIR=/userhome/xiaoya/dataset/mrc_ner/new_onto4
9 | BERT_DIR=/userhome/xiaoya/bert/chinese_bert
10 |
11 | TRAIN_BATCH_SIZE=8
12 | EVAL_BATCH_SIZE=1
13 | MAX_LENGTH=300
14 |
15 | OPTIMIZER=torch.adam
16 | LR_SCHEDULE=polydecay
17 | LR=2e-5
18 |
19 | BERT_DROPOUT=0.1
20 | ACC_GRAD=2
21 | MAX_EPOCH=5
22 | GRAD_CLIP=1.0
23 | WEIGHT_DECAY=0.002
24 | WARMUP_PROPORTION=0.1
25 |
26 | LOSS_TYPE=focal
27 | W_START=1
28 | W_END=1
29 | W_SPAN=0.3
30 | DICE_SMOOTH=1
31 | DICE_OHEM=0.4
32 | DICE_ALPHA=0.01
33 | FOCAL_GAMMA=4
34 |
35 | PRECISION=16
36 | PROGRESS_BAR=1
37 | VAL_CHECK_INTERVAL=0.25
38 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
39 |
40 | if [[ ${LOSS_TYPE} == "bce" ]]; then
41 | LOSS_SIGN=${LOSS_TYPE}
42 | elif [[ ${LOSS_TYPE} == "focal" ]]; then
43 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA}
44 | elif [[ ${LOSS_TYPE} == "dice" ]]; then
45 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA}
46 | fi
47 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}"
48 |
49 | OUTPUT_BASE_DIR=/userhome/xiaoya/outputs/dice_loss/mrc_ner
50 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${W_START}_${W_END}_${W_SPAN}_${LOSS_SIGN}
51 |
52 | mkdir -p ${OUTPUT_DIR}
53 |
54 | CUDA_VISIBLE_DEVICES=1 python ${REPO_PATH}/tasks/mrc_ner/train.py \
55 | --gpus="1" \
56 | --precision=${PRECISION} \
57 | --train_batch_size ${TRAIN_BATCH_SIZE} \
58 | --eval_batch_size ${EVAL_BATCH_SIZE} \
59 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
60 | --val_check_interval ${VAL_CHECK_INTERVAL} \
61 | --max_length ${MAX_LENGTH} \
62 | --optimizer ${OPTIMIZER} \
63 | --data_dir ${DATA_DIR} \
64 | --bert_hidden_dropout ${BERT_DROPOUT} \
65 | --bert_config_dir ${BERT_DIR} \
66 | --lr ${LR} \
67 | --lr_scheduler ${LR_SCHEDULE} \
68 | --accumulate_grad_batches ${ACC_GRAD} \
69 | --default_root_dir ${OUTPUT_DIR} \
70 | --output_dir ${OUTPUT_DIR} \
71 | --max_epochs ${MAX_EPOCH} \
72 | --gradient_clip_val ${GRAD_CLIP} \
73 | --weight_decay ${WEIGHT_DECAY} \
74 | --loss_type ${LOSS_TYPE} \
75 | --weight_start ${W_START} \
76 | --weight_end ${W_END} \
77 | --weight_span ${W_SPAN} \
78 | --focal_gamma ${FOCAL_GAMMA} \
79 | --warmup_proportion ${WARMUP_PROPORTION} \
80 | --span_loss_candidates gold_pred_random \
81 | --construct_entity_span start_and_end \
82 | --num_labels 1 \
83 | --flat_ner \
84 | --is_chinese \
85 | --pred_answerable "train_infer" \
86 | --answerable_task_ratio 0.3 \
87 | --activate_func relu
--------------------------------------------------------------------------------
/scripts/prepare_ckpt.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # description:
5 | # NOTICE:
6 | # Please make sure tensorflow
7 | #
8 |
9 | # should install tensorflow for loading parameters in Pretrained Models.
10 | pip install tensorflow
11 |
12 |
13 | BERT_BASE_DIR=$1
14 |
15 | transformers-cli convert --model_type bert \
16 | --tf_checkpoint ${BERT_BASE_DIR}/bert_model.ckpt \
17 | --config ${BERT_BASE_DIR}/bert_config.json \
18 | --pytorch_dump_output ${BERT_BASE_DIR}/pytorch_model.bin
19 |
20 | cp ${BERT_BASE_DIR}/bert_config.json ${BERT_BASE_DIR}/config.json
--------------------------------------------------------------------------------
/scripts/prepare_mrpc_data.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # example:
5 | # bash scripts/prepare_mrpc_data.sh /data/xiaoya/outputs/debug
6 |
7 | SAVE_DATA_DIR=$1
8 | DEV_IDS_FILE=$PWD/tasks/glue/mrpc_dev_ids.tsv
9 |
10 | echo "***** INFO ***** -> data dir is : $SAVE_DATA_DIR"
11 | echo "***** INFO ***** dev ids file is : $DEV_IDS_FILE"
12 |
13 | # download mrpc data files
14 | wget https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt -P ${SAVE_DATA_DIR}
15 | wget https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt -P ${SAVE_DATA_DIR}
16 |
17 | # process mrpc data
18 | python3 $PWD/tasks/glue/process_mrpc.py ${SAVE_DATA_DIR} ${DEV_IDS_FILE}
--------------------------------------------------------------------------------
/scripts/textcl_tnews/bert_dice.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | FILE_NAME=tnews_dice
6 | REPO_PATH=/data/xiaoya/workspace/dice_loss_for_NLP
7 | MODEL_SCALE=base
8 | DATA_DIR=/data/xiaoya/datasets/tnews_public_data
9 | BERT_DIR=/data/xiaoya/pretrain_lm/chinese_L-12_H-768_A-12
10 |
11 | TRAIN_BATCH_SIZE=36
12 | EVAL_BATCH_SIZE=12
13 | MAX_LENGTH=128
14 |
15 |
16 | OPTIMIZER=torch.adam
17 | LR_SCHEDULE=polydecay
18 | LR=3e-5
19 |
20 | BERT_DROPOUT=0.1
21 | ACC_GRAD=1
22 | MAX_EPOCH=5
23 | GRAD_CLIP=1.0
24 | WEIGHT_DECAY=0.002
25 | WARMUP_PROPORTION=0.1
26 |
27 | LOSS_TYPE=dice
28 | # ce, focal, dice
29 | DICE_SMOOTH=1
30 | DICE_OHEM=0
31 | DICE_ALPHA=0.01
32 | FOCAL_GAMMA=2
33 |
34 | PRECISION=32
35 | PROGRESS_BAR=1
36 | VAL_CHECK_INTERVAL=0.25
37 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
38 |
39 |
40 | if [[ ${LOSS_TYPE} == "bce" ]]; then
41 | LOSS_SIGN=${LOSS_TYPE}
42 | elif [[ ${LOSS_TYPE} == "focal" ]]; then
43 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA}
44 | elif [[ ${LOSS_TYPE} == "dice" ]]; then
45 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA}
46 | fi
47 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}"
48 |
49 | OUTPUT_BASE_DIR=/data/xiaoya/outputs/dice_loss/tnews
50 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${LOSS_SIGN}
51 |
52 | mkdir -p ${OUTPUT_DIR}
53 |
54 | CUDA_VISIBLE_DEVICES=0 python ${REPO_PATH}/tasks/tnews/train.py \
55 | --gpus="1" \
56 | --precision=${PRECISION} \
57 | --train_batch_size ${TRAIN_BATCH_SIZE} \
58 | --eval_batch_size ${EVAL_BATCH_SIZE} \
59 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
60 | --val_check_interval ${VAL_CHECK_INTERVAL} \
61 | --max_length ${MAX_LENGTH} \
62 | --optimizer ${OPTIMIZER} \
63 | --data_dir ${DATA_DIR} \
64 | --bert_hidden_dropout ${BERT_DROPOUT} \
65 | --bert_config_dir ${BERT_DIR} \
66 | --lr ${LR} \
67 | --lr_scheduler ${LR_SCHEDULE} \
68 | --accumulate_grad_batches ${ACC_GRAD} \
69 | --default_root_dir ${OUTPUT_DIR} \
70 | --output_dir ${OUTPUT_DIR} \
71 | --max_epochs ${MAX_EPOCH} \
72 | --gradient_clip_val ${GRAD_CLIP} \
73 | --weight_decay ${WEIGHT_DECAY} \
74 | --loss_type ${LOSS_TYPE} \
75 | --dice_smooth ${DICE_SMOOTH} \
76 | --dice_ohem ${DICE_OHEM} \
77 | --dice_alpha ${DICE_ALPHA} \
78 | --dice_square \
79 | --warmup_proportion ${WARMUP_PROPORTION}
--------------------------------------------------------------------------------
/scripts/textcl_tnews/bert_focal.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | FILE_NAME=tnews_focal
6 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP
7 | MODEL_SCALE=base
8 | DATA_DIR=/userhome/xiaoya/dataset/tnews
9 | BERT_DIR=/userhome/xiaoya/bert/chinese_bert
10 |
11 | TRAIN_BATCH_SIZE=18
12 | EVAL_BATCH_SIZE=12
13 | MAX_LENGTH=128
14 |
15 | OPTIMIZER=torch.adam
16 | LR_SCHEDULE=linear
17 | LR=3e-5
18 |
19 | BERT_DROPOUT=0.2
20 | ACC_GRAD=1
21 | MAX_EPOCH=5
22 | GRAD_CLIP=1.0
23 | WEIGHT_DECAY=0.002
24 | WARMUP_PROPORTION=0.02
25 |
26 | LOSS_TYPE=focal
27 | # ce, focal, dice
28 | DICE_SMOOTH=1
29 | DICE_OHEM=1
30 | DICE_ALPHA=0.01
31 | FOCAL_GAMMA=4
32 |
33 | PRECISION=16
34 | PROGRESS_BAR=1
35 | VAL_CHECK_INTERVAL=0.25
36 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
37 |
38 | if [[ ${LOSS_TYPE} == "ce" ]]; then
39 | LOSS_SIGN=${LOSS_TYPE}
40 | elif [[ ${LOSS_TYPE} == "focal" ]]; then
41 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA}
42 | elif [[ ${LOSS_TYPE} == "dice" ]]; then
43 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA}
44 | fi
45 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}"
46 |
47 | OUTPUT_BASE_DIR=/userhome/xiaoya/outputs/dice_loss/tnews
48 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${LOSS_SIGN}
49 |
50 | mkdir -p ${OUTPUT_DIR}
51 |
52 | CUDA_VISIBLE_DEVICES=0 python ${REPO_PATH}/tasks/tnews/train.py \
53 | --gpus="1" \
54 | --precision=${PRECISION} \
55 | --train_batch_size ${TRAIN_BATCH_SIZE} \
56 | --eval_batch_size ${EVAL_BATCH_SIZE} \
57 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
58 | --val_check_interval ${VAL_CHECK_INTERVAL} \
59 | --max_length ${MAX_LENGTH} \
60 | --optimizer ${OPTIMIZER} \
61 | --data_dir ${DATA_DIR} \
62 | --bert_hidden_dropout ${BERT_DROPOUT} \
63 | --bert_config_dir ${BERT_DIR} \
64 | --lr ${LR} \
65 | --lr_scheduler ${LR_SCHEDULE} \
66 | --accumulate_grad_batches ${ACC_GRAD} \
67 | --default_root_dir ${OUTPUT_DIR} \
68 | --output_dir ${OUTPUT_DIR} \
69 | --max_epochs ${MAX_EPOCH} \
70 | --gradient_clip_val ${GRAD_CLIP} \
71 | --weight_decay ${WEIGHT_DECAY} \
72 | --loss_type ${LOSS_TYPE} \
73 | --focal_gamma ${FOCAL_GAMMA} \
74 | --warmup_proportion ${WARMUP_PROPORTION}
--------------------------------------------------------------------------------
/tasks/glue/download_glue_data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | ''' Script for downloading all GLUE data.
4 | Note: for legal reasons, we are unable to host MRPC.
5 | You can either use the version hosted by the SentEval team, which is already tokenized,
6 | or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually.
7 | For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example).
8 | You should then rename and place specific files in a folder (see below for an example).
9 | mkdir MRPC
10 | cabextract MSRParaphraseCorpus.msi -d MRPC
11 | cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt
12 | cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt
13 | rm MRPC/_*
14 | rm MSRParaphraseCorpus.msi
15 | 1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now.
16 | 2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray!
17 | '''
18 |
19 | import os
20 | import sys
21 | import shutil
22 | import argparse
23 | import tempfile
24 | import urllib.request
25 | import zipfile
26 |
27 | import os
28 | import sys
29 | import shutil
30 | import argparse
31 | import tempfile
32 | import urllib.request
33 | import zipfile
34 |
35 |
36 | TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
37 | TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4',
38 | "SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
39 | "MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc',
40 | "QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5',
41 | "STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5',
42 | "MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce',
43 | "SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df',
44 | "QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601',
45 | "RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb',
46 | "WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf',
47 | "diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'}
48 |
49 | MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt'
50 | MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt'
51 |
52 | def download_and_extract(task, data_dir):
53 | print("Downloading and extracting %s..." % task)
54 | data_file = "%s.zip" % task
55 | urllib.request.urlretrieve(TASK2PATH[task], data_file)
56 | with zipfile.ZipFile(data_file) as zip_ref:
57 | zip_ref.extractall(data_dir)
58 | os.remove(data_file)
59 | print("\tCompleted!")
60 |
61 | def format_mrpc(data_dir, path_to_data):
62 | print("Processing MRPC...")
63 | mrpc_dir = os.path.join(data_dir, "MRPC")
64 | if not os.path.isdir(mrpc_dir):
65 | os.mkdir(mrpc_dir)
66 | if path_to_data:
67 | mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")
68 | mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")
69 | else:
70 | print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN)
71 | mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
72 | mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
73 | urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file)
74 | urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file)
75 | assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file
76 | assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file
77 | urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))
78 |
79 | dev_ids = []
80 | with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh:
81 | for row in ids_fh:
82 | dev_ids.append(row.strip().split('\t'))
83 |
84 | with open(mrpc_train_file, encoding="utf8") as data_fh, \
85 | open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \
86 | open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh:
87 | header = data_fh.readline()
88 | train_fh.write(header)
89 | dev_fh.write(header)
90 | for row in data_fh:
91 | label, id1, id2, s1, s2 = row.strip().split('\t')
92 | if [id1, id2] in dev_ids:
93 | dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
94 | else:
95 | train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
96 |
97 | with open(mrpc_test_file, encoding="utf8") as data_fh, \
98 | open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh:
99 | header = data_fh.readline()
100 | test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
101 | for idx, row in enumerate(data_fh):
102 | label, id1, id2, s1, s2 = row.strip().split('\t')
103 | test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
104 | print("\tCompleted!")
105 |
106 | def download_diagnostic(data_dir):
107 | print("Downloading and extracting diagnostic...")
108 | if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
109 | os.mkdir(os.path.join(data_dir, "diagnostic"))
110 | data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv")
111 | urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file)
112 | print("\tCompleted!")
113 | return
114 |
115 | def get_tasks(task_names):
116 | task_names = task_names.split(',')
117 | if "all" in task_names:
118 | tasks = TASKS
119 | else:
120 | tasks = []
121 | for task_name in task_names:
122 | assert task_name in TASKS, "Task %s not found!" % task_name
123 | tasks.append(task_name)
124 | return tasks
125 |
126 | def main(arguments):
127 | parser = argparse.ArgumentParser()
128 | parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data')
129 | parser.add_argument('--tasks', help='tasks to download data for as a comma separated string',
130 | type=str, default='all')
131 | parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt',
132 | type=str, default='')
133 | args = parser.parse_args(arguments)
134 |
135 | if not os.path.isdir(args.data_dir):
136 | os.mkdir(args.data_dir)
137 | tasks = get_tasks(args.tasks)
138 |
139 | for task in tasks:
140 | print(task)
141 | if task == 'MRPC':
142 | format_mrpc(args.data_dir, args.path_to_mrpc)
143 | elif task == 'diagnostic':
144 | download_diagnostic(args.data_dir)
145 | else:
146 | download_and_extract(task, args.data_dir)
147 |
148 |
149 | if __name__ == '__main__':
150 | sys.exit(main(sys.argv[1:]))
151 | # python3 download_glue_data.py --data_dir /data/xiaoya/datasets/glue --tasks QQP
152 |
--------------------------------------------------------------------------------
/tasks/glue/evaluate_models.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: glue/evaluate.py
5 | # description:
6 | # code for evaluating saved model checkpoints.
7 |
8 | import os
9 | import argparse
10 | from utils.random_seed import set_random_seed
11 | set_random_seed(0)
12 |
13 | from pytorch_lightning import Trainer
14 | from tasks.glue.train import BertForGLUETask
15 | from utils.get_parser import get_parser
16 |
17 |
18 | def init_evaluate_parser(parser) -> argparse.ArgumentParser:
19 | parser.add_argument("--path_to_model_checkpoint", type=str, help="")
20 | parser.add_argument("--path_to_model_hparams_file", type=str, default="")
21 | return parser
22 |
23 |
24 | def evaluate(args):
25 | trainer = Trainer(gpus=args.gpus,
26 | distributed_backend=args.distributed_backend,
27 | deterministic=True)
28 | model = BertForGLUETask.load_from_checkpoint(
29 | checkpoint_path=args.path_to_model_checkpoint,
30 | hparams_file=args.path_to_model_hparams_file,
31 | map_location=None,
32 | batch_size=args.eval_batch_size
33 | )
34 | trainer.test(model=model)
35 |
36 |
37 | def main():
38 | eval_parser = get_parser()
39 | eval_parser = init_evaluate_parser(eval_parser)
40 | eval_parser = BertForGLUETask.add_model_specific_args(eval_parser)
41 | eval_parser = Trainer.add_argparse_args(eval_parser)
42 | args = eval_parser.parse_args()
43 |
44 | if len(args.path_to_model_hparams_file) == 0:
45 | eval_output_dir = "/".join(args.path_to_model_checkpoint.split("/")[:-1])
46 | args.path_to_model_hparams_file = os.path.join(eval_output_dir, "lightning_logs", "version_0", "hparams.yaml")
47 |
48 | evaluate(args)
49 |
50 |
51 | if __name__ == "__main__":
52 | main()
53 |
54 |
--------------------------------------------------------------------------------
/tasks/glue/evaluate_predictions.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: evaluate_predictions.py
5 | # description:
6 | #
7 |
8 | import os
9 | import sys
10 | import csv
11 | from glob import glob
12 | from metrics.functional.cls_acc_f1 import compute_acc_f1_from_list
13 |
14 | def eval_single_file(file_path, quotechar=None, num_labels=2):
15 | with open(file_path, "r") as r_f:
16 | data_lines = list(csv.reader(r_f, delimiter="\t", quotechar=quotechar))
17 | # id \t pred \t gold
18 |
19 | pred_labels = []
20 | gold_labels = []
21 | for idx, data_line in enumerate(data_lines):
22 | if idx == 0:
23 | continue
24 | pred_labels.append(data_line[1])
25 | gold_labels.append(data_line[2])
26 | acc, f1, precision, recall = compute_acc_f1_from_list(pred_labels, gold_labels, num_labels=num_labels)
27 | print(f"acc: {acc}, f1: {f1}, precision: {precision}, recall: {recall}")
28 |
29 |
30 | def eval_files_in_folder(folder_path, prefix="dev", quotechar=None):
31 | file_lst = glob(os.path.join(folder_path, f"{prefix}-*.txt"))
32 | file_lst = sorted(file_lst)
33 |
34 | best_f1 = 0
35 | acc_when_best_f1 = 0
36 | best_file = ""
37 | for file_item in file_lst:
38 | acc, f1, precision, recall = eval_single_file(file_item)
39 | if f1 > best_f1:
40 | best_f1 = f1
41 | acc_when_best_f1 = acc
42 | best_file = file_item
43 | print(f"INFO -> {file_item}")
44 | print(f"INFO -> acc: {acc}, f1: {f1}, precision: {precision}, recall: {recall}")
45 |
46 | print(f"Summary INFO -> Best f1: {best_f1}, acc: {acc_when_best_f1}")
47 | print(f"Summary INFO -> Best file: {best_file}")
48 |
49 | if __name__ == "__main__":
50 | eval_folder_or_file = sys.argv[1]
51 | if eval_folder_or_file.endswith('.txt'):
52 | eval_single_file(eval_folder_or_file)
53 | else:
54 | eval_files_in_folder(eval_folder_or_file)
55 |
--------------------------------------------------------------------------------
/tasks/glue/mrpc_dev_ids.tsv:
--------------------------------------------------------------------------------
1 | 1606495 1606619
2 | 3444633 3444733
3 | 1196962 1197061
4 | 1720166 1720115
5 | 2546175 2546198
6 | 2167771 2167744
7 | 2688145 2688162
8 | 2095803 2095786
9 | 1629064 1629043
10 | 116294 116332
11 | 722278 722383
12 | 1912524 1912648
13 | 726966 726945
14 | 936978 937500
15 | 2906104 2906322
16 | 1095780 1095652
17 | 487951 488007
18 | 3294207 3294290
19 | 1744257 1744378
20 | 1490044 1489975
21 | 2896308 2896334
22 | 2907762 2907649
23 | 101747 101777
24 | 1490811 1490840
25 | 2065523 2065836
26 | 886904 887158
27 | 2204592 2204588
28 | 3084554 3084612
29 | 917965 918315
30 | 2396937 2396818
31 | 3147370 3147525
32 | 71501 71627
33 | 1166473 1166857
34 | 1353356 1353174
35 | 1876120 1876059
36 | 1708040 1708062
37 | 2265271 2265152
38 | 2224884 2224819
39 | 2551891 2551563
40 | 360875 360943
41 | 2320654 2320666
42 | 3013483 3013540
43 | 1224743 1225510
44 | 58540 58567
45 | 1015249 1015204
46 | 3039310 3039413
47 | 726399 726078
48 | 2029631 2029565
49 | 1461629 1461781
50 | 1465073 1464854
51 | 283751 283290
52 | 799346 799268
53 | 533823 533909
54 | 801552 801516
55 | 2829648 2829613
56 | 2074182 2074668
57 | 1958079 1958143
58 | 3140260 3140288
59 | 969381 969512
60 | 2307064 2307235
61 | 271891 271839
62 | 1010655 1010430
63 | 1954 2142
64 | 985015 984975
65 | 218848 218851
66 | 753928 753890
67 | 555617 555528
68 | 487993 487952
69 | 3428298 3428362
70 | 389239 389299
71 | 2249237 2249305
72 | 970740 971209
73 | 2988297 2988555
74 | 2204353 2204418
75 | 544217 544325
76 | 2083598 2083810
77 | 263690 263819
78 | 2587300 2587243
79 | 3070979 3070949
80 | 953733 953537
81 | 2561999 2561941
82 | 684848 684557
83 | 162203 162101
84 | 958161 957782
85 | 2926039 2925982
86 | 2112330 2112376
87 | 3329379 3329416
88 | 961836 962243
89 | 1808166 1808434
90 | 2632692 2632767
91 | 2652187 2652218
92 | 1430357 1430425
93 | 144089 143697
94 | 661390 661218
95 | 325763 325928
96 | 3320577 3320553
97 | 2111629 2111786
98 | 2222998 2223097
99 | 205100 205145
100 | 533903 533818
101 | 629316 629289
102 | 1033204 1033365
103 | 3180014 3179967
104 | 2673104 2673130
105 | 349215 349241
106 | 129995 129864
107 | 368067 368018
108 | 655498 655391
109 | 2139506 2139427
110 | 1661381 1661317
111 | 426112 426210
112 | 1990975 1991132
113 | 162632 162653
114 | 1889954 1889847
115 | 347017 347002
116 | 3150803 3150839
117 | 173879 173832
118 | 431076 431242
119 | 1261116 1261234
120 | 1636060 1635946
121 | 2252795 2252970
122 | 1128884 1128865
123 | 3214517 3214483
124 | 1014983 1014963
125 | 1057995 1057778
126 | 2587767 2587673
127 | 539585 539355
128 | 1756329 1756394
129 | 2963943 2963880
130 | 3242051 3241897
131 | 977772 977804
132 | 2131318 2131372
133 | 103280 103431
134 | 581592 581570
135 | 2083612 2083810
136 | 2155514 2155377
137 | 723557 724115
138 | 1351550 1351155
139 | 611663 611716
140 | 886618 886456
141 | 1989515 1989458
142 | 224932 224868
143 | 751520 751373
144 | 396041 396188
145 | 55187 54831
146 | 132553 132725
147 | 1673112 1673068
148 | 1638813 1639087
149 | 2208376 2208198
150 | 849291 849442
151 | 2638861 2638982
152 | 1967578 1967664
153 | 781439 781461
154 | 1119721 1119714
155 | 34513 34742
156 | 644788 644816
157 | 2495223 2495307
158 | 954526 954607
159 | 195728 196099
160 | 2010705 2010779
161 | 1277539 1277527
162 | 2638975 2638855
163 | 2357324 2357271
164 | 2763381 2763517
165 | 1597193 1597119
166 | 555553 555528
167 | 2796658 2796682
168 | 101746 101775
169 | 2116843 2116883
170 | 1100998 1100441
171 | 3255597 3255668
172 | 1909579 1909408
173 | 2919853 2919804
174 | 315785 315653
175 | 1264509 1264471
176 | 3439114 3439084
177 | 3062202 3062308
178 | 2614947 2614904
179 | 1462409 1462504
180 | 2199097 2199072
181 | 331980 332110
182 | 3267026 3266930
183 | 698948 698933
184 | 461779 461815
185 | 1910610 1910455
186 | 389117 389052
187 | 789691 789665
188 | 1348909 1348954
189 | 261202 260995
190 | 2820371 2820525
191 | 696677 696932
192 | 54181 53570
193 | 589579 589557
194 | 2128530 2128455
195 | 3113791 3113782
196 | 637168 637447
197 | 490355 490378
198 | 780604 780466
199 | 219064 218969
200 | 2823575 2823513
201 | 3181118 3181443
202 | 485999 486011
203 | 2304696 2304863
204 | 2916199 2916164
205 | 2829194 2829229
206 | 1167835 1167651
207 | 1438666 1438643
208 | 98432 98657
209 | 249699 249623
210 | 347022 347003
211 | 2749410 2749625
212 | 2517014 2516995
213 | 2766112 2766084
214 | 2198694 2198937
215 | 548867 548785
216 | 2758265 2758282
217 | 981185 981234
218 | 1354501 1354476
219 | 2758944 2758975
220 | 1865364 1865251
221 | 131979 131957
222 | 490376 490490
223 | 146112 146127
224 | 2763517 2763576
225 | 327839 327748
226 | 3111452 3111428
227 | 1831696 1831660
228 | 515581 515752
229 | 315647 315778
230 | 1783137 1782659
231 | 1393764 1393984
232 | 1980654 1980641
233 | 1989213 1989116
234 | 3022833 3023029
235 | 86007 86373
236 | 1685339 1685429
237 | 1592037 1592076
238 | 2493369 2493428
239 | 1726935 1726879
240 | 3389318 3389271
241 | 3394891 3394775
242 | 2324704 2325023
243 | 2455942 2455978
244 | 192285 192327
245 | 3400796 3400822
246 | 4733 4557
247 | 1050307 1050144
248 | 1112021 1111925
249 | 3376093 3376101
250 | 816867 816831
251 | 3218713 3218830
252 | 1864253 1863810
253 | 3107137 3107119
254 | 3039165 3039036
255 | 3039007 3038845
256 | 2015389 2015410
257 | 1605818 1605806
258 | 2796978 2797024
259 | 1201306 1201329
260 | 2339738 2339771
261 | 3300040 3299992
262 | 2749322 2749663
263 | 2745055 2745022
264 | 3046488 3046824
265 | 2241925 2242066
266 | 86020 86007
267 | 69773 69792
268 | 1057876 1057778
269 | 2965576 2965701
270 | 577854 578500
271 | 221515 221509
272 | 587009 586969
273 | 3264732 3264648
274 | 3023029 3023229
275 | 2523564 2523358
276 | 1552068 1551928
277 | 1439663 1439808
278 | 2377289 2377259
279 | 2283737 2283794
280 | 588637 588864
281 | 1825432 1825301
282 | 2748287 2748550
283 | 1704987 1705268
284 | 54142 53641
285 | 2259788 2259747
286 | 2090911 2091154
287 | 3093023 3092996
288 | 3122429 3122305
289 | 1521034 1520582
290 | 2324708 2325028
291 | 3261484 3261306
292 | 1675025 1675047
293 | 2317018 2317252
294 | 3448488 3448449
295 | 1762569 1762526
296 | 1602860 1602844
297 | 1824224 1824209
298 | 2640607 2640576
299 | 2697659 2697747
300 | 2440680 2440474
301 | 818091 817811
302 | 853475 853342
303 | 2175939 2176090
304 | 314997 315030
305 | 1220668 1220801
306 | 554905 554627
307 | 3035788 3035918
308 | 383417 383558
309 | 1089053 1089297
310 | 1831453 1831491
311 | 2274844 2274714
312 | 2706154 2706185
313 | 2889005 2888954
314 | 1355540 1355592
315 | 2380695 2380822
316 | 1616174 1616206
317 | 1528383 1528083
318 | 635783 635802
319 | 1580638 1580663
320 | 1549586 1549609
321 | 2826681 2826474
322 | 221079 221003
323 | 720572 720486
324 | 3311600 3311633
325 | 460211 460445
326 | 2385288 2385256
327 | 1908763 1908744
328 | 2996241 2996734
329 | 2691044 2691264
330 | 1386884 1386857
331 | 2977500 2977547
332 | 1330643 1330622
333 | 2240399 2240149
334 | 2931098 2931144
335 | 919683 919782
336 | 60122 60445
337 | 805457 805985
338 | 3435735 3435717
339 | 110731 110648
340 | 524136 524119
341 | 3439854 3439874
342 | 2008984 2009175
343 | 260952 260924
344 | 844421 844679
345 | 872784 872834
346 | 1423836 1423708
347 | 2079200 2079131
348 | 753858 753890
349 | 787432 787464
350 | 2110220 2110199
351 | 1186754 1187056
352 | 2110775 2110924
353 | 780408 780363
354 | 52758 52343
355 | 763948 763991
356 | 2810634 2810670
357 | 2584416 2584653
358 | 2268396 2268480
359 | 447728 447699
360 | 2573262 2573319
361 | 1550897 1550977
362 | 941617 941673
363 | 3310210 3310286
364 | 2494149 2494073
365 | 1619244 1619274
366 | 2531749 2531607
367 | 374015 374162
368 | 2221603 2221633
369 | 2362761 2362698
370 | 2834988 2835026
371 | 1605350 1605425
372 | 1630585 1630657
373 | 3464314 3464302
374 | 2842562 2842582
375 | 1076861 1077018
376 | 3028143 3028234
377 | 518089 518133
378 | 2336453 2336545
379 | 3061836 3062031
380 | 2738677 2738741
381 | 2046630 2046644
382 | 1919740 1919926
383 | 1721433 1721267
384 | 1269572 1269682
385 | 1771131 1771091
386 | 1757264 1757375
387 | 1984039 1983986
388 | 1609290 1609098
389 | 2728425 2728251
390 | 2020252 2020081
391 | 665419 665612
392 | 2945693 2945847
393 | 2217613 2217659
394 | 2530671 2530542
395 | 2607718 2607708
396 | 1015010 1014963
397 | 1513190 1513246
398 | 969512 969295
399 | 1657632 1657619
400 | 2385348 2385394
401 | 821523 821385
402 | 2577517 2577531
403 | 862804 862715
404 | 977938 978162
405 | 3073773 3073779
406 | 3107118 3107136
407 | 2047034 2046820
408 | 308567 308525
409 |
--------------------------------------------------------------------------------
/tasks/glue/process_mrpc.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: process_mrpc.py
5 |
6 | import os
7 | import sys
8 |
9 | def process_mrpc_data(data_dir, dev_ids_file):
10 | print("Processing MRPC...")
11 | mrpc_train_file = os.path.join(data_dir, "msr_paraphrase_train.txt")
12 | mrpc_test_file = os.path.join(data_dir, "msr_paraphrase_test.txt")
13 |
14 | assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file
15 | assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file
16 |
17 | dev_ids = []
18 | with open(dev_ids_file, encoding="utf8") as ids_fh:
19 | for row in ids_fh:
20 | dev_ids.append(row.strip().split('\t'))
21 |
22 | with open(mrpc_train_file, encoding="utf8") as data_fh, \
23 | open(os.path.join(data_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \
24 | open(os.path.join(data_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh:
25 | header = data_fh.readline()
26 | train_fh.write(header)
27 | dev_fh.write(header)
28 | for row in data_fh:
29 | label, id1, id2, s1, s2 = row.strip().split('\t')
30 | if [id1, id2] in dev_ids:
31 | dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
32 | else:
33 | train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
34 |
35 | with open(mrpc_test_file, encoding="utf8") as data_fh, \
36 | open(os.path.join(data_dir, "test.tsv"), 'w', encoding="utf8") as test_fh:
37 | header = data_fh.readline()
38 | test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
39 | for idx, row in enumerate(data_fh):
40 | label, id1, id2, s1, s2 = row.strip().split('\t')
41 | test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
42 | print("\tCompleted!")
43 |
44 |
45 | if __name__ == "__main__":
46 | data_dir = sys.argv[1]
47 | path_to_dev_ids = sys.argv[2]
48 | process_mrpc_data(data_dir, path_to_dev_ids)
--------------------------------------------------------------------------------
/tasks/mrc_ner/data_preprocess/file_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | from tasks.mrc_ner.data_preprocess.label_utils import iob_iobes
6 |
7 | def load_conll03_sentences(data_path):
8 | dataset = []
9 | with open(data_path, "r") as f:
10 | words, tags = [], []
11 | # for each line of the file correspond to one word and tag
12 | for line_idx, line in enumerate(f):
13 | if line != "\n" and ("-DOCSTART-" not in line):
14 | line = line.strip()
15 | if " " not in line:
16 | continue
17 | try:
18 | word, pos_cat, pos_label, tag = line.split(" ")
19 | word = word.strip()
20 | tag = tag.strip()
21 | except:
22 | print(line)
23 | continue
24 |
25 | if len(word) > 0 and len(tag) > 0:
26 | word, tag = str(word), str(tag)
27 | words.append(word)
28 | tags.append(tag)
29 | else:
30 | if len(words) > 0 and line_idx != 0:
31 | assert len(words) == len(tags)
32 | dataset.append((words, tags))
33 | words, tags = [], []
34 |
35 | tokens = [data[0] for data in dataset]
36 | labels = [iob_iobes(data[1]) for data in dataset]
37 | dataset = [(data_tokens, data_labels) for data_tokens, data_labels in zip(tokens, labels)]
38 | return dataset
39 |
40 | def load_conll03_documents(data_path):
41 | dataset = []
42 | with open(data_path, "r") as f:
43 | words, tags = [], []
44 | # for each line of the file correspond to one word and tag
45 | for line_idx, line in enumerate(f):
46 | if "-DOCSTART-" not in line:
47 | line = line.strip()
48 | if " " not in line:
49 | continue
50 | try:
51 | word, pos_cat, pos_label, tag = line.split(" ")
52 | word = word.strip()
53 | tag = tag.strip()
54 | except:
55 | print(line)
56 | continue
57 |
58 | if len(word) > 0 and len(tag) > 0:
59 | word, tag = str(word), str(tag)
60 | words.append(word)
61 | tags.append(tag)
62 | else:
63 | if len(words) > 0 and line_idx != 0:
64 | assert len(words) == len(tags)
65 | dataset.append((words, tags))
66 | words, tags = [], []
67 |
68 | tokens = [data[0] for data in dataset]
69 | labels = [iob_iobes(data[1]) for data in dataset]
70 | dataset = [(data_tokens, data_labels) for data_tokens, data_labels in zip(tokens, labels)]
71 | return dataset
72 |
73 | def export_conll(sentence, label, export_file_path, dim=2):
74 | """
75 | Args:
76 | sentence: a list of sentece of chars [["北", "京", "天", "安", "门"], ["真", "相", "警", 告"]]
77 | label: a list of labels [["B", "M", "E", "S", "O"], ["O", "O", "S", "S"]]
78 | Desc:
79 | export tagging data into conll format
80 | """
81 | with open(export_file_path, "w") as f:
82 | for idx, (sent_item, label_item) in enumerate(zip(sentence, label)):
83 | for char_idx, (tmp_char, tmp_label) in enumerate(zip(sent_item, label_item)):
84 | f.write("{} {}\n".format(tmp_char, tmp_label))
85 | f.write("\n")
86 |
87 |
88 | def load_conll(data_path):
89 | """
90 | Desc:
91 | load data in conll format
92 | Returns:
93 | [([word1, word2, word3, word4], [label1, label2, label3, label4]),
94 | ([word5, word6, word7, wordd8], [label5, label6, label7, label8])]
95 | """
96 | dataset = []
97 | with open(data_path, "r") as f:
98 | words, tags = [], []
99 | # for each line of the file correspond to one word and tag
100 | for line in f:
101 | if line != "\n":
102 | # line = line.strip()
103 | word, tag = line.split(" ")
104 | word = word.strip()
105 | tag = tag.strip()
106 | try:
107 | if len(word) > 0 and len(tag) > 0:
108 | word, tag = str(word), str(tag)
109 | words.append(word)
110 | tags.append(tag)
111 | except Exception as e:
112 | print("an exception was raise! skipping a word")
113 | else:
114 | if len(words) > 0:
115 | assert len(words) == len(tags)
116 | dataset.append((words, tags))
117 | words, tags = [], []
118 |
119 | return dataset
120 |
121 | def dump_tsv(data_lines, data_path):
122 | """
123 | Desc:
124 | dump data into tsv format for TAGGING data
125 | Input:
126 | the format of data_lines is:
127 | [([word1, word2, word3, word4], [label1, label2, label3, label4]),
128 | ([word5, word6, word7, word8, word9], [label5, label6, label7, label8, label9]),
129 | ([word10, word11, word12, ], [label10, label11, label12])]
130 | """
131 | print("dump dataliens into TSV format : ")
132 | with open(data_path, "w") as f:
133 | for data_item in data_lines:
134 | data_word, data_tag = data_item
135 | data_str = " ".join(data_word)
136 | data_tag = " ".join(data_tag)
137 | f.write(data_str + "\t" + data_tag + "\n")
138 | print("dump data set into data path")
139 | print(data_path)
140 |
141 |
142 | if __name__ == "__main__":
143 | import os
144 | repo_path = "/".join(os.path.realpath(__file__).split("/")[:-4])
145 | print(repo_path)
146 | conll_data_file = os.path.join(repo_path, "tests", "data", "enconll03", "test.txt")
147 | conll_dataset = load_conll03_documents(conll_data_file)
148 | doc_len = [len(tmp[0]) for tmp in conll_dataset]
149 | # number of doc is 230
150 | print(f"NUM OF DOC -> {len(doc_len)}")
151 | print(f"AGV -> {sum(doc_len)/ float(len(doc_len))}")
152 | print(f"MAX -> {max(doc_len)}")
153 | print(f"MIN -> {min(doc_len)}")
154 | print(f"512 -> {len([tmp for tmp in doc_len if tmp >= 500])}")
--------------------------------------------------------------------------------
/tasks/mrc_ner/data_preprocess/label_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # description:
5 | # utilies for sequence tagging tasks for entity-level tasks
6 | # (such as NER)
7 |
8 |
9 | def get_bmes(span_labels, length, encoding):
10 | tags = ["O" for _ in range(length)]
11 |
12 | for start, end in span_labels:
13 | for i in range(start+1, end+1):
14 | tags[i] = "M"
15 | if "E" in encoding:
16 | tags[end] = "E"
17 | if "B" in encoding:
18 | tags[start] = "B"
19 | if "S" in encoding and start == end:
20 | tags[start] = "S"
21 | return tags
22 |
23 |
24 | def get_span_labels(sentence_tags, inv_label_mapping=None):
25 | """
26 | Desc:
27 | get from token_level labels to list of entities,
28 | it doesnot matter tagging scheme is BMES or BIO or BIOUS
29 | Returns:
30 | a list of entities
31 | [(start, end, labels), (start, end, labels)]
32 | """
33 |
34 | if inv_label_mapping:
35 | sentence_tags = [inv_label_mapping[i] for i in sentence_tags]
36 |
37 | span_labels = []
38 | last = "O"
39 | start = -1
40 | for i, tag in enumerate(sentence_tags):
41 | pos, _ = (None, "O") if tag == "O" else tag.split("-")
42 | if (pos == "S" or pos == "B" or tag == "O") and last != "O":
43 | span_labels.append((start, i - 1, last.split("-")[-1]))
44 | if pos == "B" or pos == "S" or last == "O":
45 | start = i
46 | last = tag
47 |
48 | if sentence_tags[-1] != "O":
49 | span_labels.append((start, len(sentence_tags) -1 , sentence_tags[-1].split("-"[-1])))
50 |
51 | return span_labels
52 |
53 |
54 |
55 | def get_tags(span_labels, length, encoding):
56 | """
57 | Desc:
58 | convert a list of entities to token-level labels based on the provided encoding (e.g., BMOES)
59 | Please notice that the left and right bounaries are involved.
60 | """
61 | tags = ["O" for _ in range(length)]
62 |
63 | for start, end, tag in span_labels:
64 | for i in range(start, end + 1):
65 | tags[i] = "M-" + tag
66 |
67 | if "E" in encoding:
68 | tags[end] = "E-" + tag
69 | if "B" in encoding:
70 | tags[start] = "B-" + tag
71 | if "S" in encoding and start == end:
72 | tags[start] = "S-" + tag
73 | return tags
74 |
75 |
76 |
77 | def iob_iobes(tags):
78 | """
79 | Desc:
80 | IOB -> IOBES
81 | """
82 | new_tags = []
83 | for i, tag in enumerate(tags):
84 | if tag == "O":
85 | new_tags.append(tag)
86 | elif tag.split("-")[0] == "B":
87 | if i + 1 != len(tags) and tags[i+1].split("-")[0] == "I":
88 | new_tags.append(tag)
89 | else:
90 | new_tags.append(tag.replace("B-", "S-"))
91 | elif tag.split("-")[0] == "I":
92 | if i + 1 < len(tags) and tags[i + 1].split("-")[0] == "I":
93 | new_tags.append(tag)
94 | else:
95 | new_tags.append(tag.replace("I-", "E-"))
96 | else:
97 | raise Exception("invalid IOB format !!")
98 | return new_tags
99 |
100 |
101 |
102 | if __name__ == "__main__":
103 | label_tags = ["O", "B-ORG", "M-ORG", "E-ORG", "B-PER", "M-PER", "E-PER"]
104 | span_labels = get_span_labels(label_tags, )
105 | print("check the content of span_labels")
106 | print(span_labels)
107 | # [(1, 2, "ORG"), (5, 7, "PER")]
108 |
109 | # -------------------------
110 | # test the functionality of get_tags
111 | # -------------------------
112 | print("-*-"*10)
113 | print("check the content of span_labels")
114 | span_label = get_tags([(1, 3, "ORG"), (8, 10, "PER", )], "BIOES")
115 | print(span_label)
116 |
117 |
118 |
119 |
--------------------------------------------------------------------------------
/tasks/mrc_ner/data_preprocess/query_map.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: query_map.py
5 | # ---------------------------------------------
6 | # query collections for different dataset
7 | # ---------------------------------------------
8 |
9 | en_conll03_query = {
10 | "default": {
11 | "ORG": "organization entities are limited to named corporate, governmental, or other organizational entities.",
12 | "PER": "person entities are named persons or family.",
13 | "LOC": "location entities are the name of politically or geographically defined locations such as cities, provinces, countries, international regions, bodies of water, mountains, etc.",
14 | "MISC": "examples of miscellaneous entities include events, nationalities, products and works of art."
15 | },
16 | "labels": ["ORG", "PER", "LOC", "MISC"]
17 | }
18 |
19 | queries_for_dataset = {
20 | "en_conll03": en_conll03_query
21 | }
22 |
23 |
24 |
25 |
--------------------------------------------------------------------------------
/tasks/mrc_ner/eval.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | PRECISION=32
6 | FILE_NAME=eval_mrc_ner
7 | REPO_PATH=/data/xiaoya/workspace/mrc-with-dice-loss
8 | CKPT_PATH=$1
9 |
10 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
11 |
12 | CUDA_VISIBLE_DEVICES=1 python3 ${REPO_PATH}/tasks/mrc_ner/evaluate.py \
13 | --gpus="1" \
14 | --precision=${PRECISION} \
15 | --path_to_model_checkpoint ${CKPT_PATH}
16 |
--------------------------------------------------------------------------------
/tasks/mrc_ner/evaluate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: tasks/mrc_ner/evaluate.py
5 | #
6 |
7 | import os
8 | import argparse
9 | from utils.random_seed import set_random_seed
10 | set_random_seed(2333)
11 |
12 | from pytorch_lightning import Trainer
13 | from tasks.mrc_ner.train import BertForNERTask
14 | from utils.get_parser import get_parser
15 |
16 |
17 | def init_evaluate_parser(parser) -> argparse.ArgumentParser:
18 | parser.add_argument("--path_to_model_checkpoint", type=str, help="")
19 | parser.add_argument("--path_to_model_hparams_file", type=str, default="")
20 | return parser
21 |
22 |
23 | def evaluate(args):
24 | trainer = Trainer(gpus=args.gpus,
25 | distributed_backend=args.distributed_backend,
26 | deterministic=True)
27 | model = BertForNERTask.load_from_checkpoint(
28 | checkpoint_path=args.path_to_model_checkpoint,
29 | hparams_file=args.path_to_model_hparams_file,
30 | map_location=None,
31 | batch_size=args.eval_batch_size
32 | )
33 | trainer.test(model=model)
34 |
35 |
36 | def main():
37 | eval_parser = get_parser()
38 | eval_parser = init_evaluate_parser(eval_parser)
39 | eval_parser = BertForNERTask.add_model_specific_args(eval_parser)
40 | eval_parser = Trainer.add_argparse_args(eval_parser)
41 | args = eval_parser.parse_args()
42 |
43 | if len(args.path_to_model_hparams_file) == 0:
44 | eval_output_dir = "/".join(args.path_to_model_checkpoint.split("/")[:-1])
45 | args.path_to_model_hparams_file = os.path.join(eval_output_dir, "lightning_logs", "version_0", "hparams.yaml")
46 |
47 | evaluate(args)
48 |
49 |
50 | if __name__ == "__main__":
51 | main()
52 |
--------------------------------------------------------------------------------
/tasks/mrc_ner/generate_mrc_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # transform data from sequence labeling to mrc formulation
5 | # ----------------------------------------------------------
6 | # input data structure
7 | # ---------------------------------------------------------
8 | # this module is to generate mrc-style ner task.
9 | # 1. for flat ner, the input file follows conll and tagged in BMES schema
10 | # 2. for nested ner, the input file in json
11 |
12 | import os
13 | import sys
14 | import json
15 |
16 | repo_path = "/".join(os.path.realpath(__file__).split("/")[:-3])
17 | print(repo_path)
18 | if repo_path not in sys.path:
19 | sys.path.insert(0, repo_path)
20 |
21 | from data_preprocess.file_utils import load_conll, load_conll03_documents, load_conll03_sentences
22 | from data_preprocess.label_utils import get_span_labels
23 | from data_preprocess.query_map import queries_for_dataset
24 |
25 |
26 | def generate_query_ner_dataset(source_file_path, dump_file_path, entity_sign="nested",
27 | dataset_name=None, query_sign="default"):
28 | """
29 | Args:
30 | source_data_file: /data/genia/train.word.json | /data/msra/train.char.bmes
31 | dump_data_file: /data/genia-mrc/train.mrc.json | /data/msra-mrc/train.mrc.json
32 | dataset_name: one in ["en_ontonotes5", "en_conll03", ]
33 | entity_sign: one of ["nested", "flat"]
34 | query_sign: defualt is "default"
35 | """
36 | entity_queries = queries_for_dataset[dataset_name][query_sign]
37 | label_lst = queries_for_dataset[dataset_name]["labels"]
38 |
39 | if entity_sign == "nested":
40 | with open(source_file_path, "r") as f:
41 | source_data = json.load(f)
42 | elif entity_sign == "flat":
43 | source_data = load_conll(source_file_path)
44 | elif entity_sign == "flat_enconll03_docs":
45 | source_data = load_conll03_documents(source_file_path)
46 | elif entity_sign == "flat_enconll03_sents":
47 | source_data = load_conll03_sentences(source_file_path)
48 | else:
49 | raise ValueError("ENTITY_SIGN can only be NESTED or FLAT.")
50 |
51 | target_data = transform_examples_to_qa_features(entity_queries, label_lst, source_data, entity_sign=entity_sign)
52 |
53 | with open(dump_file_path, "w") as f:
54 | json.dump(target_data, f, sort_keys=True, ensure_ascii=False, indent=2)
55 |
56 |
57 | def transform_examples_to_qa_features(query_map, entity_labels, data_instances, entity_sign="nested"):
58 | """
59 | Desc:
60 | convert_examples to qa features
61 | Args:
62 | query_map: {entity label: entity query};
63 | data_instance
64 | """
65 | mrc_ner_dataset = []
66 |
67 | if "flat" in entity_sign.lower():
68 | tmp_qas_id = 0
69 | for idx, (word_lst, label_lst) in enumerate(data_instances):
70 | candidate_span_label = get_span_labels(label_lst)
71 | tmp_query_id = 0
72 | for label_idx, tmp_label in enumerate(entity_labels):
73 | tmp_query_id += 1
74 | tmp_query = query_map[tmp_label]
75 | tmp_context = " ".join(word_lst)
76 |
77 | tmp_start_pos = []
78 | tmp_end_pos = []
79 | tmp_entity_pos = []
80 |
81 | start_end_label = [(start, end) for start, end, label_content in candidate_span_label if
82 | label_content == tmp_label]
83 |
84 | if len(start_end_label) != 0:
85 | for span_item in start_end_label:
86 | start_idx, end_idx = span_item
87 | tmp_start_pos.append(start_idx)
88 | tmp_end_pos.append(end_idx)
89 | tmp_entity_pos.append("{};{}".format(str(start_idx), str(end_idx)))
90 | tmp_impossible = False
91 | else:
92 | tmp_impossible = True
93 |
94 | mrc_ner_dataset.append({
95 | "qas_id": "{}.{}".format(str(tmp_qas_id), str(tmp_query_id)),
96 | "query": tmp_query,
97 | "context": tmp_context,
98 | "entity_label": tmp_label,
99 | "start_position": tmp_start_pos,
100 | "end_position": tmp_end_pos,
101 | "span_position": tmp_entity_pos,
102 | "impossible": tmp_impossible
103 | })
104 | tmp_qas_id += 1
105 |
106 | elif "nested" in entity_sign.lower():
107 | tmp_qas_id = 0
108 | for idx, data_item in enumerate(data_instances):
109 | tmp_query_id = 0
110 | for label_idx, tmp_label in enumerate(entity_labels):
111 | tmp_query_id += 1
112 | tmp_query = query_map[tmp_label]
113 | tmp_context = data_item["context"]
114 |
115 | tmp_start_pos = []
116 | tmp_end_pos = []
117 | tmp_entity_pos = []
118 |
119 | start_end_label = data_item["label"][tmp_label] if tmp_label in data_item["label"].keys() else -1
120 |
121 | if start_end_label == -1:
122 | tmp_impossible = True
123 | else:
124 | for start_end_item in data_item["label"][tmp_label]:
125 | start_idx, end_idx = [int(ix) for ix in start_end_item.split(",")]
126 | tmp_start_pos.append(start_idx)
127 | tmp_end_pos.append(end_idx)
128 | tmp_entity_pos.append(start_end_item)
129 | tmp_impossible = False
130 |
131 | mrc_ner_dataset.append({
132 | "qas_id": "{}.{}".format(str(tmp_qas_id), str(tmp_query_id)),
133 | "query": tmp_query,
134 | "context": tmp_context,
135 | "entity_label": tmp_label,
136 | "start_position": tmp_start_pos,
137 | "end_position": tmp_end_pos,
138 | "span_position": tmp_entity_pos,
139 | "impossible": tmp_impossible
140 | })
141 | tmp_qas_id += 1
142 |
143 | return mrc_ner_dataset
144 |
145 |
146 | if __name__ == "__main__":
147 | source_file_path = "/data/xiaoya/datasets/ner/conll2003_truecase/valid.txt"
148 | target_file_path = "/data/xiaoya/datasets/mrc_ner/en_conll03_truecase_doc/mrc-ner.dev"
149 | # generate_query_ner_dataset(source_file_path, target_file_path, entity_sign="flat_enconll03_document", dataset_name="en_conll03", query_sign="default")
150 |
151 | source_file_path = "/data/xiaoya/datasets/ner/conll2003_truecase/train.txt"
152 | target_file_path = "/data/xiaoya/datasets/mrc_ner/en_conll03_truecase_sent/mrc-ner.train"
153 | # generate_query_ner_dataset(source_file_path, target_file_path, entity_sign="flat_enconll03_sentences", dataset_name="en_conll03", query_sign="default")
154 |
155 | source_file_path = "/data/xiaoya/datasets/ner/conll2003/eng.train"
156 | target_file_path = "/data/xiaoya/datasets/mrc_ner/en_conll03_doc/mrc-ner.train"
157 | generate_query_ner_dataset(source_file_path, target_file_path, entity_sign="flat_enconll03_docs",
158 | dataset_name="en_conll03", query_sign="default")
159 |
--------------------------------------------------------------------------------
/tasks/mrc_ner/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | TIME=2021.02.02
5 | FILE_NAME=debug_onto4
6 | REPO_PATH=/data/xiaoya/workspace/mrc-with-dice-loss
7 | MODEL_SCALE=base
8 | DATA_DIR=/data/nfsdata2/xiaoya/mrc_ner/zh_onto4
9 | BERT_DIR=/data/xiaoya/pretrain_lm/chinese_L-12_H-768_A-12
10 |
11 | TRAIN_BATCH_SIZE=5
12 | EVAL_BATCH_SIZE=12
13 | MAX_LENGTH=128
14 |
15 | OPTIMIZER=adamw
16 | LR_SCHEDULE=linear
17 | LR=3e-5
18 |
19 | BERT_DROPOUT=0.1
20 | ACC_GRAD=1
21 | MAX_EPOCH=5
22 | GRAD_CLIP=1.0
23 | WEIGHT_DECAY=0.002
24 | WARMUP_PROPORTION=0.002
25 |
26 | LOSS_TYPE=bce
27 | W_START=1
28 | W_END=1
29 | W_SPAN=1
30 | DICE_SMOOTH=1
31 | DICE_OHEM=0.8
32 | DICE_ALPHA=0.01
33 | FOCAL_GAMMA=2
34 |
35 | PRECISION=32
36 | PROGRESS_BAR=1
37 | VAL_CHECK_INTERVAL=0.25
38 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
39 |
40 | if [[ ${LOSS_TYPE} == "bce" ]]; then
41 | LOSS_SIGN=${LOSS_TYPE}
42 | elif [[ ${LOSS_TYPE} == "focal" ]]; then
43 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA}
44 | elif [[ ${LOSS_TYPE} == "dice" ]]; then
45 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA}
46 | fi
47 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}"
48 |
49 | OUTPUT_BASE_DIR=/data/xiaoya/outputs/dice_loss/mrc_ner/${TIME}
50 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${W_START}_${W_END}_${W_SPAN}_${LOSS_SIGN}
51 |
52 | mkdir -p ${OUTPUT_DIR}
53 |
54 | CUDA_VISIBLE_DEVICES=3 python ${REPO_PATH}/tasks/mrc_ner/train.py \
55 | --gpus="1" \
56 | --precision=${PRECISION} \
57 | --train_batch_size ${TRAIN_BATCH_SIZE} \
58 | --eval_batch_size ${EVAL_BATCH_SIZE} \
59 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
60 | --val_check_interval ${VAL_CHECK_INTERVAL} \
61 | --max_length ${MAX_LENGTH} \
62 | --optimizer ${OPTIMIZER} \
63 | --data_dir ${DATA_DIR} \
64 | --bert_hidden_dropout ${BERT_DROPOUT} \
65 | --bert_config_dir ${BERT_DIR} \
66 | --lr ${LR} \
67 | --lr_scheduler ${LR_SCHEDULE} \
68 | --accumulate_grad_batches ${ACC_GRAD} \
69 | --default_root_dir ${OUTPUT_DIR} \
70 | --output_dir ${OUTPUT_DIR} \
71 | --max_epochs ${MAX_EPOCH} \
72 | --gradient_clip_val ${GRAD_CLIP} \
73 | --weight_decay ${WEIGHT_DECAY} \
74 | --loss_type ${LOSS_TYPE} \
75 | --weight_start ${W_START} \
76 | --weight_end ${W_END} \
77 | --weight_span ${W_SPAN} \
78 | --dice_smooth ${DICE_SMOOTH} \
79 | --dice_ohem ${DICE_OHEM} \
80 | --dice_alpha ${DICE_ALPHA} \
81 | --dice_square \
82 | --focal_gamma ${FOCAL_GAMMA} \
83 | --warmup_proportion ${WARMUP_PROPORTION} \
84 | --span_loss_candidates all \
85 | --do_lower_case \
86 | --is_chinese \
87 | --flat_ner
--------------------------------------------------------------------------------
/tasks/squad/evaluate_models.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: squad/evaluate_models.py
5 | # description:
6 | # evaluate the saved model checkpoints.
7 |
8 | import os
9 | from tasks.squad.train import BertForQA
10 | from utils.get_parser import get_parser
11 | from utils.random_seed import set_random_seed
12 | set_random_seed(0)
13 | from pytorch_lightning import Trainer
14 |
15 |
16 |
17 | def init_evaluate_parser(parser):
18 | parser.add_argument("--path_to_model_checkpoint", type=str, )
19 | parser.add_argument("--path_to_model_hparams_file", type=str, default="")
20 | return parser
21 |
22 |
23 | def evaluate(args):
24 | """
25 | Args:
26 | ckpt: model checkpoints.
27 | hparams_file: the string should end with "hparams.yaml"
28 | """
29 | trainer = Trainer(gpus=args.gpus,
30 | distributed_backend=args.distributed_backend,
31 | deterministic=True)
32 |
33 | model = BertForQA.load_from_checkpoint(
34 | checkpoint_path=args.path_to_model_checkpoint,
35 | hparams_file=args.path_to_model_hparams_file,
36 | map_location=None,
37 | batch_size=args.eval_batch_size,)
38 |
39 | trainer.test(model=model)
40 |
41 |
42 | def main():
43 | """evaluate model checkpoints on the dev set. """
44 | eval_parser = init_evaluate_parser(get_parser())
45 | eval_parser = BertForQA.add_model_specific_args(eval_parser)
46 | eval_parser = Trainer.add_argparse_args(eval_parser)
47 | args = eval_parser.parse_args()
48 | if len(args.path_to_model_hparams_file) == 0:
49 | args.path_to_model_hparams_file = os.path.join("/".join(args.path_to_model_checkpoint.split("/")[:-1]), "lightning_logs", "version_0", "hparams.yaml")
50 |
51 | evaluate(args)
52 |
53 |
54 | if __name__ == "__main__":
55 | main()
--------------------------------------------------------------------------------
/tasks/squad/evaluate_predictions.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # files: squad/evaluate_predictions.py
5 | # description:
6 | # evaluate the output prediction file
7 | # Example script:
8 | # python3 evaluate_predictions.py /data/dev-v1.1.json /outputs/dice_loss/squad
9 |
10 | import os
11 | import sys
12 | import json
13 | import subprocess
14 | from glob import glob
15 | from utils.random_seed import set_random_seed
16 |
17 | REPO_PATH = "/".join(os.path.realpath(__file__).split("/")[:-3])
18 |
19 |
20 | def main(golden_data_file: str = "dev-v1.1.json",
21 | saved_ckpt_directory: str = None,
22 | version_2_with_negative: bool = False,
23 | eval_result_output: str = "result.json"):
24 | """evaluate model prediction files."""
25 |
26 | if not os.path.exists(golden_data_file):
27 | raise ValueError("Please run 'python3 evaluate_predictions.py ' ")
28 | if saved_ckpt_directory is not None:
29 | prediction_files = glob(os.path.join(saved_ckpt_directory, "predictions_*_*.json"))
30 | else:
31 | raise ValueError("Please run 'python3 evaluate_predictions.py ' ")
32 |
33 | if version_2_with_negative:
34 | evaluate_script_file = os.path.join(REPO_PATH, "metrics", "functional", "squad", "evaluate_v2.py")
35 | else:
36 | evaluate_script_file = os.path.join(REPO_PATH, "metrics", "functional", "squad", "evaluate_v1.py")
37 | evaluate_sh_file = os.path.join(REPO_PATH, "metrics", "functional", "squad", "eval.sh")
38 |
39 | chmod_result = os.system(f"chmod 777 {evaluate_sh_file}")
40 | if chmod_result != 0:
41 | raise ValueError
42 |
43 | evaluate_results = {}
44 | best_em = 0
45 | best_f1 = 0
46 | best_ckpt_path = ""
47 | for prediction_file in prediction_files:
48 | evaluate_key = prediction_file.replace(f"{saved_ckpt_directory}/", "").replace("predictions_", "").replace(".json", "")
49 | cmd = [evaluate_sh_file, evaluate_script_file, golden_data_file, prediction_file]
50 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE)
51 | stdout, stderr = process.communicate()
52 | process.wait()
53 |
54 | stdout = stdout.strip()
55 | if stderr is not None:
56 | print(stderr)
57 |
58 | evaluate_value = json.loads(stdout)
59 | if best_f1 <= evaluate_value['f1']:
60 | best_f1 = evaluate_value['f1']
61 | best_em = evaluate_value['exact_match']
62 | best_ckpt_path = prediction_file
63 | evaluate_results[evaluate_key] = evaluate_value
64 |
65 | eval_log_file = os.path.join(saved_ckpt_directory, eval_result_output)
66 | with open(eval_log_file, "w") as f:
67 | json.dump(evaluate_results, f, sort_keys=True, indent=2, ensure_ascii=False)
68 |
69 | print(f"BEST CKPT is -> {best_ckpt_path}")
70 | print(f"BEST Exact Match is : -> {best_em}")
71 | print(f"BEST span-f1 is : {best_f1}")
72 |
73 |
74 | if __name__ == "__main__":
75 | set_random_seed(0)
76 | golden_data_file = sys.argv[1]
77 | saved_ckpt_directory = sys.argv[2]
78 | main(golden_data_file, saved_ckpt_directory, )
--------------------------------------------------------------------------------
/tasks/tnews/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | TIME=2021.02.08
5 | FILE_NAME=re_debug_tnews
6 | REPO_PATH=/data/xiaoya/workspace/mrc-with-dice-loss
7 | MODEL_SCALE=base
8 | DATA_DIR=/data/xiaoya/datasets/tnews_public_data
9 | BERT_DIR=/data/xiaoya/pretrain_lm/chinese_L-12_H-768_A-12
10 |
11 | TRAIN_BATCH_SIZE=24
12 | EVAL_BATCH_SIZE=12
13 | MAX_LENGTH=128
14 |
15 | OPTIMIZER=adamw
16 | LR_SCHEDULE=linear
17 | LR=3e-5
18 |
19 | BERT_DROPOUT=0.1
20 | ACC_GRAD=1
21 | MAX_EPOCH=5
22 | GRAD_CLIP=1.0
23 | WEIGHT_DECAY=0.002
24 | WARMUP_PROPORTION=0.002
25 |
26 | LOSS_TYPE=dice
27 | # ce, focal, dice
28 | DICE_SMOOTH=1
29 | DICE_OHEM=0.8
30 | DICE_ALPHA=0.01
31 | FOCAL_GAMMA=2
32 |
33 | PRECISION=32
34 | PROGRESS_BAR=1
35 | VAL_CHECK_INTERVAL=0.25
36 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
37 |
38 | if [[ ${LOSS_TYPE} == "bce" ]]; then
39 | LOSS_SIGN=${LOSS_TYPE}
40 | elif [[ ${LOSS_TYPE} == "focal" ]]; then
41 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA}
42 | elif [[ ${LOSS_TYPE} == "dice" ]]; then
43 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA}
44 | fi
45 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}"
46 |
47 | OUTPUT_BASE_DIR=/data/xiaoya/outputs/dice_loss/tnews/${TIME}
48 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${LOSS_SIGN}
49 |
50 | mkdir -p ${OUTPUT_DIR}
51 |
52 | CUDA_VISIBLE_DEVICES=3 python ${REPO_PATH}/tasks/tnews/train.py \
53 | --gpus="1" \
54 | --precision=${PRECISION} \
55 | --train_batch_size ${TRAIN_BATCH_SIZE} \
56 | --eval_batch_size ${EVAL_BATCH_SIZE} \
57 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
58 | --val_check_interval ${VAL_CHECK_INTERVAL} \
59 | --max_length ${MAX_LENGTH} \
60 | --optimizer ${OPTIMIZER} \
61 | --data_dir ${DATA_DIR} \
62 | --bert_hidden_dropout ${BERT_DROPOUT} \
63 | --bert_config_dir ${BERT_DIR} \
64 | --lr ${LR} \
65 | --lr_scheduler ${LR_SCHEDULE} \
66 | --accumulate_grad_batches ${ACC_GRAD} \
67 | --default_root_dir ${OUTPUT_DIR} \
68 | --output_dir ${OUTPUT_DIR} \
69 | --max_epochs ${MAX_EPOCH} \
70 | --gradient_clip_val ${GRAD_CLIP} \
71 | --weight_decay ${WEIGHT_DECAY} \
72 | --loss_type ${LOSS_TYPE} \
73 | --dice_smooth ${DICE_SMOOTH} \
74 | --dice_ohem ${DICE_OHEM} \
75 | --dice_alpha ${DICE_ALPHA} \
76 | --dice_square \
77 | --focal_gamma ${FOCAL_GAMMA} \
78 | --warmup_proportion ${WARMUP_PROPORTION}
--------------------------------------------------------------------------------
/tests/count_length_autotokenizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # author: xiaoya li
5 | # file: count_length_autotokenizer.py
6 |
7 | import os
8 | import sys
9 |
10 | repo_path = "/".join(os.path.realpath(__file__).split("/")[:-2])
11 | print(repo_path)
12 | if repo_path not in sys.path:
13 | sys.path.insert(0, repo_path)
14 |
15 |
16 | from transformers import AutoTokenizer
17 | from datasets.mrc_ner_dataset import MRCNERDataset
18 |
19 |
20 |
21 | class OntoNotesDataConfig:
22 | def __init__(self):
23 | self.data_dir = "/data/nfsdata2/xiaoya/mrc_ner/zh_onto4"
24 | self.model_path = "/data/xiaoya/pretrain_lm/chinese_L-12_H-768_A-12"
25 | self.do_lower_case = False
26 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False, tokenize_chinese_chars=True)
27 | # BertWordPieceTokenizer(os.path.join(self.model_path, "vocab.txt"), lowercase=self.do_lower_case)
28 | self.max_length = 512
29 | self.is_chinese = True
30 | self.threshold = 275
31 | self.data_sign = "zh_onto"
32 |
33 | class ChineseMSRADataConfig:
34 | def __init__(self):
35 | self.data_dir = "/data/xiaoya/datasets/mrc_ner/zh_msra"
36 | self.model_path = "/data/xiaoya/pretrain_lm/chinese_L-12_H-768_A-12"
37 | self.do_lower_case = False
38 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False, tokenize_chinese_chars=True)
39 | self.max_length = 512
40 | self.is_chinese = True
41 | self.threshold = 275
42 | self.data_sign = "zh_msra"
43 |
44 |
45 | class EnglishOntoDataConfig:
46 | def __init__(self):
47 | self.data_dir = "/data/xiaoya/datasets/mrc_ner/en_onto5"
48 | self.model_path = "/data/xiaoya/pretrain_lm/cased_L-12_H-768_A-12"
49 | self.do_lower_case = False
50 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False)
51 | self.max_length = 512
52 | self.is_chinese = False
53 | self.threshold = 275
54 | self.data_sign = "en_onto"
55 |
56 |
57 | class EnglishCoNLLDataConfig:
58 | def __init__(self):
59 | self.data_dir = "/data/xiaoya/datasets/mrc_ner/en_conll03"
60 | self.model_path = "/data/xiaoya/pretrain_lm/cased_L-12_H-768_A-12"
61 | if "uncased" in self.model_path:
62 | self.do_lower_case = True
63 | else:
64 | self.do_lower_case = False
65 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False, do_lower_case=self.do_lower_case)
66 | self.max_length = 512
67 | self.is_chinese = False
68 | self.threshold = 275
69 | self.data_sign = "en_conll03"
70 |
71 | class EnglishCoNLL03DocDataConfig:
72 | def __init__(self):
73 | self.data_dir = "/data/xiaoya/datasets/mrc_ner/en_conll03_doc"
74 | self.model_path = "/data/xiaoya/pretrain_lm/cased_L-12_H-768_A-12"
75 | self.do_lower_case = False
76 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False)
77 | self.max_length = 512
78 | self.is_chinese = False
79 | self.threshold = 384
80 | self.data_sign = "en_conll03"
81 |
82 | def count_max_length(data_sign):
83 | if data_sign == "zh_onto":
84 | data_config = OntoNotesDataConfig()
85 | elif data_sign == "zh_msra":
86 | data_config = ChineseMSRADataConfig()
87 | elif data_sign == "en_onto":
88 | data_config = EnglishOntoDataConfig()
89 | elif data_sign == "en_conll03":
90 | data_config = EnglishCoNLLDataConfig()
91 | elif data_sign == "en_conll03_doc":
92 | data_config = EnglishCoNLL03DocDataConfig()
93 | else:
94 | raise ValueError
95 | for prefix in ["test", "train", "dev"]:
96 | print("=*"*15)
97 | print(f"INFO -> loading {prefix} data. ")
98 | data_file_path = os.path.join(data_config.data_dir, f"mrc-ner.{prefix}")
99 | dataset = MRCNERDataset(json_path=data_file_path,
100 | tokenizer=data_config.tokenizer,
101 | max_length=data_config.max_length,
102 | is_chinese=data_config.is_chinese,
103 | pad_to_maxlen=False,
104 | data_sign=data_config.data_sign)
105 | max_len = 0
106 | counter = 0
107 | for idx, data_item in enumerate(dataset):
108 | tokens = data_item[0]
109 | num_tokens = tokens.shape[0]
110 | if num_tokens >= max_len:
111 | max_len = num_tokens
112 | if num_tokens > data_config.threshold:
113 | counter += 1
114 |
115 | print(f"INFO -> Max LEN for {prefix} set is : {max_len}")
116 | print(f"INFO -> large than {data_config.threshold} is {counter}")
117 |
118 |
119 |
120 | if __name__ == '__main__':
121 | # for english
122 | data_sign = "en_onto"
123 | # data_sign = "zh_onto"
124 | count_max_length(data_sign)
125 |
--------------------------------------------------------------------------------
/tests/count_length_glue.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: count_length_glue.py
5 |
6 | import os
7 | import sys
8 |
9 | repo_path = "/".join(os.path.realpath(__file__).split("/")[:-2])
10 | print(repo_path)
11 | if repo_path not in sys.path:
12 | sys.path.insert(0, repo_path)
13 |
14 | from transformers import AutoTokenizer
15 | from datasets.qqp_dataset import QQPDataset
16 |
17 |
18 | class QQPDataConfig:
19 | def __init__(self):
20 | self.data_dir = "/data/xiaoya/datasets/glue/qqp"
21 | self.model_path = "/data/xiaoya/models/bert_cased_large"
22 | self.do_lower_case = False
23 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False,
24 | do_lower_case=False,
25 | tokenize_chinese_chars=False)
26 | self.max_seq_length = 1024
27 | self.is_chinese = False
28 | self.threshold = 275
29 | self.pad_to_maxlen = False
30 |
31 |
32 | def main():
33 | data_config = QQPDataConfig()
34 |
35 | for mode in ["train", "dev", "test"]:
36 | print("=*"*20)
37 | print(mode)
38 | print("=*"*20)
39 | data_length_collection = []
40 | qqp_dataset = QQPDataset(data_config, data_config.tokenizer, mode=mode, )
41 | for data_item in qqp_dataset:
42 | input_tokens = data_item["input_ids"].shape
43 | print(input_tokens)
44 | exit()
45 |
46 |
47 | if __name__ == "__main__":
48 | main()
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShannonAI/dice_loss_for_NLP/d437bb999185535df46fdb74d1f2f57161331b44/utils/__init__.py
--------------------------------------------------------------------------------
/utils/get_parser.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: get_parser.py
5 | # description:
6 | # argument parser
7 |
8 | import argparse
9 |
10 | def get_parser() -> argparse.ArgumentParser:
11 | """
12 | return basic arg parser
13 | """
14 | parser = argparse.ArgumentParser(description="argument parser")
15 |
16 | parser.add_argument("--seed", type=int, default=2333)
17 | parser.add_argument("--data_dir", type=str, help="data dir")
18 | parser.add_argument("--bert_config_dir", type=str, help="bert config dir")
19 | parser.add_argument("--pretrained_checkpoint", default="", type=str, help="pretrained checkpoint path")
20 | parser.add_argument("--train_batch_size", type=int, default=32, help="batch size for train dataloader")
21 | parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for eval dataloader")
22 | parser.add_argument("--lr", type=float, default=2e-5, help="learning rate")
23 | parser.add_argument("--lr_scheduler", type=str, default="onecycle", help="type of lr scheduler")
24 | parser.add_argument("--workers", type=int, default=0, help="num workers for dataloader")
25 | # number of data-loader workers should equal to 0.
26 | # https://blog.csdn.net/breeze210/article/details/99679048
27 | parser.add_argument("--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.")
28 | parser.add_argument("--weight_decay", default=0.01, type=float,
29 | help="Weight decay if we apply some.")
30 | # in case of not error, define a new argument
31 | parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Proportion of training to perform linear learning rate warmup for.")
32 | parser.add_argument("--adam_epsilon", default=1e-6, type=float,
33 | help="Epsilon for Adam optimizer.")
34 | parser.add_argument("--max_keep_ckpt", default=3, type=int,
35 | help="the number of keeping ckpt max.")
36 | parser.add_argument("--output_dir", default="/data", type=str, help="the directory to save model outputs")
37 | parser.add_argument("--debug", action="store_true", help="train with 10 data examples in the debug mode.")
38 |
39 | parser.add_argument("--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets")
40 | parser.add_argument("--cache_dir", default="", type=str, help="Where do you want to store the pre-trained models downloaded from s3", )
41 |
42 | # optimizer and loss func
43 | parser.add_argument("--bert_hidden_dropout", type=float, default=0.1, )
44 | parser.add_argument("--final_div_factor", type=float, default=1e4,
45 | help="final div factor of linear decay scheduler")
46 | # TODO: choices=["adamw", "sgd", "debias"]
47 | parser.add_argument("--optimizer", default="adamw", help="loss type")
48 | # TODO: change chocies
49 | # choices=["ce", "bce", "dice", "focal", "adaptive_dice"],
50 | parser.add_argument("--loss_type", default="bce", help="loss type")
51 | ## dice loss
52 | parser.add_argument("--dice_smooth", type=float, default=1e-4, help="smooth value of dice loss")
53 | parser.add_argument("--dice_ohem", type=float, default=0.0, help="ohem ratio of dice loss")
54 | parser.add_argument("--dice_alpha", type=float, default=0.01, help="alpha value of adaptive dice loss")
55 | parser.add_argument("--dice_square", action="store_true", help="use square for dice loss")
56 | ## focal loss
57 | parser.add_argument("--focal_gamma", type=float, default=2, help="gamma for focal loss.")
58 | parser.add_argument("--focal_alpha", type=float, help="alpha for focal loss.")
59 |
60 | # only keep the best checkpoint after training.
61 | parser.add_argument("--only_keep_the_best_ckpt_after_training", action="store_true", help="only the best model checkpoint after training. ")
62 |
63 | return parser
64 |
--------------------------------------------------------------------------------
/utils/random_seed.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: random_seed.py
5 | # refer to :
6 | # issue: https://github.com/PyTorchLightning/pytorch-lightning/issues/1868
7 | # Please Notice:
8 | # set for trainer: https://pytorch-lightning.readthedocs.io/en/latest/trainer.html
9 | # from pytorch_lightning import Trainer, seed_everything
10 | # seed_everything(42)
11 | # sets seeds for numpy, torch, python.random and PYTHONHASHSEED.
12 | # model = Model()
13 | # trainer = Trainer(deterministic=True)
14 |
15 | import random
16 | import torch
17 | import numpy as np
18 | from pytorch_lightning import seed_everything
19 |
20 |
21 | def set_random_seed(seed: int):
22 | """set seeds for reproducibility"""
23 | random.seed(seed)
24 | np.random.seed(seed)
25 | torch.manual_seed(seed)
26 | torch.cuda.manual_seed_all(seed)
27 | seed_everything(seed=seed)
28 | torch.backends.cudnn.deterministic = True
29 | torch.backends.cudnn.benchmark = False
30 |
31 |
32 | if __name__ == '__main__':
33 | # without this line, x would be different in every execution.
34 | set_random_seed(0)
35 |
36 | x = np.random.random()
37 | print(x)
38 |
--------------------------------------------------------------------------------