├── .gitignore ├── LICENSE ├── README.md ├── datasets ├── __init__.py ├── collate_functions.py ├── data_sampling.py ├── mrc_ner_dataset.py ├── mrpc_dataset.py ├── mrpc_processor.py ├── squad_dataset.py ├── tnews_dataset.py └── truncate_dataset.py ├── loss ├── __init__.py ├── dice_loss.py └── focal_loss.py ├── metrics ├── __init__.py ├── classification_acc_f1.py ├── functional │ ├── __init__.py │ ├── cls_acc_f1.py │ ├── ner_span_f1.py │ └── squad │ │ ├── README.md │ │ ├── __init__.py │ │ ├── eval.sh │ │ ├── evaluate_v1.py │ │ ├── evaluate_v2.py │ │ └── postprocess_predication.py ├── mrc_ner_span_f1.py └── squad_em_f1.py ├── models ├── __init__.py ├── bert_classification.py ├── bert_qa.py ├── bert_query_ner.py ├── classifier.py └── model_config.py ├── requirements.txt ├── scripts ├── download_ckpt.sh ├── glue_mrpc │ ├── bert_base_ce.sh │ ├── bert_base_dice.sh │ ├── bert_base_focal.sh │ ├── bert_large_ce.sh │ ├── bert_large_dice.sh │ ├── bert_large_focal.sh │ └── eval.sh ├── mrc_squad1 │ ├── bert_base_ce.sh │ ├── bert_base_dice.sh │ ├── bert_base_focal.sh │ ├── bert_large_ce.sh │ ├── bert_large_dice.sh │ ├── bert_large_focal.sh │ ├── eval_pred_file.sh │ └── eval_saved_ckpt.sh ├── ner_enconll03 │ ├── bert_dice.sh │ └── bert_focal.sh ├── ner_enontonotes5 │ ├── bert_dice.sh │ └── bert_focal.sh ├── ner_zhmsra │ ├── bert_dice.sh │ └── bert_focal.sh ├── ner_zhonto4 │ ├── bert_dice.sh │ └── bert_focal.sh ├── prepare_ckpt.sh ├── prepare_mrpc_data.sh └── textcl_tnews │ ├── bert_dice.sh │ └── bert_focal.sh ├── tasks ├── glue │ ├── download_glue_data.py │ ├── evaluate_models.py │ ├── evaluate_predictions.py │ ├── mrpc_dev_ids.tsv │ ├── process_mrpc.py │ └── train.py ├── mrc_ner │ ├── data_preprocess │ │ ├── file_utils.py │ │ ├── label_utils.py │ │ └── query_map.py │ ├── eval.sh │ ├── evaluate.py │ ├── generate_mrc_dataset.py │ ├── train.py │ └── train.sh ├── pos │ └── train.py ├── squad │ ├── evaluate_models.py │ ├── evaluate_predictions.py │ └── train.py └── tnews │ ├── train.py │ └── train.sh ├── tests ├── count_length_autotokenizer.py └── count_length_glue.py └── utils ├── __init__.py ├── get_parser.py └── random_seed.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Xcode 2 | *.DS_Store 3 | .idea/* 4 | # Logs 5 | logs/* 6 | tasks/pos/* 7 | # 8 | experiments/* 9 | log/* 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # backup files 20 | bk 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # Environments 98 | .env 99 | .venv 100 | env/ 101 | venv/ 102 | ENV/ 103 | env.bak/ 104 | venv.bak/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | 119 | # Do Not push origin intermediate logging files 120 | *.log 121 | *.out 122 | 123 | # mac book 124 | .DS_Store 125 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dice Loss for NLP Tasks 2 | 3 | This repository contains code for [Dice Loss for Data-imbalanced NLP Tasks](https://arxiv.org/pdf/1911.02855.pdf) at ACL2020. 4 | 5 | ## Setup 6 | 7 | - Install Package Dependencies 8 | 9 | The code was tested in `Python 3.6.9+` and `Pytorch 1.7.1`. 10 | If you are working on ubuntu GPU machine with CUDA 10.1, please run the following command to setup environment.
11 | ```bash 12 | $ virtualenv -p /usr/bin/python3.6 venv 13 | $ source venv/bin/activate 14 | $ pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html 15 | $ pip install -r requirements.txt 16 | ``` 17 | 18 | - Download BERT Model Checkpoints 19 | 20 | Before running the repo you must download the `BERT-Base` and `BERT-Large` checkpoints from [here](https://github.com/google-research/bert#pre-trained-models) and unzip it to some directory `$BERT_DIR`. 21 | Then convert original TensorFlow checkpoints for BERT to a PyTorch saved file by running `bash scripts/prepare_ckpt.sh `. 22 | 23 | ## Apply Dice-Loss to NLP Tasks 24 | 25 | In this repository, we apply dice loss to four NLP tasks, including
26 | 1. machine reading comprehension 27 | 2. paraphrase identification task 28 | 3. named entity recognition 29 | 4. text classification 30 | 31 | ### 1. Machine Reading Comprehension 32 | 33 | ***Datasets***
34 | 35 | We take SQuAD 1.1 as an example. 36 | Before training, you should download a copy of the data from [here](https://rajpurkar.github.io/SQuAD-explorer/).
37 | And move the SQuAD 1.1 train `train-v1.1.json` and dev file `dev-v1.1.json` to the directory `$DATA_DIR`.
38 | 39 | ***Train***
40 | 41 | We choose BERT as the backbone. 42 | During training, the task trainer `BertForQA` will automatically evaluate on dev set every `$val_check_interval` epoch, 43 | and save the dev predictions into files called `$OUTPUT_DIR/predictions__.json` and `$OUTPUT_DIR/nbest_predictions__.json`. 44 | 45 | Run `scripts/mrc_squad1/bert__.sh` to reproduce our experimental results.
46 | The variable `` should take the value of `[base, large]`.
47 | The variable `` should take the value of `[bce, focal, dice]` which denotes fine-tuning `BERT-Base` with `binary cross entropy loss`, `focal loss`, `dice loss` , respectively.
48 | 49 | * Run `bash scripts/mrc_squad1/bert_base_focal.sh` to start training. After training, run `bash scripts/mrc_squad1/eval_pred_file.sh $DATA_DIR $OUTPUT_DIR` for focal loss.
50 | 51 | * Run `bash scripts/mrc_squad1/bert_base_dice.sh` to start training. After training, run `bash scripts/mrc_squad1/eval_pred_file.sh $DATA_DIR $OUTPUT_DIR` for dice loss.
52 | 53 | 54 | ***Evaluate***
55 | 56 | To evaluate a model checkpoint, please run 57 | ```bash 58 | python3 tasks/squad/evaluate_models.py \ 59 | --gpus="1" \ 60 | --path_to_model_checkpoint $OUTPUT_DIR/epoch=2.ckpt \ 61 | --eval_batch_size 62 | ``` 63 | After evaluation, prediction results `predictions_dev.json` and `nbest_predictions_dev.json` can be found in `$OUTPUT_DIR`
64 | 65 | To evaluate saved predictions, please run 66 | ```bash 67 | python3 tasks/squad/evaluate_predictions.py 68 | ``` 69 | 70 | ### 2. Paraphrase Identification Task 71 | 72 | ***Datasets***
73 | 74 | We use MRPC (GLUE Version) as an example. 75 | Before running experiments, you should download and save the processed dataset files to `$DATA_DIR`.
76 | 77 | Run `bash scripts/prepare_mrpc_data.sh $DATA_DIR` to download and process datasets for MPRC (GLUE Version) task. 78 | 79 | ***Train***
80 | 81 | Please run `scripts/glue_mrpc/bert__.sh` to train and evaluate on the dev set every `$val_check_interval` epoch. 82 | After training, the task trainer evaluates on the test set with the best checkpoint which achieves the highest F1-score on the dev set.
83 | The variable `` should take the value of `[base, large]`.
84 | The variable `` should take the value of `[focal, dice]` which denotes fine-tuning `BERT` with `focal loss`, `dice loss` , respectively. 85 | 86 | * Run `bash scripts/glue_mrpc/bert_large_focal.sh` for focal loss.
87 | 88 | * Run `bash scripts/glue_mrpc/bert_large_dice.sh` for dice loss.
89 | 90 | The evaluation results on the dev and test set are saved at `$OUTPUT_DIR/eval_result_log.txt` file.
91 | The intermediate model checkpoints are saved at most `$max_keep_ckpt` times. 92 | 93 | ***Evaluate***
94 | 95 | To evaluate a model checkpoint on test set, please run 96 | ```bash 97 | bash scripts/glue_mrpc/eval.sh \ 98 | $OUTPUT_DIR \ 99 | epoch=*.ckpt 100 | ``` 101 | 102 | ### 3. Named Entity Recognition 103 | 104 | For NER, we use MRC-NER model as the backbone.
105 | Processed datasets and model architecture can be found [here](https://arxiv.org/pdf/1910.11476.pdf). 106 | 107 | ***Train***
108 | 109 | Please run `scripts//bert_.sh` to train and evaluate on the dev set every `$val_check_interval` epoch. 110 | After training, the task trainer evaluates on the test set with the best checkpoint.
111 | The variable `` should take the value of `[ner_enontonotes5, ner_zhmsra, ner_zhonto4]`.
112 | The variable `` should take the value of `[focal, dice]` which denotes fine-tuning `BERT` with `focal loss`, `dice loss` , respectively. 113 | 114 | For Chinese MSRA,
115 | * Run `scripts/ner_zhmsra/bert_focal.sh` for focal loss.
116 | 117 | * Run `scripts/ner_zhmsra/bert_dice.sh` for dice loss.
118 | 119 | For Chinese OntoNotes4,
120 | * Run `scripts/ner_zhonto4/bert_focal.sh` for focal loss.
121 | 122 | * Run `scripts/ner_zhonto4/bert_dice.sh` for dice loss.
123 | 124 | For English CoNLL03,
125 | * Run `scritps/ner_enconll03/bert_focal.sh`. After training, you will get 93.08 Span-F1 on the test set.
126 | 127 | * Run `scripts/ner_enconll03/bert_dice.sh`. After training, you will get 93.21 Span-F1 on the test set.
128 | 129 | For English OntoNotes5,
130 | * Run `scripts/ner_enontonotes5/bert_focal.sh`. After training, you will get 91.12 Span-F1 on the test set.
131 | 132 | * Run `scripts/ner_enontonotes5/bert_dice.sh`. After training, you will get 92.01 Span-F1 on the test set.
133 | 134 | ***Evaluate***
135 | 136 | To evaluate a model checkpoint, please run 137 | ```bash 138 | CUDA_VISIBLE_DEVICES=0 python3 ${REPO_PATH}/tasks/mrc_ner/evaluate.py \ 139 | --gpus="1" \ 140 | --path_to_model_checkpoint $OUTPUT_DIR/epoch=2.ckpt 141 | ``` 142 | 143 | ### 4. Text Classification 144 | 145 | ***Datasets***
146 | 147 | We use TNews (Chinese Text Classification) as an example. 148 | Before running experiments, you should [download](https://storage.googleapis.com/cluebenchmark/tasks/tnews_public.zip) and save the processed dataset files to `$DATA_DIR`.
149 | 150 | ***Train***
151 | 152 | We choose BERT as the backbone.
153 | Please run `scripts/textcl_tnews/bert_.sh` to train and evaluate on the dev set every `$val_check_interval` epoch. 154 | The variable `` should take the value of `[focal, dice]` which denotes fine-tuning `BERT` with `focal loss`, `dice loss` , respectively. 155 | 156 | * Run `bash scripts/textcl_tnews/bert_focal.sh` for focal loss.
157 | 158 | * Run `bash scripts/textcl_tnews/bert_dice.sh` for dice loss.
159 | 160 | The intermediate model checkpoints are saved at most `$max_keep_ckpt` times. 161 | 162 | 163 | ## Citation 164 | 165 | If you find this repository useful , please cite the following: 166 | 167 | ```tex 168 | @article{li2019dice, 169 | title={Dice loss for data-imbalanced NLP tasks}, 170 | author={Li, Xiaoya and Sun, Xiaofei and Meng, Yuxian and Liang, Junjun and Wu, Fei and Li, Jiwei}, 171 | journal={arXiv preprint arXiv:1911.02855}, 172 | year={2019} 173 | } 174 | ``` 175 | 176 | ## Contact 177 | 178 | xxiaoyali [AT] gmail.com OR xiaoya_li [AT] shannonai.com 179 | 180 | Any discussions, suggestions and questions are welcome! 181 | 182 | 183 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShannonAI/dice_loss_for_NLP/d437bb999185535df46fdb74d1f2f57161331b44/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/collate_functions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import torch 6 | from typing import List 7 | 8 | 9 | def collate_to_max_length(batch: List[List[torch.Tensor]]) -> List[torch.Tensor]: 10 | """ 11 | pad to maximum length of this batch 12 | Args: 13 | batch: a batch of samples, each contains a list of field data(Tensor): 14 | tokens, attention_mask, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, label_idx 15 | Returns: 16 | output: list of field batched data, which shape is [batch, max_length] 17 | """ 18 | batch_size = len(batch) 19 | max_length = max(x[0].shape[0] for x in batch) 20 | output = [] 21 | 22 | for field_idx in range(7): 23 | pad_output = torch.full([batch_size, max_length], 0, dtype=batch[0][field_idx].dtype) 24 | for sample_idx in range(batch_size): 25 | data = batch[sample_idx][field_idx] 26 | pad_output[sample_idx][: data.shape[0]] = data 27 | output.append(pad_output) 28 | 29 | pad_match_labels = torch.zeros([batch_size, max_length, max_length], dtype=torch.long) 30 | for sample_idx in range(batch_size): 31 | data = batch[sample_idx][7] 32 | pad_match_labels[sample_idx, : data.shape[1], : data.shape[1]] = data 33 | output.append(pad_match_labels) 34 | output.append(torch.stack([x[8] for x in batch])) 35 | if len(batch[0]) == 9: 36 | return output 37 | 38 | output.append(torch.stack([x[9] for x in batch])) 39 | return output 40 | 41 | -------------------------------------------------------------------------------- /datasets/data_sampling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import random 6 | 7 | def sample_positive_and_negative_by_ratio(positive_data_lst, negative_data_lst, ratio=0.5): 8 | num_negative_examples = int(len(positive_data_lst) * ratio) 9 | 10 | random.shuffle(negative_data_lst) 11 | truncated_negative_data_lst = random.sample(negative_data_lst, num_negative_examples) 12 | all_data_lst = positive_data_lst + truncated_negative_data_lst 13 | # need to use random data sampler 14 | return all_data_lst 15 | 16 | 17 | def undersample_majority_classes(data_lst, label_lst): 18 | pass 19 | 20 | 21 | def oversample_minority_classes(data_lst, sampling_strategy=None): 22 | collect_data_by_label = {} 23 | 24 | for data_item in data_lst: 25 | data_item_label = data_item["label"] 26 | if data_item_label not in collect_data_by_label.keys(): 27 | collect_data_by_label[data_item_label] = [data_item] 28 | else: 29 | collect_data_by_label[data_item_label].append(data_item) 30 | 31 | count_data_by_label = {key: len(value) for key, value in collect_data_by_label.items()} -------------------------------------------------------------------------------- /datasets/mrc_ner_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: mrc_ner_dataset.py 5 | 6 | import json 7 | import torch 8 | from torch.utils.data import Dataset 9 | from transformers import AutoTokenizer 10 | from datasets.data_sampling import sample_positive_and_negative_by_ratio 11 | 12 | 13 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text, return_subtoken_start=False): 14 | """Returns tokenized answer spans that better match the annotated answer.""" 15 | doc_tokens = [str(tmp) for tmp in doc_tokens] 16 | answer_tokens = tokenizer.encode(orig_answer_text, add_special_tokens=False) 17 | tok_answer_text = " ".join([str(tmp) for tmp in answer_tokens]) 18 | for new_start in range(input_start, input_end + 1): 19 | for new_end in range(input_end, new_start - 1, -1): 20 | text_span = " ".join(doc_tokens[new_start : (new_end+1)]) 21 | if text_span == tok_answer_text: 22 | if not return_subtoken_start: 23 | return (new_start, new_end) 24 | tokens = tokenizer.convert_ids_to_tokens(doc_tokens[new_start: (new_end + 1)]) 25 | if "##" not in tokens[-1]: 26 | return (new_start, new_end) 27 | else: 28 | for idx in range(len(tokens)-1, -1, -1): 29 | if "##" not in tokens[idx]: 30 | new_end = new_end - (len(tokens)-1 - idx) 31 | return (new_start, new_end) 32 | 33 | return (input_start, input_end) 34 | 35 | 36 | class MRCNERDataset(Dataset): 37 | """ 38 | MRC NER Dataset 39 | Args: 40 | json_path: path to mrc-ner style json 41 | tokenizer: BertTokenizer 42 | max_length: int, max length of query+context 43 | possible_only: if True, only use possible samples that contain answer for the query/context 44 | is_chinese: is chinese dataset 45 | """ 46 | def __init__(self, json_path, tokenizer: AutoTokenizer, max_length: int = 512, possible_only=False, is_chinese=False, 47 | pad_to_maxlen=False, negative_sampling=False, prefix="train", data_sign="zh_onto", do_lower_case=False, 48 | pred_answerable=True): 49 | self.all_data = json.load(open(json_path, encoding="utf-8")) 50 | self.tokenzier = tokenizer 51 | self.max_length = max_length 52 | self.do_lower_case = do_lower_case 53 | self.label2idx = {value:key for key, value in enumerate(MRCNERDataset.get_labels(data_sign))} 54 | 55 | if prefix == "train" and negative_sampling: 56 | neg_data_items = [x for x in self.all_data if not x["start_position"]] 57 | pos_data_items = [x for x in self.all_data if x["start_position"]] 58 | self.all_data = sample_positive_and_negative_by_ratio(pos_data_items, neg_data_items) 59 | elif prefix == "train" and possible_only: 60 | self.all_data = [ 61 | x for x in self.all_data if x["start_position"] 62 | ] 63 | else: 64 | pass 65 | 66 | self.is_chinese = is_chinese 67 | self.pad_to_maxlen = pad_to_maxlen 68 | self.pred_answerable = pred_answerable 69 | 70 | def __len__(self): 71 | return len(self.all_data) 72 | 73 | def __getitem__(self, item): 74 | """ 75 | Args: 76 | item: int, idx 77 | Returns: 78 | tokens: tokens of query + context, [seq_len] 79 | token_type_ids: token type ids, 0 for query, 1 for context, [seq_len] 80 | start_labels: start labels of NER in tokens, [seq_len] 81 | end_labels: end labelsof NER in tokens, [seq_len] 82 | label_mask: label mask, 1 for counting into loss, 0 for ignoring. shape of [seq_len]. 1 for no-subword context tokens. 0 for query tokens and [CLS] [SEP] tokens. 83 | match_labels: match labels, [seq_len, seq_len] 84 | sample_idx: sample id 85 | label_idx: label id 86 | 87 | """ 88 | data = self.all_data[item] 89 | tokenizer = self.tokenzier 90 | label_idx = torch.tensor(self.label2idx[data["entity_label"]], dtype=torch.long) 91 | 92 | if self.is_chinese: 93 | query = "".join(data["query"].strip().split()) 94 | context = "".join(data["context"].strip().split()) 95 | else: 96 | query = data["query"] 97 | context = data["context"] 98 | 99 | start_positions = data["start_position"] 100 | end_positions = data["end_position"] 101 | 102 | query_context_tokens = tokenizer.encode_plus(query, context, 103 | add_special_tokens=True, 104 | max_length=self.max_length, 105 | return_overflowing_tokens=True, 106 | return_token_type_ids=True) 107 | 108 | if tokenizer.pad_token_id in query_context_tokens["input_ids"]: 109 | non_padded_ids = query_context_tokens["input_ids"][: query_context_tokens["input_ids"].index(tokenizer.pad_token_id)] 110 | else: 111 | non_padded_ids = query_context_tokens["input_ids"] 112 | 113 | non_pad_tokens = tokenizer.convert_ids_to_tokens(non_padded_ids) 114 | first_sep_token = non_pad_tokens.index("[SEP]") 115 | end_sep_token = len(non_pad_tokens) - 1 116 | new_start_positions = [] 117 | new_end_positions = [] 118 | if len(start_positions) != 0: 119 | for start_index, end_index in zip(start_positions, end_positions): 120 | if self.is_chinese: 121 | answer_text_span = " ".join(context[start_index: end_index+1]) 122 | else: 123 | answer_text_span = " ".join(context.split(" ")[start_index: end_index+1]) 124 | new_start, new_end = _improve_answer_span(query_context_tokens["input_ids"], first_sep_token, end_sep_token, self.tokenzier, answer_text_span) 125 | new_start_positions.append(new_start) 126 | new_end_positions.append(new_end) 127 | else: 128 | new_start_positions = start_positions 129 | new_end_positions = end_positions 130 | 131 | # clip out-of-boundary entity positions. 132 | new_start_positions = [start_pos for start_pos in new_start_positions if start_pos < self.max_length] 133 | new_end_positions = [end_pos for end_pos in new_end_positions if end_pos < self.max_length] 134 | 135 | tokens = query_context_tokens["input_ids"] 136 | token_type_ids = query_context_tokens['token_type_ids'] 137 | # token_type_ids -> 0 for query tokens and 1 for context tokens. 138 | attention_mask = query_context_tokens['attention_mask'] 139 | start_labels = [(1 if idx in new_start_positions else 0) for idx in range(len(tokens))] 140 | end_labels = [(1 if idx in new_end_positions else 0) for idx in range(len(tokens))] 141 | label_mask = [1 if token_type_ids[token_idx] == 1 else 0 for token_idx in range(len(tokens))] 142 | start_label_mask = label_mask.copy() 143 | end_label_mask = label_mask.copy() 144 | 145 | seq_len = len(tokens) 146 | match_labels = torch.zeros([seq_len, seq_len], dtype=torch.long) 147 | for start, end in zip(new_start_positions, new_end_positions): 148 | if start >= seq_len or end >= seq_len: 149 | continue 150 | match_labels[start, end] = 1 151 | 152 | if self.pred_answerable: 153 | answerable_label = 1 if len(new_start_positions) != 0 else 0 154 | return [torch.tensor(tokens, dtype=torch.long), 155 | torch.tensor(attention_mask, dtype=torch.long), 156 | torch.tensor(token_type_ids, dtype=torch.long), 157 | torch.tensor(start_labels, dtype=torch.long), 158 | torch.tensor(end_labels, dtype=torch.long), 159 | torch.tensor(start_label_mask, dtype=torch.long), 160 | torch.tensor(end_label_mask, dtype=torch.long), 161 | match_labels, 162 | label_idx, 163 | torch.tensor([answerable_label], dtype=torch.long)] 164 | 165 | return [torch.tensor(tokens, dtype=torch.long), 166 | torch.tensor(attention_mask, dtype=torch.long), 167 | torch.tensor(token_type_ids, dtype=torch.long), 168 | torch.tensor(start_labels, dtype=torch.long), 169 | torch.tensor(end_labels, dtype=torch.long), 170 | torch.tensor(start_label_mask, dtype=torch.long), 171 | torch.tensor(end_label_mask, dtype=torch.long), 172 | match_labels, 173 | label_idx] 174 | 175 | @classmethod 176 | def get_labels(cls, data_sign): 177 | """gets the list of labels for this data set.""" 178 | if data_sign == "zh_onto": 179 | return ["GPE", "LOC", "PER", "ORG"] 180 | elif data_sign == "zh_msra": 181 | return ["NS", "NR", "NT"] 182 | elif data_sign == "en_onto": 183 | return ["LAW", "EVENT", "CARDINAL", "FAC", "TIME", "DATE", "ORDINAL", "ORG", "QUANTITY", \ 184 | "PERCENT", "WORK_OF_ART", "LOC", "LANGUAGE", "NORP", "MONEY", "PERSON", "GPE", "PRODUCT"] 185 | elif data_sign == "en_conll03": 186 | return ["ORG", "PER", "LOC", "MISC"] 187 | return ["0", "1"] 188 | 189 | 190 | -------------------------------------------------------------------------------- /datasets/mrpc_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: mrpc_dataset.py 5 | # description: 6 | # dataset processor for semantic textual similarity task MRPC 7 | # train: 3669, dev: 1726, test: 1726 8 | 9 | from collections import namedtuple 10 | from typing import Dict, Optional, List, Union 11 | 12 | import torch 13 | from torch.utils.data import TensorDataset 14 | from torch.utils.data.dataset import Dataset 15 | 16 | from transformers import BertTokenizer 17 | from datasets.mrpc_processor import MRPCProcessor, MrpcDataExample 18 | 19 | MrpcDataFeature = namedtuple("MrpcDataFeature", ["input_ids", "attention_mask", "token_type_ids", "label"]) 20 | 21 | 22 | class MRPCDataset(Dataset): 23 | def __init__(self, 24 | args, 25 | tokenizer: BertTokenizer, 26 | mode: Optional[str] = "train", 27 | cache_dir: Optional[str] = None, 28 | debug: Optional[bool] = False): 29 | 30 | self.args = args 31 | self.tokenizer = tokenizer 32 | self.mode = mode 33 | self.processor = MRPCProcessor(self.args.data_dir) 34 | self.debug = debug 35 | self.cache_dir = cache_dir 36 | self.max_seq_length = self.args.max_seq_length 37 | 38 | if self.mode == "dev": 39 | self.examples = self.processor.get_dev_examples() 40 | elif self.mode == "test": 41 | self.examples = self.processor.get_test_examples() 42 | else: 43 | self.examples = self.processor.get_train_examples() 44 | 45 | self.features, self.dataset = mrpc_convert_examples_to_features( 46 | examples=self.examples, 47 | tokenizer=tokenizer, 48 | max_length=self.max_seq_length, 49 | label_list=MRPCProcessor.get_labels(), 50 | is_training= mode == "train",) 51 | 52 | def __len__(self): 53 | return len(self.features) 54 | 55 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 56 | # convert to Tensors and build dataset 57 | feature = self.features[i] 58 | 59 | input_ids = torch.tensor(feature.input_ids, dtype=torch.long) 60 | attention_mask = torch.tensor(feature.attention_mask, dtype=torch.long) 61 | token_type_ids = torch.tensor(feature.token_type_ids, dtype=torch.long) 62 | label = torch.tensor(feature.label, dtype=torch.long) 63 | 64 | inputs = { 65 | "input_ids": input_ids, 66 | "token_type_ids": token_type_ids, 67 | "attention_mask": attention_mask, 68 | "label": label 69 | } 70 | 71 | return inputs 72 | 73 | 74 | def mrpc_convert_examples_to_features(examples: Union[List[MrpcDataExample]], 75 | tokenizer: BertTokenizer, 76 | max_length: int = 256, 77 | label_list: Union[List[str]] = None, 78 | is_training: bool = False,): 79 | """ 80 | Description: 81 | GLUE Version 82 | - test.tsv -> index #1 ID #2 ID #1 String #2 String 83 | - train/dev.tsv -> Quality #1 ID #2 ID #1 String #2 String 84 | """ 85 | 86 | label_map = {label: i for i, label in enumerate(label_list)} 87 | 88 | labels = [label_map[example.label] for example in examples] 89 | batch_encoding = tokenizer( 90 | [(example.text_a, example.text_b) for example in examples], 91 | max_length=max_length, padding="max_length", truncation=True 92 | ) 93 | 94 | features = [] 95 | for i in range(len(examples)): 96 | inputs = {k: batch_encoding[k][i] for k in batch_encoding} 97 | 98 | feature = MrpcDataFeature(**inputs, label=labels[i]) 99 | features.append(feature) 100 | 101 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 102 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) 103 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) 104 | all_label = torch.tensor([f.label for f in features], dtype=torch.long) 105 | 106 | dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_label) 107 | return features, dataset 108 | 109 | -------------------------------------------------------------------------------- /datasets/mrpc_processor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file name: mrpc_processor.py 5 | # description: 6 | # code for loading data samples from files. 7 | 8 | import os 9 | import csv 10 | from collections import namedtuple 11 | 12 | MrpcDataExample = namedtuple("DataExample", ["guid", "text_a", "text_b", "label"]) 13 | 14 | 15 | class MRPCProcessor: 16 | """ 17 | Processor for the MRPC data set. 18 | """ 19 | def __init__(self, data_dir): 20 | self.data_dir = data_dir 21 | self.train_file = os.path.join(data_dir, "train.tsv") 22 | self.dev_file = os.path.join(data_dir, "dev.tsv") 23 | # TODO: add test.tsv processing 24 | self.test_file = os.path.join(data_dir, "msr_paraphrase_test.txt") 25 | 26 | def get_train_examples(self, ): 27 | return self._create_examples(self._read_tsv(self.train_file), "train") 28 | 29 | def get_dev_examples(self, ): 30 | return self._create_examples(self._read_tsv(self.dev_file), "dev") 31 | 32 | def get_test_examples(self, ): 33 | return self._create_examples(self._read_tsv(self.test_file), "test") 34 | 35 | def _create_examples(self, lines, set_type): 36 | """create examples for the train/dev/test datasets""" 37 | examples = [] 38 | for idx, line in enumerate(lines): 39 | if idx == 0: 40 | continue 41 | guid = "%s-%s" % (set_type, idx) 42 | text_a = line[3] 43 | text_b = line[4] 44 | label = line[0] 45 | examples.append(MrpcDataExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 46 | return examples 47 | 48 | def _read_tsv(self, input_file, quotechar=None): 49 | """reads a tab separated value file.""" 50 | with open(input_file, "r", encoding="utf-8") as f: 51 | return list(csv.reader(f, delimiter="\t", quotechar=quotechar)) 52 | 53 | @classmethod 54 | def get_labels(cls, ): 55 | """gets the list of labels for this data set.""" 56 | return ["0", "1"] 57 | 58 | 59 | 60 | if __name__ == "__main__": 61 | num_labels = MRPCProcessor.get_labels() 62 | print(num_labels) -------------------------------------------------------------------------------- /datasets/squad_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: squad_dataset.py 5 | # description: 6 | # dataset class for the squad task. 7 | # NOTICE: 8 | # https://github.com/huggingface/transformers/issues/7735 9 | # fast tokenizers don’t currently work with the QA pipeline. 10 | 11 | import os 12 | from typing import Dict, Optional 13 | 14 | import torch 15 | from torch.utils.data.dataset import Dataset 16 | 17 | from transformers import AutoTokenizer 18 | from transformers.data.processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor 19 | from transformers.data.processors.squad import squad_convert_examples_to_features 20 | 21 | 22 | class SquadDataset(Dataset): 23 | def __init__(self, 24 | args, 25 | tokenizer: AutoTokenizer, 26 | mode: Optional[str] = "train", 27 | is_language_sensitive: Optional[bool] = False, 28 | cache_dir: Optional[str] = None, 29 | dataset_format: Optional[str] = "pt", 30 | threads: Optional[int] = 1, 31 | debug: Optional[bool] = False,): 32 | 33 | self.args = args 34 | self.tokenizer = tokenizer 35 | self.is_language_sensitive = is_language_sensitive 36 | self.processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor() 37 | self.mode = mode 38 | self.debug = debug 39 | self.threads = threads 40 | 41 | self.max_seq_length = self.args.max_seq_length 42 | self.doc_stride = self.args.doc_stride 43 | self.max_query_length = self.args.max_query_length 44 | 45 | # dataset format configurations 46 | self.column_names = ["id", "title", "context", "question", "answers" ] 47 | self.question_column_name = "question" if "question" in self.column_names else self.column_names[0] 48 | self.context_column_name = "context" if "context" in self.column_names else self.column_names[1] 49 | self.answer_column_name = "answers" if "answers" in self.column_names else self.column_names[2] 50 | 51 | # Padding side determines if we do (question|context) or (context|question). 52 | self.pad_on_right = tokenizer.padding_side == "right" 53 | 54 | # load data features from cache or dataset file 55 | version_tag = "v2" if args.version_2_with_negative else "v1" 56 | cached_features_file = os.path.join( 57 | cache_dir if cache_dir is not None else args.data_dir, 58 | "cached_{}_{}_{}_{}".format( 59 | mode, 60 | tokenizer.__class__.__name__, 61 | str(args.max_seq_length), 62 | version_tag 63 | ) 64 | ) 65 | 66 | self.cached_data_file = cached_features_file 67 | 68 | # Make sure only the first process in distributed training processes the dataset, 69 | # and the others will use the cache. 70 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 71 | self.old_features = torch.load(cached_features_file) 72 | 73 | # legacy cache files have only features, 74 | # which new cache files will have dataset and examples also. 75 | self.features = self.old_features["features"] 76 | self.dataset = self.old_features.get("dataset", None) 77 | self.examples = self.old_features.get("examples", None) 78 | 79 | if self.dataset is None or self.examples is None: 80 | raise ValueError 81 | else: 82 | if self.mode == "dev": 83 | self.examples = self.processor.get_dev_examples(args.data_dir) 84 | else: 85 | self.examples = self.processor.get_train_examples(args.data_dir) 86 | 87 | if self.debug: 88 | print(f"DEBUG INFO -> already load {self.mode} data ...") 89 | print(f"DEBUG INFO -> show 2 EXAMPLES ...") 90 | for idx, data_examples in enumerate(self.examples): 91 | # data_examples should be an object of transformers.data.processors.squad.SquadExample 92 | if idx <= 2: 93 | print(f"DEBUG INFO -> {idx}, {data_examples}") 94 | print(f"{idx} qas_id -> {data_examples.qas_id}") 95 | print(f"{idx} question_text -> {data_examples.question_text}") 96 | print(f"{idx} context_text -> {data_examples.context_text}") 97 | print(f"{idx} answer_text -> {data_examples.answer_text}") 98 | print("-*-"*10) 99 | 100 | self.features, self.dataset = squad_convert_examples_to_features( 101 | examples=self.examples, 102 | tokenizer=tokenizer, 103 | max_seq_length=self.max_seq_length, 104 | doc_stride=self.doc_stride, 105 | max_query_length=self.max_query_length, 106 | is_training= mode == "train", 107 | threads=self.threads, 108 | return_dataset=dataset_format, 109 | ) 110 | 111 | torch.save( 112 | {"features": self.features, "dataset": self.dataset, "examples": self.examples}, 113 | cached_features_file,) 114 | 115 | def __len__(self): 116 | return len(self.features) 117 | 118 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 119 | # convert to Tensors and build dataset 120 | feature = self.features[i] 121 | 122 | input_ids = torch.tensor(feature.input_ids, dtype=torch.long) 123 | attention_mask = torch.tensor(feature.attention_mask, dtype=torch.long) 124 | token_type_ids = torch.tensor(feature.token_type_ids, dtype=torch.long) 125 | # only for "xlnet", "xlm" models. 126 | # cls_index = torch.tensor(feature.cls_index, dtype=torch.long) 127 | # p_mask = torch.tensor(feature.p_mask, dtype=torch.float) 128 | # is_impossible = torch.tensor(feature.is_impossible, dtype=torch.float) 129 | 130 | inputs = { 131 | "input_ids": input_ids, 132 | "attention_mask": attention_mask, 133 | "token_type_ids": token_type_ids 134 | } 135 | 136 | label_mask = [1] + feature.token_type_ids[1:] 137 | 138 | start_labels = torch.tensor(feature.start_position, dtype=torch.long) 139 | end_labels = torch.tensor(feature.end_position, dtype=torch.long) 140 | label_mask = torch.tensor(label_mask, dtype=torch.long) 141 | 142 | inputs.update({"start_labels": start_labels, "end_labels": end_labels, "label_mask": label_mask}) 143 | 144 | if self.mode != "train": 145 | inputs.update({"unique_id": feature.unique_id}) 146 | 147 | return inputs 148 | 149 | 150 | 151 | -------------------------------------------------------------------------------- /datasets/tnews_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: tnews_dataset.py 5 | # Data Example: 6 | # {"label": "113", "label_desc": "news_world", "sentence": "日本虎视眈眈“武力夺岛”, 美军向俄后院开火,普京终不再忍!", "keywords": "普京,北方四岛,安倍,俄罗斯,黑海"} 7 | 8 | import os 9 | import json 10 | import torch 11 | from torch.utils.data import Dataset 12 | from tokenizers import BertWordPieceTokenizer 13 | 14 | class TNewsDataset(Dataset): 15 | def __init__(self, prefix: str = "train", data_dir: str = "", tokenizer: BertWordPieceTokenizer = None, max_length: int = 512): 16 | super().__init__() 17 | self.data_prefix = prefix 18 | self.max_length = max_length 19 | data_file = os.path.join(data_dir, f"{prefix}.json") 20 | with open(data_file, "r", encoding="utf-8") as f: 21 | data_items = f.readlines() 22 | self.data_items = data_items 23 | self.tokenizer = tokenizer 24 | self.labels2id = {value: key for key, value in enumerate(TNewsDataset.get_labels())} 25 | 26 | def __len__(self): 27 | return len(self.data_items) 28 | 29 | def __getitem__(self, idx): 30 | """ 31 | Description: 32 | for single sentence task, BERTWordPieceTokenizer will [CLS}++[SEP] 33 | Returns: 34 | input_token_ids, token_type_ids, attention_mask, label_id 35 | """ 36 | data_item = self.data_items[idx] 37 | data_item = json.loads(data_item) 38 | label, sentence = data_item["label"], data_item["sentence"] 39 | label_id = self.labels2id[label] 40 | sentence = sentence[: self.max_length-3] 41 | tokenizer_output = self.tokenizer.encode(sentence) 42 | 43 | tokens = tokenizer_output.ids + (self.max_length - len(tokenizer_output.ids)) * [0] 44 | token_type_ids = tokenizer_output.type_ids + (self.max_length - len(tokenizer_output.type_ids)) * [0] 45 | attention_mask = tokenizer_output.attention_mask + (self.max_length - len(tokenizer_output.attention_mask)) * [0] 46 | 47 | input_token_ids = torch.tensor(tokens, dtype=torch.long) 48 | token_type_ids = torch.tensor(token_type_ids, dtype=torch.long) 49 | attention_mask = torch.tensor(attention_mask, dtype=torch.long) 50 | label_id = torch.tensor(label_id, dtype=torch.long) 51 | 52 | return input_token_ids, token_type_ids, attention_mask, label_id 53 | 54 | @classmethod 55 | def get_labels(cls, ): 56 | return ['100', '101', '102', '103', '104', '106', '107', '108', '109', '110', '112', '113', '114', '115', '116'] 57 | 58 | 59 | -------------------------------------------------------------------------------- /datasets/truncate_dataset.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class TruncateDataset(Dataset): 7 | """Truncate dataset to certain num""" 8 | def __init__(self, dataset: Dataset, max_num: int = 100): 9 | self.dataset = dataset 10 | self.max_num = min(max_num, len(self.dataset)) 11 | 12 | def __len__(self): 13 | return self.max_num 14 | 15 | def __getitem__(self, item): 16 | return self.dataset[item] 17 | 18 | def __getattr__(self, item): 19 | """other dataset func""" 20 | return getattr(self.dataset, item) 21 | -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .dice_loss import DiceLoss 2 | from .focal_loss import FocalLoss -------------------------------------------------------------------------------- /loss/dice_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: dice_loss.py 5 | # description: 6 | # implementation of dice loss for NLP tasks. 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch import Tensor 12 | from typing import Optional 13 | 14 | 15 | class DiceLoss(nn.Module): 16 | """ 17 | Dice coefficient for short, is an F1-oriented statistic used to gauge the similarity of two sets. 18 | Given two sets A and B, the vanilla dice coefficient between them is given as follows: 19 | Dice(A, B) = 2 * True_Positive / (2 * True_Positive + False_Positive + False_Negative) 20 | = 2 * |A and B| / (|A| + |B|) 21 | 22 | Math Function: 23 | U-NET: https://arxiv.org/abs/1505.04597.pdf 24 | dice_loss(p, y) = 1 - numerator / denominator 25 | numerator = 2 * \sum_{1}^{t} p_i * y_i + smooth 26 | denominator = \sum_{1}^{t} p_i + \sum_{1} ^{t} y_i + smooth 27 | if square_denominator is True, the denominator is \sum_{1}^{t} (p_i ** 2) + \sum_{1} ^{t} (y_i ** 2) + smooth 28 | V-NET: https://arxiv.org/abs/1606.04797.pdf 29 | Args: 30 | smooth (float, optional): a manual smooth value for numerator and denominator. 31 | square_denominator (bool, optional): [True, False], specifies whether to square the denominator in the loss function. 32 | with_logits (bool, optional): [True, False], specifies whether the input tensor is normalized by Sigmoid/Softmax funcs. 33 | ohem_ratio: max ratio of positive/negative, defautls to 0.0, which means no ohem. 34 | alpha: dsc alpha 35 | Shape: 36 | - input: (*) 37 | - target: (*) 38 | - mask: (*) 0,1 mask for the input sequence. 39 | - Output: Scalar loss 40 | Examples: 41 | >>> loss = DiceLoss(with_logits=True, ohem_ratio=0.1) 42 | >>> input = torch.FloatTensor([2, 1, 2, 2, 1]) 43 | >>> input.requires_grad=True 44 | >>> target = torch.LongTensor([0, 1, 0, 0, 0]) 45 | >>> output = loss(input, target) 46 | >>> output.backward() 47 | """ 48 | def __init__(self, 49 | smooth: Optional[float] = 1e-4, 50 | square_denominator: Optional[bool] = False, 51 | with_logits: Optional[bool] = True, 52 | ohem_ratio: float = 0.0, 53 | alpha: float = 0.0, 54 | reduction: Optional[str] = "mean", 55 | index_label_position=True) -> None: 56 | super(DiceLoss, self).__init__() 57 | 58 | self.reduction = reduction 59 | self.with_logits = with_logits 60 | self.smooth = smooth 61 | self.square_denominator = square_denominator 62 | self.ohem_ratio = ohem_ratio 63 | self.alpha = alpha 64 | self.index_label_position = index_label_position 65 | 66 | def forward(self, input: Tensor, target: Tensor, mask: Optional[Tensor] = None) -> Tensor: 67 | logits_size = input.shape[-1] 68 | 69 | if logits_size != 1: 70 | loss = self._multiple_class(input, target, logits_size, mask=mask) 71 | else: 72 | loss = self._binary_class(input, target, mask=mask) 73 | 74 | if self.reduction == "mean": 75 | return loss.mean() 76 | if self.reduction == "sum": 77 | return loss.sum() 78 | return loss 79 | 80 | def _compute_dice_loss(self, flat_input, flat_target): 81 | flat_input = ((1 - flat_input) ** self.alpha) * flat_input 82 | interection = torch.sum(flat_input * flat_target, -1) 83 | if not self.square_denominator: 84 | loss = 1 - ((2 * interection + self.smooth) / 85 | (flat_input.sum() + flat_target.sum() + self.smooth)) 86 | else: 87 | loss = 1 - ((2 * interection + self.smooth) / 88 | (torch.sum(torch.square(flat_input, ), -1) + torch.sum(torch.square(flat_target), -1) + self.smooth)) 89 | 90 | return loss 91 | 92 | def _multiple_class(self, input, target, logits_size, mask=None): 93 | flat_input = input 94 | flat_target = F.one_hot(target, num_classes=logits_size).float() if self.index_label_position else target.float() 95 | flat_input = torch.nn.Softmax(dim=1)(flat_input) if self.with_logits else flat_input 96 | 97 | if mask is not None: 98 | mask = mask.float() 99 | flat_input = flat_input * mask 100 | flat_target = flat_target * mask 101 | else: 102 | mask = torch.ones_like(target) 103 | 104 | loss = None 105 | if self.ohem_ratio > 0 : 106 | mask_neg = torch.logical_not(mask) 107 | for label_idx in range(logits_size): 108 | pos_example = target == label_idx 109 | neg_example = target != label_idx 110 | 111 | pos_num = pos_example.sum() 112 | neg_num = mask.sum() - (pos_num - (mask_neg & pos_example).sum()) 113 | keep_num = min(int(pos_num * self.ohem_ratio / logits_size), neg_num) 114 | 115 | if keep_num > 0: 116 | neg_scores = torch.masked_select(flat_input, neg_example.view(-1, 1).bool()).view(-1, logits_size) 117 | neg_scores_idx = neg_scores[:, label_idx] 118 | neg_scores_sort, _ = torch.sort(neg_scores_idx, ) 119 | threshold = neg_scores_sort[-keep_num + 1] 120 | cond = (torch.argmax(flat_input, dim=1) == label_idx & flat_input[:, label_idx] >= threshold) | pos_example.view(-1) 121 | ohem_mask_idx = torch.where(cond, 1, 0) 122 | 123 | flat_input_idx = flat_input[:, label_idx] 124 | flat_target_idx = flat_target[:, label_idx] 125 | 126 | flat_input_idx = flat_input_idx * ohem_mask_idx 127 | flat_target_idx = flat_target_idx * ohem_mask_idx 128 | else: 129 | flat_input_idx = flat_input[:, label_idx] 130 | flat_target_idx = flat_target[:, label_idx] 131 | 132 | loss_idx = self._compute_dice_loss(flat_input_idx.view(-1, 1), flat_target_idx.view(-1, 1)) 133 | if loss is None: 134 | loss = loss_idx 135 | else: 136 | loss += loss_idx 137 | return loss 138 | 139 | else: 140 | for label_idx in range(logits_size): 141 | pos_example = target == label_idx 142 | flat_input_idx = flat_input[:, label_idx] 143 | flat_target_idx = flat_target[:, label_idx] 144 | 145 | loss_idx = self._compute_dice_loss(flat_input_idx.view(-1, 1), flat_target_idx.view(-1, 1)) 146 | if loss is None: 147 | loss = loss_idx 148 | else: 149 | loss += loss_idx 150 | return loss 151 | 152 | def _binary_class(self, input, target, mask=None): 153 | flat_input = input.view(-1) 154 | flat_target = target.view(-1).float() 155 | flat_input = torch.sigmoid(flat_input) if self.with_logits else flat_input 156 | 157 | if mask is not None: 158 | mask = mask.float() 159 | flat_input = flat_input * mask 160 | flat_target = flat_target * mask 161 | else: 162 | mask = torch.ones_like(target) 163 | 164 | if self.ohem_ratio > 0: 165 | pos_example = target > 0.5 166 | neg_example = target <= 0.5 167 | mask_neg_num = mask <= 0.5 168 | 169 | pos_num = pos_example.sum() - (pos_example & mask_neg_num).sum() 170 | neg_num = neg_example.sum() 171 | keep_num = min(int(pos_num * self.ohem_ratio), neg_num) 172 | 173 | neg_scores = torch.masked_select(flat_input, neg_example.bool()) 174 | neg_scores_sort, _ = torch.sort(neg_scores, ) 175 | threshold = neg_scores_sort[-keep_num+1] 176 | cond = (flat_input > threshold) | pos_example.view(-1) 177 | ohem_mask = torch.where(cond, 1, 0) 178 | flat_input = flat_input * ohem_mask 179 | flat_target = flat_target * ohem_mask 180 | 181 | return self._compute_dice_loss(flat_input, flat_target) 182 | 183 | def __str__(self): 184 | return f"Dice Loss smooth:{self.smooth}, ohem: {self.ohem_ratio}, alpha: {self.alpha}" 185 | 186 | def __repr__(self): 187 | return str(self) 188 | -------------------------------------------------------------------------------- /loss/focal_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | from typing import List 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class FocalLoss(nn.Module): 13 | """ 14 | Focal loss(https://arxiv.org/pdf/1708.02002.pdf) 15 | Shape: 16 | - input: (N, C) 17 | - target: (N) 18 | - Output: Scalar loss 19 | Examples: 20 | >>> loss = FocalLoss(gamma=2, alpha=[1.0]*7) 21 | >>> input = torch.randn(3, 7, requires_grad=True) 22 | >>> target = torch.empty(3, dtype=torch.long).random_(7) 23 | >>> output = loss(input, target) 24 | >>> output.backward() 25 | """ 26 | def __init__(self, gamma=0, alpha: List[float] = None, reduction="none"): 27 | super(FocalLoss, self).__init__() 28 | self.gamma = gamma 29 | self.alpha = alpha 30 | if alpha is not None: 31 | self.alpha = torch.FloatTensor(alpha) 32 | self.reduction = reduction 33 | 34 | def forward(self, input, target): 35 | # [N, 1] 36 | target = target.unsqueeze(-1) 37 | # [N, C] 38 | pt = F.softmax(input, dim=-1) 39 | logpt = F.log_softmax(input, dim=-1) 40 | # [N] 41 | pt = pt.gather(1, target).squeeze(-1) 42 | logpt = logpt.gather(1, target).squeeze(-1) 43 | 44 | if self.alpha is not None: 45 | # [N] at[i] = alpha[target[i]] 46 | at = self.alpha.gather(0, target.squeeze(-1)) 47 | logpt = logpt * at 48 | 49 | loss = -1 * (1 - pt) ** self.gamma * logpt 50 | if self.reduction == "none": 51 | return loss 52 | if self.reduction == "mean": 53 | return loss.mean() 54 | return loss.sum() 55 | 56 | @staticmethod 57 | def convert_binary_pred_to_two_dimension(x, is_logits=True): 58 | """ 59 | Args: 60 | x: (*): (log) prob of some instance has label 1 61 | is_logits: if True, x represents log prob; otherwhise presents prob 62 | Returns: 63 | y: (*, 2), where y[*, 1] == log prob of some instance has label 0, 64 | y[*, 0] = log prob of some instance has label 1 65 | """ 66 | probs = torch.sigmoid(x) if is_logits else x 67 | probs = probs.unsqueeze(-1) 68 | probs = torch.cat([1-probs, probs], dim=-1) 69 | logprob = torch.log(probs+1e-4) # 1e-4 to prevent being rounded to 0 in fp16 70 | return logprob 71 | 72 | def __str__(self): 73 | return f"Focal Loss gamma:{self.gamma}" 74 | 75 | def __repr__(self): 76 | return str(self) 77 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShannonAI/dice_loss_for_NLP/d437bb999185535df46fdb74d1f2f57161331b44/metrics/__init__.py -------------------------------------------------------------------------------- /metrics/classification_acc_f1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: classification_acc_f1.py 5 | # description: 6 | # compute acc & f1 scores for classification tasks. 7 | 8 | import torch 9 | from pytorch_lightning.metrics.metric import TensorMetric 10 | from metrics.functional.cls_acc_f1 import collect_confusion_matrix, compute_precision_recall_f1_scores 11 | 12 | 13 | class ClassificationF1Metric(TensorMetric): 14 | """ 15 | compute acc and f1 scores for text classification tasks. 16 | """ 17 | def __init__(self, reduce_group=None, reduce_op=None, num_classes=2, f1_type="micro"): 18 | super(ClassificationF1Metric, self).__init__(name="classification_f1_metric", reduce_group=reduce_group, reduce_op=reduce_op) 19 | self.num_classes = num_classes 20 | self.f1_type = f1_type 21 | 22 | def forward(self, pred_labels, gold_labels): 23 | """ 24 | Description: 25 | collect confusion matrix for one batch. 26 | Args: 27 | pred_labels: a tensor in shape of [eval_batch_size] 28 | gold_labels: a tensor in shape if [eval_batch_size] 29 | Returns: 30 | a tensor of [true_positive, false_positive, true_negative, false_negative] 31 | """ 32 | confusion_matrix = collect_confusion_matrix(pred_labels, gold_labels, num_classes=self.num_classes) 33 | 34 | return confusion_matrix 35 | 36 | 37 | def compute_f1(self, all_confusion_matrix): 38 | """ 39 | Args: 40 | true_positive, false_positive, true_negative, false_negative in ALL CORPUS. 41 | Returns: 42 | four tensors -> acc, precision, recall, f1 43 | """ 44 | precision, recall, f1 = compute_precision_recall_f1_scores(all_confusion_matrix, num_classes=self.num_classes, f1_type=self.f1_type) 45 | precision, recall, f1 = torch.tensor(precision, dtype=torch.float), torch.tensor(recall, dtype=torch.float), torch.tensor(f1, dtype=torch.float) 46 | # The metric you returned including Precision, Recall, F1 (e.g., 0.91638) must be a `torch.Tensor` instance. 47 | return precision, recall, f1 -------------------------------------------------------------------------------- /metrics/functional/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShannonAI/dice_loss_for_NLP/d437bb999185535df46fdb74d1f2f57161331b44/metrics/functional/__init__.py -------------------------------------------------------------------------------- /metrics/functional/cls_acc_f1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: cls_acc_f1.py 5 | # description: 6 | # compute acc and f1 scores for text classification task. 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | def collect_confusion_matrix(y_pred_labels, y_gold_labels, num_classes=2): 13 | """ 14 | compute accuracy and f1 scores for text classification task. 15 | Args: 16 | pred_labels: [batch_size] index of labels. 17 | gold_labels: [batch_size] index of labels. 18 | Returns: 19 | A LongTensor composed by [true_positive, false_positive, false_negative] 20 | """ 21 | if num_classes <= 0: 22 | raise ValueError 23 | 24 | if num_classes == 1 or num_classes == 2: 25 | num_classes = 1 26 | y_true_onehot = y_gold_labels.bool() 27 | y_pred_onehot = y_pred_labels.bool() 28 | else: 29 | y_true_onehot = F.one_hot(y_gold_labels, num_classes=num_classes) 30 | y_pred_onehot = F.one_hot(y_pred_labels, num_classes=num_classes) 31 | 32 | if num_classes == 1: 33 | y_true_onehot = y_true_onehot.bool() 34 | y_pred_onehot = y_pred_onehot.bool() 35 | 36 | true_positive = (y_true_onehot & y_pred_onehot).long().sum() 37 | false_positive = (y_pred_onehot & ~ y_true_onehot).long().sum() 38 | false_negative = (~ y_pred_onehot & y_true_onehot).long().sum() 39 | 40 | stack_confusion_matrix = torch.stack([true_positive, false_positive, false_negative]) 41 | return stack_confusion_matrix 42 | 43 | multi_label_confusion_matrix = [] 44 | 45 | for idx in range(num_classes): 46 | index_item = torch.tensor([idx], dtype=torch.long).cuda() 47 | y_true_item_onehot = torch.index_select(y_true_onehot, 1, index_item) 48 | y_pred_item_onehot = torch.index_select(y_pred_onehot, 1, index_item) 49 | 50 | true_sum_item = torch.sum(y_true_item_onehot) 51 | pred_sum_item = torch.sum(y_pred_item_onehot) 52 | 53 | true_positive_item = torch.sum(y_true_item_onehot.multiply(y_pred_item_onehot)) 54 | 55 | false_positive_item = pred_sum_item - true_positive_item 56 | false_negative_item = true_sum_item - true_positive_item 57 | 58 | confusion_matrix_item = torch.tensor([true_positive_item, false_positive_item, false_negative_item], 59 | dtype=torch.long) 60 | 61 | multi_label_confusion_matrix.append(confusion_matrix_item) 62 | 63 | stack_confusion_matrix = torch.stack(multi_label_confusion_matrix, dim=0) 64 | 65 | return stack_confusion_matrix 66 | 67 | def compute_precision_recall_f1_scores(confusion_matrix, num_classes=2, f1_type="micro"): 68 | """ 69 | compute precision, recall and f1 scores. 70 | Description: 71 | f1: 2 * precision * recall / (precision + recall) 72 | - precision = true_positive / true_positive + false_positive 73 | - recall = true_positive / true_positive + false_negative 74 | Returns: 75 | precision, recall, f1 76 | """ 77 | 78 | if num_classes == 2 or num_classes == 1: 79 | confusion_matrix = confusion_matrix.to("cpu").numpy().tolist() 80 | true_positive, false_positive, false_negative = tuple(confusion_matrix) 81 | precision = true_positive / (true_positive + false_positive + 1e-10) 82 | recall = true_positive / (true_positive + false_negative + 1e-10) 83 | f1 = 2 * precision * recall / (precision + recall + 1e-10) 84 | precision, recall, f1 = round(precision, 5), round(recall, 5), round(f1, 5) 85 | return precision, recall, f1 86 | 87 | if f1_type == "micro": 88 | precision, recall, f1 = micro_precision_recall_f1(confusion_matrix, num_classes) 89 | elif f1_type == "macro": 90 | precision, recall, f1 = macro_precision_recall_f1(confusion_matrix) 91 | else: 92 | raise ValueError 93 | 94 | return precision, recall, f1 95 | 96 | 97 | def micro_precision_recall_f1(all_confusion_matrix, num_classes): 98 | precision_lst = [] 99 | recall_lst = [] 100 | 101 | all_confusion_matrix_lst = all_confusion_matrix.to("cpu").numpy().tolist() 102 | for idx in range(num_classes): 103 | matrix_item = all_confusion_matrix_lst[idx] 104 | true_positive_item, false_positive_item, false_negative_item = tuple(matrix_item) 105 | 106 | precision_item = true_positive_item / (true_positive_item + false_positive_item + 1e-10) 107 | recall_item = true_positive_item / (true_positive_item + false_negative_item + 1e-10) 108 | 109 | precision_lst.append(precision_item) 110 | recall_lst.append(recall_item) 111 | 112 | avg_precision = sum(precision_lst) / num_classes 113 | avg_recall = sum(recall_lst) / num_classes 114 | avg_f1 = 2 * avg_recall * avg_precision / (avg_recall + avg_precision + 1e-10) 115 | 116 | avg_precision, avg_recall, avg_f1 = round(avg_precision, 5), round(avg_recall, 5), round(avg_f1, 5) 117 | 118 | return avg_precision, avg_recall, avg_f1 119 | 120 | 121 | def macro_precision_recall_f1(all_confusion_matrix, ): 122 | confusion_matrix = torch.sum(all_confusion_matrix, 1, keepdim=False) 123 | confusion_matrix_lst = confusion_matrix.to("cpu").numpy().tolist() 124 | true_positive, false_positive, false_negative = tuple(confusion_matrix_lst) 125 | 126 | precision = true_positive / (true_positive + false_positive + 1e-10) 127 | recall = true_positive / (true_positive + false_negative + 1e-10) 128 | f1 = 2 * precision * recall / (precision + recall + 1e-10) 129 | 130 | precision, recall, f1 = round(precision, 5), round(recall, 5), round(f1, 5) 131 | 132 | return precision, recall, f1 -------------------------------------------------------------------------------- /metrics/functional/ner_span_f1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: ner_span_f1.py 5 | 6 | 7 | import torch 8 | from typing import Tuple, List 9 | 10 | 11 | class Tag(object): 12 | def __init__(self, term, tag, begin, end): 13 | self.term = term 14 | self.tag = tag 15 | self.begin = begin 16 | self.end = end 17 | 18 | def to_tuple(self): 19 | return tuple([self.term, self.begin, self.end]) 20 | 21 | def __str__(self): 22 | return str({key: value for key, value in self.__dict__.items()}) 23 | 24 | def __repr__(self): 25 | return str({key: value for key, value in self.__dict__.items()}) 26 | 27 | 28 | def bmes_decode(char_label_list: List[Tuple[str, str]]) -> List[Tag]: 29 | """ 30 | decode inputs to tags 31 | Args: 32 | char_label_list: list of tuple (word, bmes-tag) 33 | Returns: 34 | tags 35 | Examples: 36 | >>> x = [("Hi", "O"), ("Beijing", "S-LOC")] 37 | >>> bmes_decode(x) 38 | [{'term': 'Beijing', 'tag': 'LOC', 'begin': 1, 'end': 2}] 39 | [(1, 2), (4, 6), (7, 8), (9, 10)] 40 | """ 41 | idx = 0 42 | length = len(char_label_list) 43 | tags = [] 44 | while idx < length: 45 | term, label = char_label_list[idx] 46 | current_label = label[0] 47 | 48 | if current_label == "O": 49 | idx += 1 50 | continue 51 | if current_label == "S": 52 | tags.append(Tag(term, label[2:], idx, idx + 1)) 53 | idx += 1 54 | continue 55 | if current_label == "B": 56 | end = idx + 1 57 | while end + 1 < length and char_label_list[end][1][0] == "M": 58 | end += 1 59 | 60 | if end == len(char_label_list): 61 | entity = "".join(char_label_list[i][0] for i in range(idx, end)) 62 | tags.append(Tag(entity, label[2:], idx, end)) 63 | idx = end 64 | continue 65 | 66 | if char_label_list[end][1][0] == "E": # end with E 67 | entity = "".join(char_label_list[i][0] for i in range(idx, end + 1)) 68 | tags.append(Tag(entity, label[2:], idx, end + 1)) 69 | idx = end + 1 70 | else: # end with M/B 71 | entity = "".join(char_label_list[i][0] for i in range(idx, end)) 72 | tags.append(Tag(entity, label[2:], idx, end)) 73 | idx = end 74 | continue 75 | else: 76 | idx += 1 77 | continue 78 | return tags 79 | 80 | 81 | def bmes_decode_flat_query_span_f1(start_preds, end_preds, match_logits, start_end_label_mask, start_labels, end_labels, match_labels, answerable_pred=None): 82 | sum_true_positive, sum_false_positive, sum_false_negative = 0, 0, 0 83 | start_preds, end_preds, match_logits, start_end_label_mask = start_preds.to("cpu").numpy().tolist(), end_preds.to("cpu").numpy().tolist(), match_logits.to("cpu").numpy().tolist(), start_end_label_mask.to("cpu").numpy().tolist() 84 | start_labels, end_labels, match_labels = start_labels.to("cpu").numpy().tolist(), end_labels.to("cpu").numpy().tolist(), match_labels.to("cpu").numpy().tolist() 85 | batch_size, seq_len = len(start_labels), len(start_labels[0]) 86 | 87 | if answerable_pred is not None: 88 | answerable_pred = answerable_pred.to("cpu").numpy().tolist() 89 | else: 90 | answerable_pred = [1] * batch_size 91 | 92 | for start_pred_item, end_pred_item, match_logits_item, start_end_label_mask_item, start_label_item, end_label_item, match_label_item, answerable_item in \ 93 | zip(start_preds, end_preds, match_logits, start_end_label_mask, start_labels, end_labels, match_labels, answerable_pred): 94 | if answerable_item == 0: 95 | start_pred_item = [0] * len(start_pred_item) 96 | end_pred_item = [0] * len(end_pred_item) 97 | 98 | pred_entity_lst = extract_flat_spans(start_pred_item, end_pred_item, match_logits_item, start_end_label_mask_item,) 99 | gold_entity_lst = extract_flat_spans(start_label_item, end_label_item, match_label_item, start_end_label_mask_item) 100 | 101 | true_positive_item, false_positive_item, false_negative_item = count_confusion_matrix(pred_entity_lst, gold_entity_lst) 102 | sum_true_positive += true_positive_item 103 | sum_false_negative += false_negative_item 104 | sum_false_positive += false_positive_item 105 | 106 | batch_confusion_matrix = torch.tensor([sum_true_positive, sum_false_positive, sum_false_negative], dtype=torch.long) 107 | return batch_confusion_matrix 108 | 109 | def count_confusion_matrix(pred_entities, gold_entities): 110 | true_positive, false_positive, false_negative = 0, 0, 0 111 | for span_item in pred_entities: 112 | if span_item in gold_entities: 113 | true_positive += 1 114 | gold_entities.remove(span_item) 115 | else: 116 | false_positive += 1 117 | 118 | # these entities are not predicted. 119 | for span_item in gold_entities: 120 | false_negative += 1 121 | 122 | return true_positive, false_positive, false_negative 123 | 124 | def extract_flat_spans(start_pred, end_pred, match_pred, label_mask): 125 | """ 126 | Extract flat-ner spans from start/end/match logits 127 | Args: 128 | start_pred: [seq_len], 1/True for start, 0/False for non-start 129 | end_pred: [seq_len, 2], 1/True for end, 0/False for non-end 130 | match_pred: [seq_len, seq_len], 1/True for match, 0/False for non-match 131 | label_mask: [seq_len], 1 for valid boundary. 132 | Returns: 133 | tags: list of tuple (start, end) 134 | Examples: 135 | >>> start_pred = [0, 1] 136 | >>> end_pred = [0, 1] 137 | >>> match_pred = [[0, 0], [0, 1]] 138 | >>> label_mask = [1, 1] 139 | >>> extract_flat_spans(start_pred, end_pred, match_pred, label_mask) 140 | """ 141 | pseudo_tag = "TAG" 142 | pseudo_input = "a" 143 | 144 | bmes_labels = ["O"] * len(start_pred) 145 | start_positions = [idx for idx, tmp in enumerate(start_pred) if tmp and label_mask[idx]] 146 | end_positions = [idx for idx, tmp in enumerate(end_pred) if tmp and label_mask[idx]] 147 | 148 | for start_item in start_positions: 149 | bmes_labels[start_item] = f"B-{pseudo_tag}" 150 | for end_item in end_positions: 151 | if end_item in start_positions: 152 | bmes_labels[end_item] = f"B-{pseudo_tag}" 153 | # bmes_labels[end_item] = f"S-{pseudo_tag}" 154 | else: 155 | bmes_labels[end_item] = f"E-{pseudo_tag}" 156 | 157 | for tmp_start in start_positions: 158 | tmp_end = [tmp for tmp in end_positions if tmp >= tmp_start] 159 | if len(tmp_end) == 0: 160 | continue 161 | else: 162 | tmp_end = min(tmp_end) 163 | # if match_pred[tmp_start][tmp_end]: 164 | if tmp_start != tmp_end: 165 | for i in range(tmp_start+1, tmp_end): 166 | bmes_labels[i] = f"M-{pseudo_tag}" 167 | else: 168 | bmes_labels[tmp_end] = f"S-{pseudo_tag}" 169 | 170 | tags = bmes_decode([(pseudo_input, label) for label in bmes_labels]) 171 | 172 | return [(tag.begin, tag.end) for tag in tags] 173 | 174 | 175 | def remove_overlap(spans): 176 | """ 177 | remove overlapped spans greedily for flat-ner 178 | Args: 179 | spans: list of tuple (start, end), which means [start, end] is a ner-span 180 | Returns: 181 | spans without overlap 182 | """ 183 | output = [] 184 | occupied = set() 185 | for start, end in spans: 186 | if any(x for x in range(start, end+1)) in occupied: 187 | continue 188 | output.append((start, end)) 189 | for x in range(start, end + 1): 190 | occupied.add(x) 191 | return output 192 | -------------------------------------------------------------------------------- /metrics/functional/squad/README.md: -------------------------------------------------------------------------------- 1 | ## Evaluate the saved models using the SQuAD2.0 official evaluation. 2 | - To run the evaluation, use `python3 evaluate_v2.py ` 3 | - Sample Prediction File (On Dev 2.0) can be found `./tests/data/sample_prediction_v2.0.json` -------------------------------------------------------------------------------- /metrics/functional/squad/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShannonAI/dice_loss_for_NLP/d437bb999185535df46fdb74d1f2f57161331b44/metrics/functional/squad/__init__.py -------------------------------------------------------------------------------- /metrics/functional/squad/eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: eval.sh 5 | # description: 6 | # bash for evaluate prediction files. 7 | # Example: 8 | # bash eval.sh mrc-with-dice-loss/metrics/functional/squad/evaluate_v1.py /data/dev-v1.1.json predictions_10_10.json 9 | 10 | EVAL_SCRIPT=$1 11 | DATA_FILE=$2 12 | PRED_FILE=$3 13 | 14 | python3 ${EVAL_SCRIPT} ${DATA_FILE} ${PRED_FILE} 1 -------------------------------------------------------------------------------- /metrics/functional/squad/evaluate_v1.py: -------------------------------------------------------------------------------- 1 | """ Official evaluation script for v1.1 of the SQuAD dataset. """ 2 | from __future__ import print_function 3 | from collections import Counter 4 | import string 5 | import re 6 | import json 7 | import sys 8 | 9 | 10 | def normalize_answer(s): 11 | """Lower text and remove punctuation, articles and extra whitespace.""" 12 | def remove_articles(text): 13 | return re.sub(r'\b(a|an|the)\b', ' ', text) 14 | 15 | def white_space_fix(text): 16 | return ' '.join(text.split()) 17 | 18 | def remove_punc(text): 19 | exclude = set(string.punctuation) 20 | return ''.join(ch for ch in text if ch not in exclude) 21 | 22 | def lower(text): 23 | return text.lower() 24 | 25 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 26 | 27 | 28 | def f1_score(prediction, ground_truth): 29 | prediction_tokens = normalize_answer(prediction).split() 30 | ground_truth_tokens = normalize_answer(ground_truth).split() 31 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 32 | num_same = sum(common.values()) 33 | if num_same == 0: 34 | return 0 35 | precision = 1.0 * num_same / len(prediction_tokens) 36 | recall = 1.0 * num_same / len(ground_truth_tokens) 37 | f1 = (2 * precision * recall) / (precision + recall) 38 | return f1 39 | 40 | 41 | def exact_match_score(prediction, ground_truth): 42 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 43 | 44 | 45 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 46 | scores_for_ground_truths = [] 47 | for ground_truth in ground_truths: 48 | score = metric_fn(prediction, ground_truth) 49 | scores_for_ground_truths.append(score) 50 | return max(scores_for_ground_truths) 51 | 52 | 53 | def evaluate(dataset, predictions, stdout=False): 54 | f1 = exact_match = total = 0 55 | for article in dataset: 56 | for paragraph in article['paragraphs']: 57 | for qa in paragraph['qas']: 58 | total += 1 59 | if qa['id'] not in predictions: 60 | message = 'Unanswered question ' + qa['id'] + \ 61 | ' will receive score 0.' 62 | print(message, file=sys.stderr) 63 | continue 64 | ground_truths = list(map(lambda x: x['text'], qa['answers'])) 65 | prediction = predictions[qa['id']] 66 | exact_match += metric_max_over_ground_truths( 67 | exact_match_score, prediction, ground_truths) 68 | f1 += metric_max_over_ground_truths( 69 | f1_score, prediction, ground_truths) 70 | 71 | exact_match = 100.0 * exact_match / total 72 | f1 = 100.0 * f1 / total 73 | 74 | result_dict = {"exact_match": exact_match, "f1": f1} 75 | if stdout: 76 | print(json.dumps(result_dict)) 77 | return result_dict 78 | 79 | 80 | def main(data_file, prediction_file, stdout=False): 81 | with open(data_file, "r") as datafile: 82 | dataset_json = json.load(datafile) 83 | dataset = dataset_json["data"] 84 | 85 | with open(prediction_file, "r") as predfile: 86 | predictions = json.load(predfile) 87 | 88 | eval_result = evaluate(dataset, predictions, stdout=stdout) 89 | return eval_result 90 | 91 | 92 | if __name__ == '__main__': 93 | data_file = sys.argv[1] 94 | prediction_file = sys.argv[2] 95 | stdout_sign = sys.argv[3] 96 | if stdout_sign == "1": 97 | stdout = True 98 | else: 99 | stdout = False 100 | main(data_file, prediction_file, stdout=stdout) 101 | -------------------------------------------------------------------------------- /metrics/functional/squad/evaluate_v2.py: -------------------------------------------------------------------------------- 1 | """Official evaluation script for SQuAD version 2.0. 2 | 3 | In addition to basic functionality, we also compute additional statistics and 4 | plot precision-recall curves if an additional na_prob.json file is provided. 5 | This file is expected to map question ID's to the model's predicted probability 6 | that a question is unanswerable. 7 | """ 8 | import argparse 9 | import collections 10 | import json 11 | import os 12 | import re 13 | import string 14 | import sys 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser('Official evaluation script for SQuAD version 2.0.') 19 | parser.add_argument('data_file', metavar='data.json', help='Input data JSON file.') 20 | parser.add_argument('pred_file', metavar='pred.json', help='Model predictions.') 21 | parser.add_argument('--out-file', '-o', metavar='eval.json', 22 | help='Write accuracy metrics to file (default is stdout).') 23 | parser.add_argument('--na-prob-file', '-n', metavar='na_prob.json', 24 | help='Model estimates of probability of no answer.') 25 | parser.add_argument('--na-prob-thresh', '-t', type=float, default=1.0, 26 | help='Predict "" if no-answer probability exceeds this (default = 1.0).') 27 | parser.add_argument('--out-image-dir', '-p', metavar='out_images', default=None, 28 | help='Save precision-recall curves to directory.') 29 | parser.add_argument('--verbose', '-v', action='store_true') 30 | if len(sys.argv) == 1: 31 | parser.print_help() 32 | sys.exit(1) 33 | return parser.parse_args() 34 | 35 | 36 | def make_qid_to_has_ans(dataset): 37 | qid_to_has_ans = {} 38 | for article in dataset: 39 | for p in article['paragraphs']: 40 | for qa in p['qas']: 41 | qid_to_has_ans[qa['id']] = bool(qa['answers']) 42 | return qid_to_has_ans 43 | 44 | 45 | def normalize_answer(s): 46 | """Lower text and remove punctuation, articles and extra whitespace.""" 47 | 48 | def remove_articles(text): 49 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 50 | return re.sub(regex, ' ', text) 51 | 52 | def white_space_fix(text): 53 | return ' '.join(text.split()) 54 | 55 | def remove_punc(text): 56 | exclude = set(string.punctuation) 57 | return ''.join(ch for ch in text if ch not in exclude) 58 | 59 | def lower(text): 60 | return text.lower() 61 | 62 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 63 | 64 | 65 | def get_tokens(s): 66 | if not s: return [] 67 | return normalize_answer(s).split() 68 | 69 | 70 | def compute_exact(a_gold, a_pred): 71 | return int(normalize_answer(a_gold) == normalize_answer(a_pred)) 72 | 73 | 74 | def compute_f1(a_gold, a_pred): 75 | gold_toks = get_tokens(a_gold) 76 | pred_toks = get_tokens(a_pred) 77 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 78 | num_same = sum(common.values()) 79 | if len(gold_toks) == 0 or len(pred_toks) == 0: 80 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 81 | return int(gold_toks == pred_toks) 82 | if num_same == 0: 83 | return 0 84 | precision = 1.0 * num_same / len(pred_toks) 85 | recall = 1.0 * num_same / len(gold_toks) 86 | f1 = (2 * precision * recall) / (precision + recall) 87 | return f1 88 | 89 | 90 | def get_raw_scores(dataset, preds): 91 | exact_scores = {} 92 | f1_scores = {} 93 | for article in dataset: 94 | for p in article['paragraphs']: 95 | for qa in p['qas']: 96 | qid = qa['id'] 97 | gold_answers = [a['text'] for a in qa['answers'] 98 | if normalize_answer(a['text'])] 99 | if not gold_answers: 100 | # For unanswerable questions, only correct answer is empty string 101 | gold_answers = [''] 102 | if qid not in preds: 103 | print('Missing prediction for %s' % qid) 104 | continue 105 | a_pred = preds[qid] 106 | # Take max over all gold answers 107 | exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers) 108 | f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers) 109 | return exact_scores, f1_scores 110 | 111 | 112 | def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): 113 | new_scores = {} 114 | for qid, s in scores.items(): 115 | pred_na = na_probs[qid] > na_prob_thresh 116 | if pred_na: 117 | new_scores[qid] = float(not qid_to_has_ans[qid]) 118 | else: 119 | new_scores[qid] = s 120 | return new_scores 121 | 122 | 123 | def make_eval_dict(exact_scores, f1_scores, qid_list=None): 124 | if not qid_list: 125 | total = len(exact_scores) 126 | return collections.OrderedDict([ 127 | ('exact', 100.0 * sum(exact_scores.values()) / total), 128 | ('f1', 100.0 * sum(f1_scores.values()) / total), 129 | ('total', total), 130 | ]) 131 | else: 132 | total = len(qid_list) 133 | return collections.OrderedDict([ 134 | ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total), 135 | ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total), 136 | ('total', total), 137 | ]) 138 | 139 | 140 | def merge_eval(main_eval, new_eval, prefix): 141 | for k in new_eval: 142 | main_eval['%s_%s' % (prefix, k)] = new_eval[k] 143 | 144 | 145 | def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, out_image=None, title=None): 146 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 147 | true_pos = 0.0 148 | cur_p = 1.0 149 | cur_r = 0.0 150 | precisions = [1.0] 151 | recalls = [0.0] 152 | avg_prec = 0.0 153 | for i, qid in enumerate(qid_list): 154 | if qid_to_has_ans[qid]: 155 | true_pos += scores[qid] 156 | cur_p = true_pos / float(i + 1) 157 | cur_r = true_pos / float(num_true_pos) 158 | if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i + 1]]: 159 | # i.e., if we can put a threshold after this point 160 | avg_prec += cur_p * (cur_r - recalls[-1]) 161 | precisions.append(cur_p) 162 | recalls.append(cur_r) 163 | return {'ap': 100.0 * avg_prec} 164 | 165 | 166 | def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, 167 | qid_to_has_ans, out_image_dir): 168 | if out_image_dir and not os.path.exists(out_image_dir): 169 | os.makedirs(out_image_dir) 170 | num_true_pos = sum(1 for v in qid_to_has_ans.values() if v) 171 | if num_true_pos == 0: 172 | return 173 | pr_exact = make_precision_recall_eval( 174 | exact_raw, na_probs, num_true_pos, qid_to_has_ans, 175 | out_image=os.path.join(out_image_dir, 'pr_exact.png'), 176 | title='Precision-Recall curve for Exact Match score') 177 | pr_f1 = make_precision_recall_eval( 178 | f1_raw, na_probs, num_true_pos, qid_to_has_ans, 179 | out_image=os.path.join(out_image_dir, 'pr_f1.png'), 180 | title='Precision-Recall curve for F1 score') 181 | oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()} 182 | pr_oracle = make_precision_recall_eval( 183 | oracle_scores, na_probs, num_true_pos, qid_to_has_ans, 184 | out_image=os.path.join(out_image_dir, 'pr_oracle.png'), 185 | title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)') 186 | merge_eval(main_eval, pr_exact, 'pr_exact') 187 | merge_eval(main_eval, pr_f1, 'pr_f1') 188 | merge_eval(main_eval, pr_oracle, 'pr_oracle') 189 | 190 | 191 | def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): 192 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) 193 | cur_score = num_no_ans 194 | best_score = cur_score 195 | best_thresh = 0.0 196 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 197 | for i, qid in enumerate(qid_list): 198 | if qid not in scores: continue 199 | if qid_to_has_ans[qid]: 200 | diff = scores[qid] 201 | else: 202 | if preds[qid]: 203 | diff = -1 204 | else: 205 | diff = 0 206 | cur_score += diff 207 | if cur_score > best_score: 208 | best_score = cur_score 209 | best_thresh = na_probs[qid] 210 | return 100.0 * best_score / len(scores), best_thresh 211 | 212 | 213 | def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): 214 | best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) 215 | best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) 216 | main_eval['best_exact'] = best_exact 217 | main_eval['best_exact_thresh'] = exact_thresh 218 | main_eval['best_f1'] = best_f1 219 | main_eval['best_f1_thresh'] = f1_thresh 220 | 221 | 222 | def evaluate(data_file, pred_file, na_prob_file, na_prob_thresh, out_file): 223 | with open(data_file) as f: 224 | dataset_json = json.load(f) 225 | dataset = dataset_json['data'] 226 | with open(pred_file) as f: 227 | preds = json.load(f) 228 | if na_prob_file: 229 | with open(na_prob_file) as f: 230 | na_probs = json.load(f) 231 | else: 232 | na_probs = {k: 0.0 for k in preds} 233 | qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False 234 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] 235 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] 236 | exact_raw, f1_raw = get_raw_scores(dataset, preds) 237 | exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, na_prob_thresh) 238 | f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, na_prob_thresh) 239 | out_eval = make_eval_dict(exact_thresh, f1_thresh) 240 | if has_ans_qids: 241 | has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) 242 | merge_eval(out_eval, has_ans_eval, 'HasAns') 243 | if no_ans_qids: 244 | no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) 245 | merge_eval(out_eval, no_ans_eval, 'NoAns') 246 | if na_prob_file: 247 | find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans) 248 | if out_file: 249 | with open(out_file, 'w') as f: 250 | json.dump(out_eval, f) 251 | else: 252 | print(json.dumps(out_eval, indent=2)) 253 | 254 | 255 | if __name__ == '__main__': 256 | data_file = "" 257 | pred_file = "" 258 | na_prob_file = "" 259 | na_prob_thresh = "" 260 | out_file = "" 261 | evaluate(data_file, pred_file, na_prob_file, na_prob_thresh, out_file) 262 | -------------------------------------------------------------------------------- /metrics/mrc_ner_span_f1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | from pytorch_lightning.metrics.metric import TensorMetric 6 | from metrics.functional.ner_span_f1 import bmes_decode_flat_query_span_f1 7 | 8 | 9 | class MRCNERSpanF1(TensorMetric): 10 | """ 11 | Query Span F1 12 | Args: 13 | flat: is flat-ner 14 | """ 15 | def __init__(self, reduce_group=None, reduce_op=None, flat=False): 16 | super(MRCNERSpanF1, self).__init__(name="query_span_f1", 17 | reduce_group=reduce_group, 18 | reduce_op=reduce_op) 19 | self.flat = flat 20 | 21 | def forward(self, start_preds, end_preds, match_logits, start_end_label_mask, start_labels, end_labels, match_labels, answerable_pred=None): 22 | return bmes_decode_flat_query_span_f1(start_preds, end_preds, match_logits, start_end_label_mask, start_labels, 23 | end_labels, match_labels, answerable_pred=answerable_pred) 24 | 25 | -------------------------------------------------------------------------------- /metrics/squad_em_f1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: 5 | # squad_em_f1.py 6 | # description: 7 | # compute exact match / f1-score for SQuAD task. 8 | 9 | import os 10 | import json 11 | from metrics.functional.squad.postprocess_predication import compute_predictions_logits 12 | from metrics.functional.squad.evaluate_v1 import evaluate as evaluate_squad_v1 13 | 14 | class SquadEvalMetric: 15 | def __init__(self, 16 | n_best_size: int = 20, 17 | max_answer_length: int = 20, 18 | do_lower_case: bool = False, 19 | verbose_logging: bool = False, 20 | version_2_with_negative: bool = False, 21 | null_score_diff_threshold: float = 0, 22 | data_dir: str = "", 23 | output_dir: str = ""): 24 | 25 | self.n_best_size = n_best_size 26 | self.max_answer_length = max_answer_length 27 | self.do_lower_case = do_lower_case 28 | self.verbose_logging = verbose_logging 29 | self.version_2_with_negative = version_2_with_negative 30 | self.null_score_diff_threshold = null_score_diff_threshold 31 | 32 | self.data_dir = data_dir 33 | self.output_dir = output_dir 34 | 35 | 36 | def forward(self, all_examples, all_features, all_results, tokenizer, prefix = "dev", sigmoid=True): 37 | if not self.version_2_with_negative: 38 | with open(os.path.join(self.data_dir, "dev-v1.1.json"), "r") as f: 39 | text_dataset = json.load(f)["data"] 40 | else: 41 | with open(os.path.join(self.data_dir, "dev-v2.0.json"), "r") as f: 42 | text_dataset = json.load(f)["data"] 43 | 44 | output_prediction_file = os.path.join(self.output_dir, "predictions_{}.json".format(prefix)) 45 | output_nbest_file = os.path.join(self.output_dir, "nbest_predictions_{}.json".format(prefix)) 46 | 47 | if self.version_2_with_negative: 48 | output_null_log_odds_file = os.path.join(self.output_dir, "null_odds_{}.json".format(prefix)) 49 | else: 50 | output_null_log_odds_file = None 51 | 52 | all_predictions = compute_predictions_logits(all_examples, all_features, all_results, 53 | self.n_best_size, 54 | self.max_answer_length, 55 | self.do_lower_case, 56 | output_prediction_file, 57 | output_nbest_file, 58 | output_null_log_odds_file, 59 | self.verbose_logging, 60 | self.version_2_with_negative, 61 | self.null_score_diff_threshold, 62 | tokenizer, 63 | sigmoid=sigmoid) 64 | if not self.version_2_with_negative: 65 | eval_results = evaluate_squad_v1(text_dataset, all_predictions) 66 | exact_match, f1 = eval_results["exact_match"], eval_results["f1"] 67 | else: 68 | raise ValueError("Evaluation for SQuAD 2.0 is not Implementation yet") 69 | 70 | return exact_match, f1 71 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShannonAI/dice_loss_for_NLP/d437bb999185535df46fdb74d1f2f57161331b44/models/__init__.py -------------------------------------------------------------------------------- /models/bert_classification.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: bert_classification.py 5 | # description: 6 | # model for fine-tuning BERT on text classification tasks. 7 | 8 | import torch.nn as nn 9 | from torch import Tensor 10 | from transformers import BertModel, BertPreTrainedModel 11 | 12 | from models.classifier import truncated_normal_ 13 | from models.model_config import BertForSequenceClassificationConfig 14 | 15 | class BertForSequenceClassification(BertPreTrainedModel): 16 | """Fine-tune BERT model for text classification.""" 17 | def __init__(self, config: BertForSequenceClassificationConfig,): 18 | super(BertForSequenceClassification, self).__init__(config) 19 | self.bert = BertModel(config,) 20 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 21 | self.cls_classifier = nn.Linear(config.hidden_size, config.num_labels) 22 | self.cls_classifier.weight = truncated_normal_(self.cls_classifier.weight, mean=0, std=0.02) 23 | self.init_weights() 24 | 25 | def forward(self, input_ids: Tensor, token_type_ids: Tensor, attention_mask: Tensor): 26 | """ 27 | Args: 28 | inputs_ids: input tokens, tensor of shape [batch_size, seq_len]. 29 | token_type_ids: 1 for text_b tokens and 0 for text_a tokens. tensor of shape [batch_size, seq_len]. 30 | attention_mask: 1 for non-[PAD] tokens and 0 for [PAD] tokens. tensor of shape [batch_size, seq_len]. 31 | Returns: 32 | cls_outputs: output logits for the [CLS] token. tensor of shape [batch_size, num_labels]. 33 | """ 34 | bert_outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) 35 | bert_cls_output = bert_outputs[1] 36 | bert_cls_output = self.dropout(bert_cls_output) 37 | cls_logits = self.cls_classifier(bert_cls_output) 38 | 39 | return cls_logits -------------------------------------------------------------------------------- /models/bert_qa.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: bert_qa.py 5 | # description: 6 | # BERT for question answering task. 7 | 8 | 9 | import torch.nn as nn 10 | from torch import Tensor 11 | from transformers import BertModel, BertPreTrainedModel 12 | from models.classifier import truncated_normal_, BertMLP 13 | 14 | 15 | class BertForQuestionAnswering(BertPreTrainedModel): 16 | """Finetuning Bert Model for the question answering task.""" 17 | def __init__(self, config): 18 | super(BertForQuestionAnswering, self).__init__(config) 19 | 20 | self.bert = BertModel(config, add_pooling_layer=False) 21 | if config.multi_layer_classifier: 22 | self.qa_classifier = BertMLP(config) 23 | else: 24 | self.qa_classifier = nn.Linear(config.hidden_size, 2) 25 | self.qa_classifier.weight = truncated_normal_(self.qa_classifier.weight, mean=0, std=0.02) 26 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 27 | self.init_weights() 28 | 29 | def forward(self, input_ids: Tensor, token_type_ids: Tensor, attention_mask: Tensor): 30 | """ 31 | Args: 32 | input_ids: Bert input tokens, tensor of shape [batch, seq_len] 33 | token_type_ids: 0 for query, 1 for context, tensor of shape [batch, seq_len] 34 | attention_mask: attention mask, tensor of shape [batch, seq_len] 35 | Returns: 36 | start_logits: start/non-start logits of shape [batch, seq_len] 37 | end_logits: end/non-end logits of shape [batch, seq_len] 38 | """ 39 | 40 | bert_outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) 41 | 42 | sequence_heatmap = self.dropout(bert_outputs[0]) # [batch, seq_len, hidden] 43 | logits = self.qa_classifier(sequence_heatmap) 44 | start_logits, end_logits = logits.split(1, dim=-1) 45 | start_logits = start_logits.squeeze(-1) 46 | end_logits = end_logits.squeeze(-1) 47 | 48 | return start_logits, end_logits 49 | 50 | 51 | -------------------------------------------------------------------------------- /models/bert_query_ner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: bert_mrc_ner.py 5 | 6 | import torch 7 | import torch.nn as nn 8 | from transformers import BertModel, BertPreTrainedModel 9 | 10 | from models.classifier import SpanClassifier, MultiLayerPerceptronClassifier 11 | 12 | 13 | class BertForQueryNER(BertPreTrainedModel): 14 | def __init__(self, config): 15 | super(BertForQueryNER, self).__init__(config) 16 | self.bert = BertModel(config) 17 | 18 | self.construct_entity_span = config.construct_entity_span 19 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 20 | 21 | if self.construct_entity_span == "start_end_match": 22 | self.start_outputs = MultiLayerPerceptronClassifier(hidden_size=config.hidden_size, num_labels=config.num_labels, activate_func=config.activate_func) 23 | self.end_outputs = MultiLayerPerceptronClassifier(hidden_size=config.hidden_size, num_labels=config.num_labels, activate_func=config.activate_func) 24 | self.span_embedding = MultiLayerPerceptronClassifier(hidden_size=config.hidden_size*2, num_labels=1, activate_func=config.activate_func) 25 | elif self.construct_entity_span == "match": 26 | self.span_nn = SpanClassifier(config.hidden_size, config.hidden_dropout_prob) 27 | elif self.construct_entity_span == "start_and_end": 28 | self.start_outputs = MultiLayerPerceptronClassifier(hidden_size=config.hidden_size, num_labels=config.num_labels, activate_func=config.activate_func) 29 | self.end_outputs = MultiLayerPerceptronClassifier(hidden_size=config.hidden_size, num_labels=config.num_labels, activate_func=config.activate_func) 30 | elif self.construct_entity_span == "start_end": 31 | self.start_end_outputs = MultiLayerPerceptronClassifier(hidden_size=config.hidden_size, num_labels=2, activate_func=config.activate_func) 32 | else: 33 | raise ValueError 34 | 35 | self.pred_answerable = config.pred_answerable 36 | if self.pred_answerable: 37 | self.answerable_cls_output = MultiLayerPerceptronClassifier(hidden_size=config.hidden_size, num_labels=1, activate_func=config.activate_func) 38 | 39 | self.init_weights() 40 | 41 | def forward(self, input_ids, token_type_ids=None, attention_mask=None): 42 | """ 43 | Args: 44 | input_ids: bert input tokens, tensor of shape [batch, seq_len] 45 | token_type_ids: 0 for query, 1 for context, tensor of shape [batch, seq_len] 46 | attention_mask: attention mask, tensor of shape [batch, seq_len] 47 | Returns: 48 | start_logits: start/non-start probs of shape [batch, seq_len] 49 | end_logits: end/non-end probs of shape [batch, seq_len] 50 | match_logits: start-end-match probs of shape [batch, seq_len, seq_len] 51 | """ 52 | 53 | bert_outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) 54 | 55 | sequence_heatmap = bert_outputs[0] # [batch, seq_len, hidden] 56 | sequence_cls = bert_outputs[1] 57 | 58 | batch_size, seq_len, hid_size = sequence_heatmap.size() 59 | if self.construct_entity_span == "match" : 60 | start_logits = end_logits = torch.ones_like(input_ids).float() 61 | span_logits = self.span_nn(sequence_heatmap) 62 | elif self.construct_entity_span == "start_end_match": 63 | sequence_heatmap = self.dropout(sequence_heatmap) 64 | start_logits = self.start_outputs(sequence_heatmap).view(batch_size, seq_len, -1) # [batch, seq_len, 1] 65 | end_logits = self.end_outputs(sequence_heatmap).view(batch_size, seq_len, -1) # [batch, seq_len, 1] 66 | 67 | # for every position $i$ in sequence, should concate $j$ to 68 | # predict if $i$ and $j$ are start_pos and end_pos for an entity. 69 | # [batch, seq_len, seq_len, hidden] 70 | start_extend = sequence_heatmap.unsqueeze(2).expand(-1, -1, seq_len, -1) 71 | # [batch, seq_len, seq_len, hidden] 72 | end_extend = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, -1) 73 | # [batch, seq_len, seq_len, hidden*2] 74 | span_matrix = torch.cat([start_extend, end_extend], 3) 75 | # [batch, seq_len, seq_len] 76 | span_logits = self.span_embedding(span_matrix).squeeze(-1) 77 | elif self.construct_entity_span == "start_and_end": 78 | sequence_heatmap = self.dropout(sequence_heatmap) 79 | start_logits = self.start_outputs(sequence_heatmap).view(batch_size, seq_len, -1) # [batch, seq_len, 1] 80 | end_logits = self.end_outputs(sequence_heatmap).view(batch_size, seq_len, -1) # [batch, seq_len, 1] 81 | 82 | span_logits = None 83 | elif self.construct_entity_span == "start_end": 84 | sequence_heatmap = self.dropout(sequence_heatmap) 85 | start_end_logits = self.start_end_outputs(sequence_heatmap) 86 | start_logits, end_logits = start_end_logits.split(1, dim=-1) 87 | 88 | span_logits = None 89 | else: 90 | raise ValueError 91 | 92 | if self.pred_answerable: 93 | cls_logits = self.answerable_cls_output(sequence_cls) 94 | return start_logits, end_logits, span_logits, cls_logits 95 | 96 | return start_logits, end_logits, span_logits 97 | -------------------------------------------------------------------------------- /models/classifier.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: modules.py 5 | # description: 6 | # modules for building models. 7 | 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import init 12 | from torch.nn import functional as F 13 | 14 | 15 | def truncated_normal_(tensor, mean=0, std=1): 16 | size = tensor.shape 17 | tmp = tensor.new_empty(size + (4,)).normal_() 18 | valid = (tmp < 2) & (tmp > -2) 19 | ind = valid.max(-1, keepdim=True)[1] 20 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 21 | tensor.data.mul_(std).add_(mean) 22 | return tensor 23 | 24 | 25 | class BertMLP(nn.Module): 26 | def __init__(self, config): 27 | super().__init__() 28 | self.dense_layer = nn.Linear(config.hidden_size, config.hidden_size) 29 | self.dense_to_labels_layer = nn.Linear(config.hidden_size, config.num_labels) 30 | self.activation = nn.Tanh() 31 | if config.truncated_normal: 32 | self.dense_layer.weight = truncated_normal_(self.dense_layer.weight, mean=0, std=0.02) 33 | self.dense_to_labels_layer.weight = truncated_normal_(self.dense_to_labels_layer.weight, mean=0, std=0.02) 34 | 35 | def forward(self, sequence_hidden_states): 36 | sequence_output = self.dense_layer(sequence_hidden_states) 37 | sequence_output = self.activation(sequence_output) 38 | sequence_output = self.dense_to_labels_layer(sequence_output) 39 | return sequence_output 40 | 41 | 42 | class MultiLayerPerceptronClassifier(nn.Module): 43 | def __init__(self, hidden_size=None, num_labels=None, activate_func="gelu"): 44 | super().__init__() 45 | self.dense_layer = nn.Linear(hidden_size, hidden_size) 46 | self.dense_to_labels_layer = nn.Linear(hidden_size, num_labels) 47 | if activate_func == "tanh": 48 | self.activation = nn.Tanh() 49 | elif activate_func == "relu": 50 | self.activation = nn.ReLU() 51 | elif activate_func == "gelu": 52 | self.activation = nn.GELU() 53 | else: 54 | raise ValueError 55 | 56 | def forward(self, sequence_hidden_states): 57 | sequence_output = self.dense_layer(sequence_hidden_states) 58 | sequence_output = self.activation(sequence_output) 59 | sequence_output = self.dense_to_labels_layer(sequence_output) 60 | return sequence_output 61 | 62 | 63 | class SpanClassifier(nn.Module): 64 | def __init__(self, hidden_size: int, dropout_rate: float): 65 | super(SpanClassifier, self).__init__() 66 | self.start_proj = nn.Linear(hidden_size, hidden_size) 67 | self.end_proj = nn.Linear(hidden_size, hidden_size) 68 | self.biaffine = nn.Parameter(torch.Tensor(hidden_size, hidden_size)) 69 | self.concat_proj = nn.Linear(hidden_size * 2, 1) 70 | self.dropout = nn.Dropout(dropout_rate) 71 | self.reset_parameters() 72 | 73 | def forward(self, input_features): 74 | bsz, seq_len, dim = input_features.size() 75 | # B, L, h 76 | start_feature = self.dropout(F.gelu(self.start_proj(input_features))) 77 | # B, L, h 78 | end_feature = self.dropout(F.gelu(self.end_proj(input_features))) 79 | # B, L, L 80 | biaffine_logits = torch.bmm(torch.matmul(start_feature, self.biaffine), end_feature.transpose(1, 2)) 81 | 82 | start_extend = start_feature.unsqueeze(2).expand(-1, -1, seq_len, -1) 83 | # [B, L, L, h] 84 | end_extend = end_feature.unsqueeze(1).expand(-1, seq_len, -1, -1) 85 | # [B, L, L, h] 86 | span_matrix = torch.cat([start_extend, end_extend], 3) 87 | # [B, L, L] 88 | concat_logits = self.concat_proj(span_matrix).squeeze(-1) 89 | # B, L, L 90 | return biaffine_logits + concat_logits 91 | 92 | def reset_parameters(self) -> None: 93 | init.kaiming_uniform_(self.biaffine, a=math.sqrt(5)) 94 | 95 | 96 | -------------------------------------------------------------------------------- /models/model_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: model_config.py 5 | # description: 6 | # user defined configuration class for NLP tasks. 7 | 8 | from transformers import BertConfig 9 | 10 | 11 | class BertForQAConfig(BertConfig): 12 | def __init__(self, **kwargs): 13 | super(BertForQAConfig, self).__init__(**kwargs) 14 | self.hidden_size = kwargs.get("hidden_size", 768) 15 | self.multi_layer_classifier = kwargs.get("multi_layer_classifier", True) 16 | self.truncated_normal = kwargs.get("truncated_normal", True) 17 | 18 | class BertForSequenceClassificationConfig(BertConfig): 19 | def __init__(self, **kwargs): 20 | super(BertForSequenceClassificationConfig, self).__init__(**kwargs) 21 | self.hidden_dropout_prob = kwargs.get("hidden_dropout_prob", 0.0) 22 | self.num_labels = kwargs.get("num_labels", 2) 23 | self.hidden_size = kwargs.get("hidden_size", 768) 24 | self.truncated_normal = kwargs.get("truncated_normal", False) 25 | 26 | class BertForQueryNERConfig(BertConfig): 27 | def __init__(self, **kwargs): 28 | super(BertForQueryNERConfig, self).__init__(**kwargs) 29 | self.hidden_dropout_prob = kwargs.get("hidden_dropout_prob", 0.1) 30 | self.truncated_normal = kwargs.get("truncated_normal", False) 31 | self.construct_entity_span = kwargs.get("construct_entity_span", "start_end_match") 32 | self.pred_answerable = kwargs.get("pred_answerable", True) 33 | self.num_labels = kwargs.get("num_labels", 2) 34 | self.activate_func = kwargs.get("activate_func", "gelu") -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning==0.9.0 2 | tokenizers==0.9.3 3 | transformers==3.5.1 -------------------------------------------------------------------------------- /scripts/download_ckpt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # description: 5 | # download pretrained model ckpt 6 | 7 | BERT_PRETRAIN_CKPT=$1 8 | MODEL_NAME=$2 9 | 10 | if [[ $MODEL_NAME == "bert_cased_base" ]]; then 11 | mkdir -p $BERT_PRETRAIN_CKPT 12 | echo "DownLoad BERT Cased Base" 13 | wget https://storage.googleapis.com/bert_models/2018_10_18/cased_L-12_H-768_A-12.zip -P $BERT_PRETRAIN_CKPT 14 | unzip $BERT_PRETRAIN_CKPT/cased_L-12_H-768_A-12.zip -d $BERT_PRETRAIN_CKPT 15 | rm $BERT_PRETRAIN_CKPT/cased_L-12_H-768_A-12.zip 16 | mv $BERT_PRETRAIN_CKPT/cased_L-12_H-768_A-12 $BERT_PRETRAIN_CKPT/bert_cased_base 17 | elif [[ $MODEL_NAME == "bert_cased_large" ]]; then 18 | echo "DownLoad BERT Cased Large" 19 | wget https://storage.googleapis.com/bert_models/2018_10_18/cased_L-24_H-1024_A-16.zip -P $BERT_PRETRAIN_CKPT 20 | unzip $BERT_PRETRAIN_CKPT/cased_L-24_H-1024_A-16.zip -d $BERT_PRETRAIN_CKPT 21 | rm $BERT_PRETRAIN_CKPT/cased_L-24_H-1024_A-16.zip 22 | mv $BERT_PRETRAIN_CKPT/cased_L-24_H-1024_A-16 $BERT_PRETRAIN_CKPT/bert_cased_large 23 | elif [[ $MODEL_NAME == "bert_uncased_base" ]]; then 24 | echo "DownLoad BERT Uncased Base" 25 | wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip -P $BERT_PRETRAIN_CKPT 26 | unzip $BERT_PRETRAIN_CKPT/uncased_L-12_H-768_A-12.zip -d $BERT_PRETRAIN_CKPT 27 | rm $BERT_PRETRAIN_CKPT/uncased_L-12_H-768_A-12.zip 28 | mv $BERT_PRETRAIN_CKPT/uncased_L-12_H-768_A-12 $BERT_PRETRAIN_CKPT/bert_uncased_base 29 | elif [[ $MODEL_NAME == "bert_uncased_large" ]]; then 30 | echo "DownLoad BERT Uncased Large" 31 | wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-24_H-1024_A-16.zip -P $BERT_PRETRAIN_CKPT 32 | unzip $BERT_PRETRAIN_CKPT/uncased_L-24_H-1024_A-16.zip -d $BERT_PRETRAIN_CKPT 33 | rm $BERT_PRETRAIN_CKPT/uncased_L-24_H-1024_A-16.zip 34 | mv $BERT_PRETRAIN_CKPT/uncased_L-24_H-1024_A-16 $BERT_PRETRAIN_CKPT/bert_uncased_large 35 | elif [[ $MODEL_NAME == "bert_tiny" ]]; then 36 | each "DownLoad BERT Uncased Tiny; Helps to debug on GPU." 37 | wget https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-2_H-128_A-2.zip -P $BERT_PRETRAIN_CKPT 38 | unzip -zxvf $BERT_PRETRAIN_CKPT/uncased_L-2_H-128_A-2.zip -d $BERT_PRETRAIN_CKPT 39 | rm $BERT_PRETRAIN_CKPT/uncased_L-2_H-128_A-2.zip 40 | mv $BERT_PRETRAIN_CKPT/uncased_L-2_H-128_A-2 $BERT_PRETRAIN_CKPT/bert_uncased_tiny 41 | else 42 | echo 'Unknown argment 2 (Model Sign)' 43 | fi -------------------------------------------------------------------------------- /scripts/glue_mrpc/bert_base_ce.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: glue/bert_base_ce.sh 5 | # result: 6 | # dev f1/acc: 90.03/85.78 7 | # test f1/acc: 87.36/82.72 8 | 9 | FILE=bert_base_ce 10 | MODEL_SCALE=base 11 | TASK=mrpc 12 | OUTPUT_BASE_DIR=/data/xiaoya/outputs/dice_loss/glue_mrpc 13 | DATA_DIR=/data/xiaoya/datasets/mrpc 14 | BERT_DIR=/data/xiaoya/pretrain_lm/cased_L-12_H-768_A-12 15 | REPO_PATH=/data/xiaoya/workspace/dice_loss_for_NLP 16 | 17 | # NEED CHANGE 18 | LR=2e-5 19 | LR_SCHEDULER=linear 20 | TRAIN_BATCH_SIZE=24 21 | ACC_GRAD=1 22 | DROPOUT=0.1 23 | WEIGHT_DECAY=0.01 24 | WARMUP_PROPORTION=0.1 25 | LOSS_TYPE=ce 26 | DICE_SMOOTH=1e-4 27 | DICE_OHEM=0.1 28 | DICE_ALPHA=0.01 29 | FOCAL_GAMMA=0.1 30 | 31 | # DONOT NEED CHANGE 32 | PRECISION=32 33 | MAX_SEQ_LEN=128 34 | MAX_EPOCH=3 35 | PROGRESS_BAR=1 36 | VAL_CHECK_INTERVAL=0.25 37 | DISTRIBUTE=ddp 38 | GRAD_CLIP=1 39 | 40 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE}_${MODEL_SCALE}_${LOSS_TYPE}_${MAX_SEQ_LEN}_${TRAIN_BATCH_SIZE}_${ACC_GRAD}_${MAX_EPOCH}_${LR}_${LR_SCHEDULER}_${DROPOUT}_${WARMUP_PROPORTION}_${WEIGHT_DECAY} 41 | mkdir -p ${OUTPUT_DIR} 42 | CACHE_DIR=${OUTPUT_DIR}/cache 43 | mkdir -p ${CACHE_DIR} 44 | 45 | 46 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 47 | CUDA_VISIBLE_DEVICES=1 python3 ${REPO_PATH}/tasks/glue/train.py \ 48 | --gpus="1" \ 49 | --task_name ${TASK} \ 50 | --max_seq_length ${MAX_SEQ_LEN} \ 51 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 52 | --precision=${PRECISION} \ 53 | --default_root_dir ${OUTPUT_DIR} \ 54 | --output_dir ${OUTPUT_DIR} \ 55 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 56 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 57 | --loss_type ${LOSS_TYPE} \ 58 | --data_dir ${DATA_DIR} \ 59 | --bert_config_dir ${BERT_DIR} \ 60 | --bert_hidden_dropout ${DROPOUT} \ 61 | --lr ${LR} \ 62 | --lr_scheduler ${LR_SCHEDULER} \ 63 | --accumulate_grad_batches ${ACC_GRAD} \ 64 | --output_dir ${OUTPUT_DIR} \ 65 | --max_epochs ${MAX_EPOCH} \ 66 | --gradient_clip_val ${GRAD_CLIP} \ 67 | --pad_to_max_length \ 68 | --weight_decay ${WEIGHT_DECAY} \ 69 | --warmup_proportion ${WARMUP_PROPORTION} \ 70 | --overwrite_cache 71 | -------------------------------------------------------------------------------- /scripts/glue_mrpc/bert_base_dice.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # result: 5 | # Dev f1/acc: 90.42/86.03 6 | # Test f1/acc: 88.23/83.59 7 | # gpu4: /data/xiaoya/outputs/dice_loss/glue_mrpc/2021.01.24/dice_night8_base_dice_128_20_1_5_3e-5_linear_0.2_0.002_0.003 8 | 9 | 10 | FILE=reproduce_dice_bertbase 11 | MODEL_SCALE=base 12 | TASK=mrpc 13 | OUTPUT_BASE_DIR=/data/xiaoya/outputs/dice_loss/glue_mrpc 14 | DATA_DIR=/data/xiaoya/datasets/mrpc 15 | BERT_DIR=/data/xiaoya/pretrain_lm/cased_L-12_H-768_A-12 16 | REPO_PATH=/data/xiaoya/workspace/dice_loss_for_NLP 17 | 18 | # NEED CHANGE 19 | LR=3e-5 20 | LR_SCHEDULER=linear 21 | TRAIN_BATCH_SIZE=20 22 | ACC_GRAD=1 23 | DROPOUT=0.2 24 | WEIGHT_DECAY=0.003 25 | WARMUP_PROPORTION=0.002 26 | LOSS_TYPE=dice 27 | DICE_SMOOTH=1 28 | DICE_OHEM=0 29 | DICE_ALPHA=0.01 30 | 31 | # DONOT NEED CHANGE 32 | PRECISION=32 33 | MAX_SEQ_LEN=128 34 | MAX_EPOCH=5 35 | PROGRESS_BAR=1 36 | VAL_CHECK_INTERVAL=0.25 37 | DISTRIBUTE=ddp 38 | GRAD_CLIP=1 39 | 40 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE}_${MODEL_SCALE}_${LOSS_TYPE}_${MAX_SEQ_LEN}_${TRAIN_BATCH_SIZE}_${ACC_GRAD}_${MAX_EPOCH}_${LR}_${LR_SCHEDULER}_${DROPOUT}_${WARMUP_PROPORTION}_${WEIGHT_DECAY} 41 | mkdir -p ${OUTPUT_DIR} 42 | CACHE_DIR=${OUTPUT_DIR}/cache 43 | mkdir -p ${CACHE_DIR} 44 | 45 | 46 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 47 | CUDA_VISIBLE_DEVICES=1 python3 ${REPO_PATH}/tasks/glue/train.py \ 48 | --gpus="1" \ 49 | --task_name ${TASK} \ 50 | --max_seq_length ${MAX_SEQ_LEN} \ 51 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 52 | --precision=${PRECISION} \ 53 | --default_root_dir ${OUTPUT_DIR} \ 54 | --output_dir ${OUTPUT_DIR} \ 55 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 56 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 57 | --loss_type ${LOSS_TYPE} \ 58 | --data_dir ${DATA_DIR} \ 59 | --bert_config_dir ${BERT_DIR} \ 60 | --bert_hidden_dropout ${DROPOUT} \ 61 | --lr ${LR} \ 62 | --lr_scheduler ${LR_SCHEDULER} \ 63 | --accumulate_grad_batches ${ACC_GRAD} \ 64 | --output_dir ${OUTPUT_DIR} \ 65 | --max_epochs ${MAX_EPOCH} \ 66 | --gradient_clip_val ${GRAD_CLIP} \ 67 | --pad_to_max_length \ 68 | --weight_decay ${WEIGHT_DECAY} \ 69 | --warmup_proportion ${WARMUP_PROPORTION} \ 70 | --overwrite_cache \ 71 | --dice_square \ 72 | --dice_smooth ${DICE_SMOOTH} \ 73 | --dice_ohem ${DICE_OHEM} \ 74 | --dice_alpha ${DICE_ALPHA} 75 | 76 | -------------------------------------------------------------------------------- /scripts/glue_mrpc/bert_base_focal.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # Result: 5 | # - Dev f1/acc: 89.31/84.80 6 | # - Test f1/acc: 88.06/83.59 7 | # gpu3: /data/xiaoya/outputs/dice_loss/glue_mrpc/2021.01.25/mrpc_focal_base4_base_focal_128_38_1_3_3e-5_linear_0.1_0.1_0.002 8 | 9 | 10 | FILE=mrpc_focal_base4 11 | MODEL_SCALE=base 12 | TASK=mrpc 13 | OUTPUT_BASE_DIR=/data/xiaoya/outputs/dice_loss/glue_mrpc 14 | DATA_DIR=/data/xiaoya/datasets/mrpc 15 | BERT_DIR=/data/xiaoya/pretrain_lm/cased_L-12_H-768_A-12 16 | REPO_PATH=/data/xiaoya/workspace/dice_loss_for_NLP 17 | 18 | # NEED CHANGE 19 | LR=3e-5 20 | LR_SCHEDULER=linear 21 | TRAIN_BATCH_SIZE=38 22 | ACC_GRAD=1 23 | DROPOUT=0.1 24 | WEIGHT_DECAY=0.002 25 | WARMUP_PROPORTION=0.1 26 | LOSS_TYPE=focal 27 | DICE_SMOOTH=1e-4 28 | DICE_OHEM=0.1 29 | DICE_ALPHA=0.01 30 | FOCAL_GAMMA=2 31 | OPTIMIZER=debias 32 | 33 | # DONOT NEED CHANGE 34 | PRECISION=32 35 | MAX_SEQ_LEN=128 36 | MAX_EPOCH=3 37 | PROGRESS_BAR=1 38 | VAL_CHECK_INTERVAL=0.25 39 | DISTRIBUTE=ddp 40 | GRAD_CLIP=1 41 | 42 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE}_${MODEL_SCALE}_${LOSS_TYPE}_${MAX_SEQ_LEN}_${TRAIN_BATCH_SIZE}_${ACC_GRAD}_${MAX_EPOCH}_${LR}_${LR_SCHEDULER}_${DROPOUT}_${WARMUP_PROPORTION}_${WEIGHT_DECAY} 43 | mkdir -p ${OUTPUT_DIR} 44 | CACHE_DIR=${OUTPUT_DIR}/cache 45 | mkdir -p ${CACHE_DIR} 46 | 47 | 48 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 49 | CUDA_VISIBLE_DEVICES=3 python3 ${REPO_PATH}/tasks/glue/train.py \ 50 | --gpus="1" \ 51 | --task_name ${TASK} \ 52 | --max_seq_length ${MAX_SEQ_LEN} \ 53 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 54 | --precision=${PRECISION} \ 55 | --default_root_dir ${OUTPUT_DIR} \ 56 | --output_dir ${OUTPUT_DIR} \ 57 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 58 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 59 | --loss_type ${LOSS_TYPE} \ 60 | --data_dir ${DATA_DIR} \ 61 | --bert_config_dir ${BERT_DIR} \ 62 | --bert_hidden_dropout ${DROPOUT} \ 63 | --lr ${LR} \ 64 | --lr_scheduler ${LR_SCHEDULER} \ 65 | --accumulate_grad_batches ${ACC_GRAD} \ 66 | --output_dir ${OUTPUT_DIR} \ 67 | --max_epochs ${MAX_EPOCH} \ 68 | --gradient_clip_val ${GRAD_CLIP} \ 69 | --pad_to_max_length \ 70 | --weight_decay ${WEIGHT_DECAY} \ 71 | --warmup_proportion ${WARMUP_PROPORTION} \ 72 | --overwrite_cache \ 73 | --optimizer ${OPTIMIZER} \ 74 | --focal_gamma ${FOCAL_GAMMA} 75 | 76 | 77 | -------------------------------------------------------------------------------- /scripts/glue_mrpc/bert_large_ce.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # result: 5 | # - Dev f1/acc: 91.53/87.75 6 | # - Test f1/acc: 87.98/83.13 7 | # gpu4: /data/xiaoya/outputs/dice_loss/glue_mrpc/2021.01.25/ce_large_large_ce_128_16_2_3_2e-5_linear_0.1_0.1_0.002 8 | 9 | 10 | FILE=ce_large 11 | MODEL_SCALE=large 12 | TASK=mrpc 13 | OUTPUT_BASE_DIR=/data/xiaoya/outputs/dice_loss/glue_mrpc 14 | DATA_DIR=/data/xiaoya/datasets/mrpc 15 | BERT_DIR=/data/xiaoya/pretrain_lm/bert_cased_large 16 | REPO_PATH=/data/xiaoya/workspace/dice_loss_for_NLP 17 | 18 | # NEED CHANGE 19 | LR=2e-5 20 | LR_SCHEDULER=linear 21 | TRAIN_BATCH_SIZE=16 22 | ACC_GRAD=2 23 | DROPOUT=0.1 24 | WEIGHT_DECAY=0.002 25 | WARMUP_PROPORTION=0.1 26 | LOSS_TYPE=ce 27 | DICE_SMOOTH=1e-4 28 | DICE_OHEM=0.1 29 | DICE_ALPHA=0.01 30 | FOCAL_GAMMA=0.1 31 | OPTIMIZER=debias 32 | 33 | # DONOT NEED CHANGE 34 | PRECISION=32 35 | MAX_SEQ_LEN=128 36 | MAX_EPOCH=3 37 | PROGRESS_BAR=1 38 | VAL_CHECK_INTERVAL=0.25 39 | DISTRIBUTE=ddp 40 | GRAD_CLIP=1 41 | 42 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE}_${MODEL_SCALE}_${LOSS_TYPE}_${MAX_SEQ_LEN}_${TRAIN_BATCH_SIZE}_${ACC_GRAD}_${MAX_EPOCH}_${LR}_${LR_SCHEDULER}_${DROPOUT}_${WARMUP_PROPORTION}_${WEIGHT_DECAY} 43 | mkdir -p ${OUTPUT_DIR} 44 | CACHE_DIR=${OUTPUT_DIR}/cache 45 | mkdir -p ${CACHE_DIR} 46 | 47 | 48 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 49 | CUDA_VISIBLE_DEVICES=2 python3 ${REPO_PATH}/tasks/glue/train.py \ 50 | --gpus="1" \ 51 | --task_name ${TASK} \ 52 | --max_seq_length ${MAX_SEQ_LEN} \ 53 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 54 | --precision=${PRECISION} \ 55 | --default_root_dir ${OUTPUT_DIR} \ 56 | --output_dir ${OUTPUT_DIR} \ 57 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 58 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 59 | --loss_type ${LOSS_TYPE} \ 60 | --data_dir ${DATA_DIR} \ 61 | --bert_config_dir ${BERT_DIR} \ 62 | --bert_hidden_dropout ${DROPOUT} \ 63 | --lr ${LR} \ 64 | --lr_scheduler ${LR_SCHEDULER} \ 65 | --accumulate_grad_batches ${ACC_GRAD} \ 66 | --output_dir ${OUTPUT_DIR} \ 67 | --max_epochs ${MAX_EPOCH} \ 68 | --gradient_clip_val ${GRAD_CLIP} \ 69 | --pad_to_max_length \ 70 | --weight_decay ${WEIGHT_DECAY} \ 71 | --warmup_proportion ${WARMUP_PROPORTION} \ 72 | --overwrite_cache \ 73 | --optimizer ${OPTIMIZER} 74 | 75 | -------------------------------------------------------------------------------- /scripts/glue_mrpc/bert_large_dice.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # result: 5 | # - Dev f1/acc: 91.31/87.50 6 | # - Test f1/acc: 88.98/84.64 7 | # gpu6: /data/xiaoya/outputs/dice_loss/glue_mrpc/2021.02.19/dice_large4_large_dice_128_12_1_5_2e-5_linear_0.1_0.06_0.001 8 | 9 | 10 | FILE=dice_large4 11 | MODEL_SCALE=large 12 | TASK=mrpc 13 | OUTPUT_BASE_DIR=/data/xiaoya/outputs/dice_loss/glue_mrpc 14 | DATA_DIR=/data/xiaoya/datasets/mrpc 15 | BERT_DIR=/data/xiaoya/pretrain_lm/bert_cased_large 16 | REPO_PATH=/data/xiaoya/workspace/dice_loss_for_NLP 17 | 18 | # NEED CHANGE 19 | LR=2e-5 20 | LR_SCHEDULER=linear 21 | TRAIN_BATCH_SIZE=12 22 | ACC_GRAD=1 23 | DROPOUT=0.1 24 | WEIGHT_DECAY=0.001 25 | WARMUP_PROPORTION=0.06 26 | LOSS_TYPE=dice 27 | DICE_SMOOTH=1 28 | DICE_OHEM=0 29 | DICE_ALPHA=0.01 30 | FOCAL_GAMMA=1 31 | OPTIMIZER=debias 32 | 33 | # DONOT NEED CHANGE 34 | PRECISION=32 35 | MAX_SEQ_LEN=128 36 | MAX_EPOCH=5 37 | PROGRESS_BAR=1 38 | VAL_CHECK_INTERVAL=0.25 39 | DISTRIBUTE=ddp 40 | GRAD_CLIP=1 41 | 42 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE}_${MODEL_SCALE}_${LOSS_TYPE}_${MAX_SEQ_LEN}_${TRAIN_BATCH_SIZE}_${ACC_GRAD}_${MAX_EPOCH}_${LR}_${LR_SCHEDULER}_${DROPOUT}_${WARMUP_PROPORTION}_${WEIGHT_DECAY} 43 | mkdir -p ${OUTPUT_DIR} 44 | CACHE_DIR=${OUTPUT_DIR}/cache 45 | mkdir -p ${CACHE_DIR} 46 | 47 | 48 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 49 | CUDA_VISIBLE_DEVICES=1 python3 ${REPO_PATH}/tasks/glue/train.py \ 50 | --gpus="1" \ 51 | --task_name ${TASK} \ 52 | --max_seq_length ${MAX_SEQ_LEN} \ 53 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 54 | --precision=${PRECISION} \ 55 | --default_root_dir ${OUTPUT_DIR} \ 56 | --output_dir ${OUTPUT_DIR} \ 57 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 58 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 59 | --loss_type ${LOSS_TYPE} \ 60 | --data_dir ${DATA_DIR} \ 61 | --bert_config_dir ${BERT_DIR} \ 62 | --bert_hidden_dropout ${DROPOUT} \ 63 | --lr ${LR} \ 64 | --lr_scheduler ${LR_SCHEDULER} \ 65 | --accumulate_grad_batches ${ACC_GRAD} \ 66 | --output_dir ${OUTPUT_DIR} \ 67 | --max_epochs ${MAX_EPOCH} \ 68 | --gradient_clip_val ${GRAD_CLIP} \ 69 | --pad_to_max_length \ 70 | --weight_decay ${WEIGHT_DECAY} \ 71 | --warmup_proportion ${WARMUP_PROPORTION} \ 72 | --overwrite_cache \ 73 | --optimizer ${OPTIMIZER} \ 74 | --dice_square \ 75 | --dice_smooth ${DICE_SMOOTH} \ 76 | --dice_ohem ${DICE_OHEM} \ 77 | --dice_alpha ${DICE_ALPHA} 78 | 79 | -------------------------------------------------------------------------------- /scripts/glue_mrpc/bert_large_focal.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # Result: 5 | # - Dev f1/acc: 90.91/86.76 6 | # - Test f1/acc: 88.35/84.06 7 | # gpu3: /data/xiaoya/outputs/dice_loss/glue_mrpc/2021.01.25/mrpc_focal_large2_large_focal_128_16_2_3_2e-5_linear_0.1_0.1_0.002 8 | 9 | 10 | FILE=mrpc_focal_large2 11 | MODEL_SCALE=large 12 | TASK=mrpc 13 | OUTPUT_BASE_DIR=/data/xiaoya/outputs/dice_loss/glue_mrpc 14 | DATA_DIR=/data/xiaoya/datasets/mrpc 15 | BERT_DIR=/data/xiaoya/pretrain_lm/bert_cased_large 16 | REPO_PATH=/data/xiaoya/workspace/dice_loss_for_NLP 17 | 18 | # NEED CHANGE 19 | LR=2e-5 20 | LR_SCHEDULER=linear 21 | TRAIN_BATCH_SIZE=16 22 | ACC_GRAD=2 23 | DROPOUT=0.1 24 | WEIGHT_DECAY=0.002 25 | WARMUP_PROPORTION=0.1 26 | LOSS_TYPE=focal 27 | DICE_SMOOTH=1e-4 28 | DICE_OHEM=0.1 29 | DICE_ALPHA=0.01 30 | FOCAL_GAMMA=3 31 | OPTIMIZER=debias 32 | 33 | # DONOT NEED CHANGE 34 | PRECISION=32 35 | MAX_SEQ_LEN=128 36 | MAX_EPOCH=3 37 | PROGRESS_BAR=1 38 | VAL_CHECK_INTERVAL=0.25 39 | DISTRIBUTE=ddp 40 | GRAD_CLIP=1 41 | 42 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE}_${MODEL_SCALE}_${LOSS_TYPE}_${MAX_SEQ_LEN}_${TRAIN_BATCH_SIZE}_${ACC_GRAD}_${MAX_EPOCH}_${LR}_${LR_SCHEDULER}_${DROPOUT}_${WARMUP_PROPORTION}_${WEIGHT_DECAY} 43 | mkdir -p ${OUTPUT_DIR} 44 | CACHE_DIR=${OUTPUT_DIR}/cache 45 | mkdir -p ${CACHE_DIR} 46 | 47 | 48 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 49 | CUDA_VISIBLE_DEVICES=2 python3 ${REPO_PATH}/tasks/glue/train.py \ 50 | --gpus="1" \ 51 | --task_name ${TASK} \ 52 | --max_seq_length ${MAX_SEQ_LEN} \ 53 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 54 | --precision=${PRECISION} \ 55 | --default_root_dir ${OUTPUT_DIR} \ 56 | --output_dir ${OUTPUT_DIR} \ 57 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 58 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 59 | --loss_type ${LOSS_TYPE} \ 60 | --data_dir ${DATA_DIR} \ 61 | --bert_config_dir ${BERT_DIR} \ 62 | --bert_hidden_dropout ${DROPOUT} \ 63 | --lr ${LR} \ 64 | --lr_scheduler ${LR_SCHEDULER} \ 65 | --accumulate_grad_batches ${ACC_GRAD} \ 66 | --output_dir ${OUTPUT_DIR} \ 67 | --max_epochs ${MAX_EPOCH} \ 68 | --gradient_clip_val ${GRAD_CLIP} \ 69 | --pad_to_max_length \ 70 | --weight_decay ${WEIGHT_DECAY} \ 71 | --warmup_proportion ${WARMUP_PROPORTION} \ 72 | --overwrite_cache \ 73 | --optimizer ${OPTIMIZER} \ 74 | --focal_gamma ${FOCAL_GAMMA} 75 | -------------------------------------------------------------------------------- /scripts/glue_mrpc/eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: scripts/glue/eval.sh 5 | 6 | 7 | REPO_PATH=/data/xiaoya/workspace/dice_loss_for_NLP 8 | 9 | EVAL_DIR=$1 10 | CKPT_PATH=${EVAL_DIR}/$2 11 | 12 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 13 | CUDA_VISIBLE_DEVICES=0 python3 ${REPO_PATH}/tasks/glue/evaluate_models.py \ 14 | --gpus="1" \ 15 | --path_to_model_checkpoint ${CKPT_PATH} -------------------------------------------------------------------------------- /scripts/mrc_squad1/bert_base_ce.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP 6 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 7 | 8 | DATA_DIR=/userhome/xiaoya/dataset/squad1 9 | BERT_DIR=/userhome/xiaoya/bert/uncased_L-12_H-768_A-12 10 | 11 | LOSS_TYPE=ce 12 | LR=3e-5 13 | LR_SCHEDULE=linear 14 | OPTIMIZER=adamw 15 | WARMUP_PROPORTION=0.002 16 | 17 | GRAD_CLIP=1.0 18 | ACC_GRAD=1 19 | MAX_EPOCH=2 20 | BERT_DROPOUT=0.1 21 | WEIGHT_DECAY=0.002 22 | 23 | TRAIN_BATCH_SIZE=12 24 | MAX_QUERY_LEN=64 25 | MAX_SEQ_LEN=384 26 | DOC_STRIDE=128 27 | 28 | PRECISION=16 29 | PROGRESS_BAR=1 30 | VAL_CHECK_INTERVAL=0.125 31 | DISTRIBUTE=ddp 32 | 33 | OUTPUT_DIR_BASE=/userhome/xiaoya/outputs/dice_loss/squad 34 | OUTPUT_DIR=${OUTPUT_DIR_BASE}/reproduce_bert_base_ce 35 | 36 | mkdir -p ${OUTPUT_DIR} 37 | CACHE_DIR=${OUTPUT_DIR}/cache 38 | mkdir -p ${CACHE_DIR} 39 | 40 | python ${REPO_PATH}/tasks/squad/train.py \ 41 | --gpus="1" \ 42 | --precision=${PRECISION} \ 43 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 44 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 45 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 46 | --max_query_length ${MAX_QUERY_LEN} \ 47 | --max_seq_length ${MAX_SEQ_LEN} \ 48 | --doc_stride ${DOC_STRIDE} \ 49 | --optimizer ${OPTIMIZER} \ 50 | --loss_type ${LOSS_TYPE} \ 51 | --data_dir ${DATA_DIR} \ 52 | --bert_hidden_dropout ${BERT_DROPOUT} \ 53 | --bert_config_dir ${BERT_DIR} \ 54 | --lr ${LR} \ 55 | --lr_scheduler ${LR_SCHEDULE} \ 56 | --accumulate_grad_batches ${ACC_GRAD} \ 57 | --default_root_dir ${OUTPUT_DIR} \ 58 | --output_dir ${OUTPUT_DIR} \ 59 | --max_epochs ${MAX_EPOCH} \ 60 | --gradient_clip_val ${GRAD_CLIP} \ 61 | --weight_decay ${WEIGHT_DECAY} \ 62 | --do_lower_case \ 63 | --warmup_proportion ${WARMUP_PROPORTION} 64 | -------------------------------------------------------------------------------- /scripts/mrc_squad1/bert_base_dice.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP 6 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 7 | 8 | DATA_DIR=/userhome/xiaoya/dataset/squad1 9 | BERT_DIR=/userhome/xiaoya/bert/uncased_L-12_H-768_A-12 10 | 11 | LOSS_TYPE=dice 12 | LR=3e-5 13 | LR_SCHEDULE=linear 14 | GRAD_CLIP=1.0 15 | OPTIMIZER=adamw 16 | WARMUP=0 17 | WARMUP_PROPORTION=0 18 | ACC_GRAD=5 19 | MAX_EPOCH=10 20 | BERT_DROPOUT=0.1 21 | WEIGHT_DECAY=0 22 | TRAIN_BATCH_SIZE=4 23 | MAX_QUERY_LEN=64 24 | MAX_SEQ_LEN=384 25 | DOC_STRIDE=128 26 | 27 | DICE_SMOOTH=1 28 | DICE_OHEM=0.3 29 | DICE_ALPHA=0.01 30 | 31 | PRECISION=16 32 | PROGRESS_BAR=1 33 | VAL_CHECK_INTERVAL=0.125 34 | 35 | OUTPUT_DIR_BASE=/userhome/xiaoya/outputs/dice_loss/squad 36 | OUTPUT_DIR=${OUTPUT_DIR_BASE}/reproduce_bert_base_dice 37 | 38 | echo "INFO -> OUTPUT_DIR is ${OUTPUT_DIR}" 39 | mkdir -p ${OUTPUT_DIR} 40 | CACHE_DIR=${OUTPUT_DIR}/cache 41 | mkdir -p ${CACHE_DIR} 42 | 43 | python ${REPO_PATH}/tasks/squad/train.py \ 44 | --gpus="1" \ 45 | --precision=${PRECISION} \ 46 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 47 | --dice_smooth ${DICE_SMOOTH} \ 48 | --dice_ohem ${DICE_OHEM} \ 49 | --dice_alpha ${DICE_ALPHA} \ 50 | --dice_square \ 51 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 52 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 53 | --max_query_length ${MAX_QUERY_LEN} \ 54 | --max_seq_length ${MAX_SEQ_LEN} \ 55 | --doc_stride ${DOC_STRIDE} \ 56 | --optimizer ${OPTIMIZER} \ 57 | --loss_type ${LOSS_TYPE} \ 58 | --data_dir ${DATA_DIR} \ 59 | --bert_hidden_dropout ${BERT_DROPOUT} \ 60 | --bert_config_dir ${BERT_DIR} \ 61 | --lr ${LR} \ 62 | --lr_scheduler ${LR_SCHEDULE} \ 63 | --accumulate_grad_batches ${ACC_GRAD} \ 64 | --default_root_dir ${OUTPUT_DIR} \ 65 | --output_dir ${OUTPUT_DIR} \ 66 | --max_epochs ${MAX_EPOCH} \ 67 | --gradient_clip_val ${GRAD_CLIP} \ 68 | --weight_decay ${WEIGHT_DECAY} \ 69 | --do_lower_case \ 70 | --warmup_proportion ${WARMUP_PROPORTION} 71 | 72 | 73 | -------------------------------------------------------------------------------- /scripts/mrc_squad1/bert_base_focal.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP 6 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 7 | 8 | DATA_DIR=/userhome/xiaoya/dataset/squad1 9 | BERT_DIR=/userhome/xiaoya/bert/uncased_L-12_H-768_A-12 10 | 11 | LR=3e-5 12 | LR_SCHEDULE=onecycle 13 | OPTIMIZER=adamw 14 | WARMUP_PROPORTION=0 15 | GRAD_CLIP=1.0 16 | MAX_EPOCH=3 17 | ACC_GRAD=6 18 | BERT_DROPOUT=0.1 19 | WEIGHT_DECAY=0.01 20 | BATCH_SIZE=2 21 | MAX_QUERY_LEN=64 22 | MAX_SEQ_LEN=384 23 | DOC_STRIDE=128 24 | 25 | LOSS_TYPE=focal 26 | FOCAL_GAMMA=2 27 | 28 | OUTPUT_DIR_BASE=/userhome/xiaoya/outputs/dice_loss/squad 29 | OUTPUT_DIR=${OUTPUT_DIR_BASE}/reproduce_bert_base_focal 30 | 31 | mkdir -p ${OUTPUT_DIR} 32 | CACHE_DIR=${OUTPUT_DIR}/cache 33 | mkdir -p ${CACHE_DIR} 34 | 35 | 36 | PRECISION=16 37 | PROGRESS_BAR=1 38 | VAL_CHECK_INTERVAL=0.5 39 | 40 | python ${REPO_PATH}/tasks/squad/train.py \ 41 | --gpus="1" \ 42 | --train_batch_size ${BATCH_SIZE} \ 43 | --precision=${PRECISION} \ 44 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 45 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 46 | --max_query_length ${MAX_QUERY_LEN} \ 47 | --max_seq_length ${MAX_SEQ_LEN} \ 48 | --doc_stride ${DOC_STRIDE} \ 49 | --optimizer ${OPTIMIZER} \ 50 | --loss_type ${LOSS_TYPE} \ 51 | --data_dir ${DATA_DIR} \ 52 | --bert_hidden_dropout ${BERT_DROPOUT} \ 53 | --bert_config_dir ${BERT_DIR} \ 54 | --lr ${LR} \ 55 | --lr_scheduler ${LR_SCHEDULE} \ 56 | --warmup_proportion ${WARMUP_PROPORTION} \ 57 | --accumulate_grad_batches ${ACC_GRAD} \ 58 | --default_root_dir ${OUTPUT_DIR} \ 59 | --output_dir ${OUTPUT_DIR} \ 60 | --max_epochs ${MAX_EPOCH} \ 61 | --gradient_clip_val ${GRAD_CLIP} \ 62 | --do_lower_case \ 63 | --weight_decay ${WEIGHT_DECAY} \ 64 | --focal_gamma ${FOCAL_GAMMA} \ 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /scripts/mrc_squad1/bert_large_ce.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # description: 5 | # predictions_4_7387.json 6 | # EM -> 83.98; F1 -> 90.89 7 | 8 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP 9 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 10 | 11 | MODEL_SCALE=large 12 | DATA_DIR=/userhome/xiaoya/dataset/squad1 13 | BERT_DIR=/userhome/xiaoya/bert/uncased_L-24_H-1024_A-16 14 | 15 | LOSS_TYPE=ce 16 | LR=3e-5 17 | LR_SCHEDULE=linear 18 | OPTIMIZER=adamw 19 | WARMUP_PROPORTION=0.002 20 | GRAD_CLIP=1.0 21 | ACC_GRAD=6 22 | MAX_EPOCH=2 23 | 24 | BERT_DROPOUT=0.1 25 | WEIGHT_DECAY=0.002 26 | TRAIN_BATCH_SIZE=4 27 | MAX_QUERY_LEN=64 28 | MAX_SEQ_LEN=384 29 | DOC_STRIDE=128 30 | 31 | PRECISION=16 32 | PROGRESS_BAR=1 33 | VAL_CHECK_INTERVAL=0.125 34 | DISTRIBUTE=ddp 35 | 36 | OUTPUT_DIR_BASE=/userhome/xiaoya/outputs/dice_loss/squad 37 | OUTPUT_DIR=${OUTPUT_DIR_BASE}/reproduce_bert_large_ce 38 | 39 | mkdir -p ${OUTPUT_DIR} 40 | CACHE_DIR=${OUTPUT_DIR}/cache 41 | mkdir -p ${CACHE_DIR} 42 | 43 | python ${REPO_PATH}/tasks/squad/train.py \ 44 | --gpus="1" \ 45 | --precision=${PRECISION} \ 46 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 47 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 48 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 49 | --max_query_length ${MAX_QUERY_LEN} \ 50 | --max_seq_length ${MAX_SEQ_LEN} \ 51 | --doc_stride ${DOC_STRIDE} \ 52 | --optimizer ${OPTIMIZER} \ 53 | --loss_type ${LOSS_TYPE} \ 54 | --data_dir ${DATA_DIR} \ 55 | --bert_hidden_dropout ${BERT_DROPOUT} \ 56 | --bert_config_dir ${BERT_DIR} \ 57 | --lr ${LR} \ 58 | --lr_scheduler ${LR_SCHEDULE} \ 59 | --accumulate_grad_batches ${ACC_GRAD} \ 60 | --default_root_dir ${OUTPUT_DIR} \ 61 | --output_dir ${OUTPUT_DIR} \ 62 | --max_epochs ${MAX_EPOCH} \ 63 | --gradient_clip_val ${GRAD_CLIP} \ 64 | --weight_decay ${WEIGHT_DECAY} \ 65 | --do_lower_case \ 66 | --warmup_proportion ${WARMUP_PROPORTION} 67 | 68 | -------------------------------------------------------------------------------- /scripts/mrc_squad1/bert_large_dice.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # description: 5 | # predictions_1_7387.json 6 | # EM -> 83.98; F1 -> 90.89 7 | 8 | 9 | FILE_NAME=dice_large 10 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP 11 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 12 | echo "DEBUG INFO -> repo dir is : ${REPO_PATH} " 13 | 14 | MODEL_SCALE=large 15 | DATA_DIR=/userhome/xiaoya/dataset/squad1 16 | BERT_DIR=/userhome/xiaoya/bert/uncased_L-24_H-1024_A-16 17 | 18 | LOSS_TYPE=dice 19 | LR=2e-5 20 | LR_SCHEDULE=polydecay 21 | OPTIMIZER=torch.adam 22 | WARMUP_PROPORTION=0.06 23 | 24 | GRAD_CLIP=1.0 25 | ACC_GRAD=20 26 | MAX_EPOCH=5 27 | 28 | BERT_DROPOUT=0.1 29 | WEIGHT_DECAY=0.002 30 | 31 | # data 32 | TRAIN_BATCH_SIZE=2 33 | MAX_QUERY_LEN=64 34 | MAX_SEQ_LEN=384 35 | DOC_STRIDE=128 36 | 37 | DICE_SMOOTH=1 38 | DICE_OHEM=0.1 39 | DICE_ALPHA=0.01 40 | 41 | PRECISION=16 # default AMP configuration in pytorch-lightning==0.9.0 is 'O2' 42 | PROGRESS_BAR=1 43 | VAL_CHECK_INTERVAL=0.125 44 | DISTRIBUTE=ddp 45 | 46 | 47 | # define OUTPUT directory 48 | OUTPUT_DIR_BASE=/userhome/xiaoya/outputs/dice_loss/squad 49 | OUTPUT_DIR=${OUTPUT_DIR_BASE}/${FILE_NAME}_${MODEL_SCALE}_${MAX_EPOCH}_${GRAD_CLIP}_${ACC_GRAD}_${WARMUP}_${OPTIMIZER}_${LR}_${BERT_DROPOUT}_${BATCH_SIZE}_${MAX_QUERY_LEN}_${MAX_SEQ_LEN}_${DOC_STRIDE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA} 50 | 51 | 52 | echo "INFO -> OUTPUT_DIR is ${OUTPUT_DIR}" 53 | mkdir -p ${OUTPUT_DIR} 54 | CACHE_DIR=${OUTPUT_DIR}/cache 55 | mkdir -p ${CACHE_DIR} 56 | 57 | 58 | python ${REPO_PATH}/squad/train.py \ 59 | --gpus="1" \ 60 | --precision=${PRECISION} \ 61 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 62 | --dice_smooth ${DICE_SMOOTH} \ 63 | --dice_ohem ${DICE_OHEM} \ 64 | --dice_alpha ${DICE_ALPHA} \ 65 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 66 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 67 | --max_query_length ${MAX_QUERY_LEN} \ 68 | --max_seq_length ${MAX_SEQ_LEN} \ 69 | --doc_stride ${DOC_STRIDE} \ 70 | --optimizer ${OPTIMIZER} \ 71 | --loss_type ${LOSS_TYPE} \ 72 | --data_dir ${DATA_DIR} \ 73 | --bert_hidden_dropout ${BERT_DROPOUT} \ 74 | --bert_config_dir ${BERT_DIR} \ 75 | --lr ${LR} \ 76 | --lr_scheduler ${LR_SCHEDULE} \ 77 | --accumulate_grad_batches ${ACC_GRAD} \ 78 | --default_root_dir ${OUTPUT_DIR} \ 79 | --output_dir ${OUTPUT_DIR} \ 80 | --max_epochs ${MAX_EPOCH} \ 81 | --gradient_clip_val ${GRAD_CLIP} \ 82 | --weight_decay ${WEIGHT_DECAY} \ 83 | --do_lower_case \ 84 | --dice_square \ 85 | --warmup_proportion ${WARMUP_PROPORTION} 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /scripts/mrc_squad1/bert_large_focal.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | FILE_NAME=focal_large 6 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP 7 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 8 | 9 | DATA_DIR=/userhome/xiaoya/dataset/squad1 10 | BERT_DIR=/userhome/xiaoya/bert/uncased_L-12_H-768_A-12 11 | 12 | LR=3e-5 13 | LR_SCHEDULE=onecycle 14 | OPTIMIZER=adamw 15 | WARMUP_PROPORTION=0.002 16 | GRAD_CLIP=1.0 17 | MAX_EPOCH=2 18 | ACC_GRAD=6 19 | 20 | BERT_DROPOUT=0.1 21 | WEIGHT_DECAY=0.002 22 | 23 | TRAIN_BATCH_SIZE=4 24 | MAX_QUERY_LEN=64 25 | MAX_SEQ_LEN=384 26 | DOC_STRIDE=128 27 | 28 | LOSS_TYPE=focal 29 | FOCAL_GAMMA=2 30 | 31 | OUTPUT_DIR_BASE=/userhome/xiaoya/outputs/dice_loss/squad 32 | OUTPUT_DIR=${OUTPUT_DIR_BASE}/${FILE_NAME}_${MAX_EPOCH}_${GRAD_CLIP}_${ACC_GRAD}_${WARMUP_PROPORTION}_${OPTIMIZER}_${LR}_${BERT_DROPOUT}_${WEIGHT_DECAY}_${BATCH_SIZE}_${MAX_QUERY_LEN}_${MAX_SEQ_LEN}_${DOC_STRIDE}_${FOCAL_GAMMA} 33 | 34 | echo "INFO -> OUTPUT_DIR is ${OUTPUT_DIR}" 35 | mkdir -p ${OUTPUT_DIR} 36 | CACHE_DIR=${OUTPUT_DIR}/cache 37 | mkdir -p ${CACHE_DIR} 38 | 39 | PRECISION=16 40 | PROGRESS_BAR=1 41 | VAL_CHECK_INTERVAL=0.125 42 | 43 | python ${REPO_PATH}/squad/train.py \ 44 | --gpus="0,1,2" \ 45 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 46 | --precision=${PRECISION} \ 47 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 48 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 49 | --max_query_length ${MAX_QUERY_LEN} \ 50 | --max_seq_length ${MAX_SEQ_LEN} \ 51 | --doc_stride ${DOC_STRIDE} \ 52 | --optimizer ${OPTIMIZER} \ 53 | --loss_type ${LOSS_TYPE} \ 54 | --data_dir ${DATA_DIR} \ 55 | --bert_hidden_dropout ${BERT_DROPOUT} \ 56 | --bert_config_dir ${BERT_DIR} \ 57 | --lr ${LR} \ 58 | --lr_scheduler ${LR_SCHEDULE} \ 59 | --warmup_proportion ${WARMUP_PROPORTION} \ 60 | --accumulate_grad_batches ${ACC_GRAD} \ 61 | --default_root_dir ${OUTPUT_DIR} \ 62 | --output_dir ${OUTPUT_DIR} \ 63 | --max_epochs ${MAX_EPOCH} \ 64 | --gradient_clip_val ${GRAD_CLIP} \ 65 | --do_lower_case \ 66 | --weight_decay ${WEIGHT_DECAY} \ 67 | --focal_gamma ${FOCAL_GAMMA} 68 | 69 | -------------------------------------------------------------------------------- /scripts/mrc_squad1/eval_pred_file.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP 6 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 7 | 8 | DATA_DIR=$1 9 | OUTPUT_DIR=$2 10 | DATA_FILE=${DATA_DIR}/dev-v1.1.json 11 | 12 | python3 ${REPO_PATH}/tasks/squad/evaluate_predictions.py ${DATA_FILE} ${OUTPUT_DIR} 13 | 14 | -------------------------------------------------------------------------------- /scripts/mrc_squad1/eval_saved_ckpt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | REPO_PATH=/data/xiaoya/workspace/dice_loss_for_NLP 6 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 7 | 8 | OUTPUT_DIR=/data/xiaoya/outputs/dice_loss/squad/gpu4_ce_base_2_1.0_1__adamw_3e-5_0.1_12_64_384_128 9 | MODEL_CKPT=${OUTPUT_DIR}/epoch=0_v2.ckpt 10 | HPARAMS_PATH=${OUTPUT_DIR}/lightning_logs/version_0/hparams.yaml 11 | 12 | CUDA_VISIBLE_DEVICES=3 python ${REPO_PATH}/tasks/squad/evaluate_models.py \ 13 | --gpus="1" \ 14 | --path_to_model_checkpoint ${MODEL_CKPT} \ 15 | --path_to_model_hparams_file ${HPARAMS_PATH} -------------------------------------------------------------------------------- /scripts/ner_enconll03/bert_dice.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # author: xiaoya li 5 | # first create: 2021.02.02 6 | # file: train.sh 7 | 8 | 9 | TIME=2021.09.06 10 | FILE_NAME=enconll_dice 11 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP 12 | MODEL_SCALE=large 13 | DATA_DIR=/userhome/xiaoya/dataset/en_conll03 14 | BERT_DIR=/userhome/xiaoya/bert/bert_cased_large 15 | 16 | TRAIN_BATCH_SIZE=36 17 | EVAL_BATCH_SIZE=1 18 | MAX_LENGTH=256 19 | 20 | OPTIMIZER=torch.adam 21 | LR_SCHEDULE=polydecay 22 | LR=3e-5 23 | 24 | BERT_DROPOUT=0.2 25 | ACC_GRAD=8 26 | MAX_EPOCH=10 27 | GRAD_CLIP=1.0 28 | WEIGHT_DECAY=0.01 29 | WARMUP_PROPORTION=0.01 30 | 31 | LOSS_TYPE=dice 32 | W_START=1 33 | W_END=1 34 | W_SPAN=0.3 35 | DICE_SMOOTH=1 36 | DICE_OHEM=0.0 37 | DICE_ALPHA=0.01 38 | FOCAL_GAMMA=2 39 | 40 | PRECISION=16 41 | PROGRESS_BAR=1 42 | VAL_CHECK_INTERVAL=0.25 43 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 44 | 45 | if [[ ${LOSS_TYPE} == "bce" ]]; then 46 | LOSS_SIGN=${LOSS_TYPE} 47 | elif [[ ${LOSS_TYPE} == "focal" ]]; then 48 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA} 49 | elif [[ ${LOSS_TYPE} == "dice" ]]; then 50 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA} 51 | fi 52 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}" 53 | 54 | OUTPUT_BASE_DIR=/userhome/xiaoya/outputs/dice_loss/mrc_ner/${TIME} 55 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${W_START}_${W_END}_${W_SPAN}_${LOSS_SIGN} 56 | 57 | mkdir -p ${OUTPUT_DIR} 58 | 59 | CUDA_VISIBLE_DEVICES=0 python ${REPO_PATH}/tasks/mrc_ner/train.py \ 60 | --gpus="1" \ 61 | --precision=${PRECISION} \ 62 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 63 | --eval_batch_size ${EVAL_BATCH_SIZE} \ 64 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 65 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 66 | --max_length ${MAX_LENGTH} \ 67 | --optimizer ${OPTIMIZER} \ 68 | --data_dir ${DATA_DIR} \ 69 | --bert_hidden_dropout ${BERT_DROPOUT} \ 70 | --bert_config_dir ${BERT_DIR} \ 71 | --lr ${LR} \ 72 | --lr_scheduler ${LR_SCHEDULE} \ 73 | --accumulate_grad_batches ${ACC_GRAD} \ 74 | --default_root_dir ${OUTPUT_DIR} \ 75 | --output_dir ${OUTPUT_DIR} \ 76 | --max_epochs ${MAX_EPOCH} \ 77 | --gradient_clip_val ${GRAD_CLIP} \ 78 | --weight_decay ${WEIGHT_DECAY} \ 79 | --loss_type ${LOSS_TYPE} \ 80 | --weight_start ${W_START} \ 81 | --weight_end ${W_END} \ 82 | --weight_span ${W_SPAN} \ 83 | --dice_smooth ${DICE_SMOOTH} \ 84 | --dice_ohem ${DICE_OHEM} \ 85 | --dice_alpha ${DICE_ALPHA} \ 86 | --dice_square \ 87 | --warmup_proportion ${WARMUP_PROPORTION} \ 88 | --span_loss_candidates gold_pred_random \ 89 | --construct_entity_span start_and_end \ 90 | --flat_ner \ 91 | --pred_answerable "train_infer" \ 92 | --answerable_task_ratio 0.4 \ 93 | --activate_func relu \ 94 | --data_sign en_conll03 -------------------------------------------------------------------------------- /scripts/ner_enconll03/bert_focal.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # author: xiaoya li 5 | # first create: 2021.02.02 6 | # file: train.sh 7 | 8 | TIME=2021.07.23 9 | FILE_NAME=enconll03_focal 10 | REPO_PATH=/userhome/xiaoya/mrc-with-dice-loss 11 | MODEL_SCALE=large 12 | DATA_DIR=/userhome/xiaoya/dataset/en_conll03 13 | BERT_DIR=/userhome/xiaoya/bert/bert-large-cased 14 | 15 | TRAIN_BATCH_SIZE=18 16 | EVAL_BATCH_SIZE=1 17 | MAX_LENGTH=256 18 | 19 | OPTIMIZER=torch.adam 20 | LR_SCHEDULE=polydecay 21 | LR=2e-5 22 | 23 | BERT_DROPOUT=0.2 24 | ACC_GRAD=4 25 | MAX_EPOCH=10 26 | GRAD_CLIP=1.0 27 | WEIGHT_DECAY=0.002 28 | WARMUP_PROPORTION=0.06 29 | 30 | LOSS_TYPE=focal 31 | W_START=1 32 | W_END=1 33 | W_SPAN=0.3 34 | DICE_SMOOTH=1 35 | DICE_OHEM=0.0 36 | DICE_ALPHA=0.01 37 | FOCAL_GAMMA=3 38 | 39 | PRECISION=16 40 | PROGRESS_BAR=1 41 | VAL_CHECK_INTERVAL=0.25 42 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 43 | 44 | if [[ ${LOSS_TYPE} == "bce" ]]; then 45 | LOSS_SIGN=${LOSS_TYPE} 46 | elif [[ ${LOSS_TYPE} == "focal" ]]; then 47 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA} 48 | elif [[ ${LOSS_TYPE} == "dice" ]]; then 49 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA} 50 | fi 51 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}" 52 | 53 | OUTPUT_BASE_DIR=/userhome/xiaoya/outputs/dice_loss/mrc_ner/${TIME} 54 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${W_START}_${W_END}_${W_SPAN}_${LOSS_SIGN} 55 | 56 | mkdir -p ${OUTPUT_DIR} 57 | 58 | CUDA_VISIBLE_DEVICES=0 python ${REPO_PATH}/tasks/mrc_ner/train.py \ 59 | --gpus="1" \ 60 | --precision=${PRECISION} \ 61 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 62 | --eval_batch_size ${EVAL_BATCH_SIZE} \ 63 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 64 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 65 | --max_length ${MAX_LENGTH} \ 66 | --optimizer ${OPTIMIZER} \ 67 | --data_dir ${DATA_DIR} \ 68 | --bert_hidden_dropout ${BERT_DROPOUT} \ 69 | --bert_config_dir ${BERT_DIR} \ 70 | --lr ${LR} \ 71 | --lr_scheduler ${LR_SCHEDULE} \ 72 | --accumulate_grad_batches ${ACC_GRAD} \ 73 | --default_root_dir ${OUTPUT_DIR} \ 74 | --output_dir ${OUTPUT_DIR} \ 75 | --max_epochs ${MAX_EPOCH} \ 76 | --gradient_clip_val ${GRAD_CLIP} \ 77 | --weight_decay ${WEIGHT_DECAY} \ 78 | --loss_type ${LOSS_TYPE} \ 79 | --weight_start ${W_START} \ 80 | --weight_end ${W_END} \ 81 | --weight_span ${W_SPAN} \ 82 | --dice_smooth ${DICE_SMOOTH} \ 83 | --dice_ohem ${DICE_OHEM} \ 84 | --dice_alpha ${DICE_ALPHA} \ 85 | --dice_square \ 86 | --warmup_proportion ${WARMUP_PROPORTION} \ 87 | --span_loss_candidates gold_pred_random \ 88 | --construct_entity_span start_and_end \ 89 | --flat_ner \ 90 | --pred_answerable "train_infer" \ 91 | --answerable_task_ratio 0.2 \ 92 | --activate_func relu \ 93 | --data_sign en_conll03 -------------------------------------------------------------------------------- /scripts/ner_enontonotes5/bert_dice.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | FILE_NAME=brain_enonto_dice 6 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP 7 | MODEL_SCALE=base 8 | DATA_DIR=/userhome/xiaoya/dataset/new_mrc_ner/new_en_onto5 9 | BERT_DIR=/userhome/xiaoya/bert/bert_cased_large 10 | 11 | TRAIN_BATCH_SIZE=12 12 | EVAL_BATCH_SIZE=1 13 | MAX_LENGTH=300 14 | 15 | OPTIMIZER=torch.adam 16 | LR_SCHEDULE=linear 17 | LR=2e-5 18 | 19 | BERT_DROPOUT=0.1 20 | ACC_GRAD=6 21 | MAX_EPOCH=5 22 | GRAD_CLIP=1.0 23 | WEIGHT_DECAY=0.002 24 | WARMUP_PROPORTION=0.1 25 | 26 | LOSS_TYPE=dice 27 | W_START=1 28 | W_END=1 29 | W_SPAN=0.3 30 | DICE_SMOOTH=1 31 | DICE_OHEM=0.3 32 | DICE_ALPHA=0.01 33 | FOCAL_GAMMA=2 34 | 35 | PRECISION=16 36 | PROGRESS_BAR=1 37 | VAL_CHECK_INTERVAL=0.25 38 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 39 | 40 | if [[ ${LOSS_TYPE} == "bce" ]]; then 41 | LOSS_SIGN=${LOSS_TYPE} 42 | elif [[ ${LOSS_TYPE} == "focal" ]]; then 43 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA} 44 | elif [[ ${LOSS_TYPE} == "dice" ]]; then 45 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA} 46 | fi 47 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}" 48 | 49 | OUTPUT_BASE_DIR=/userhome/xiaoya/outputs/dice_loss/mrc_ner 50 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${W_START}_${W_END}_${W_SPAN}_${LOSS_SIGN} 51 | 52 | mkdir -p ${OUTPUT_DIR} 53 | 54 | CUDA_VISIBLE_DEVICES=1 python ${REPO_PATH}/tasks/mrc_ner/train.py \ 55 | --gpus="1" \ 56 | --precision=${PRECISION} \ 57 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 58 | --eval_batch_size ${EVAL_BATCH_SIZE} \ 59 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 60 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 61 | --max_length ${MAX_LENGTH} \ 62 | --optimizer ${OPTIMIZER} \ 63 | --data_dir ${DATA_DIR} \ 64 | --bert_hidden_dropout ${BERT_DROPOUT} \ 65 | --bert_config_dir ${BERT_DIR} \ 66 | --lr ${LR} \ 67 | --lr_scheduler ${LR_SCHEDULE} \ 68 | --accumulate_grad_batches ${ACC_GRAD} \ 69 | --default_root_dir ${OUTPUT_DIR} \ 70 | --output_dir ${OUTPUT_DIR} \ 71 | --max_epochs ${MAX_EPOCH} \ 72 | --gradient_clip_val ${GRAD_CLIP} \ 73 | --weight_decay ${WEIGHT_DECAY} \ 74 | --loss_type ${LOSS_TYPE} \ 75 | --weight_start ${W_START} \ 76 | --weight_end ${W_END} \ 77 | --weight_span ${W_SPAN} \ 78 | --dice_smooth ${DICE_SMOOTH} \ 79 | --dice_ohem ${DICE_OHEM} \ 80 | --dice_alpha ${DICE_ALPHA} \ 81 | --dice_square \ 82 | --warmup_proportion ${WARMUP_PROPORTION} \ 83 | --span_loss_candidates gold_pred_random \ 84 | --construct_entity_span start_and_end \ 85 | --num_labels 1 \ 86 | --flat_ner \ 87 | --pred_answerable "train_infer" \ 88 | --answerable_task_ratio 0.2 \ 89 | --activate_func relu \ 90 | --data_sign en_onto -------------------------------------------------------------------------------- /scripts/ner_enontonotes5/bert_focal.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | FILE_NAME=brain_enonto_focal 6 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP 7 | MODEL_SCALE=base 8 | DATA_DIR=/userhome/xiaoya/dataset/mrc_ner/new_enonto5 9 | BERT_DIR=/userhome/xiaoya/bert/english_bert_large 10 | 11 | TRAIN_BATCH_SIZE=4 12 | EVAL_BATCH_SIZE=1 13 | MAX_LENGTH=300 14 | 15 | OPTIMIZER=torch.adam 16 | LR_SCHEDULE=linear 17 | LR=1e-5 18 | 19 | BERT_DROPOUT=0.1 20 | ACC_GRAD=12 21 | MAX_EPOCH=5 22 | GRAD_CLIP=1.0 23 | WEIGHT_DECAY=0.002 24 | WARMUP_PROPORTION=0.1 25 | 26 | LOSS_TYPE=focal 27 | W_START=1 28 | W_END=1 29 | W_SPAN=0.3 30 | DICE_SMOOTH=1 31 | DICE_OHEM=0.4 32 | DICE_ALPHA=0.01 33 | FOCAL_GAMMA=4 34 | 35 | PRECISION=16 36 | PROGRESS_BAR=1 37 | VAL_CHECK_INTERVAL=0.25 38 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 39 | 40 | if [[ ${LOSS_TYPE} == "bce" ]]; then 41 | LOSS_SIGN=${LOSS_TYPE} 42 | elif [[ ${LOSS_TYPE} == "focal" ]]; then 43 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA} 44 | elif [[ ${LOSS_TYPE} == "dice" ]]; then 45 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA} 46 | fi 47 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}" 48 | 49 | OUTPUT_BASE_DIR=/userhome/xiaoya/outputs/dice_loss/mrc_ner 50 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${W_START}_${W_END}_${W_SPAN}_${LOSS_SIGN} 51 | 52 | mkdir -p ${OUTPUT_DIR} 53 | 54 | CUDA_VISIBLE_DEVICES=1 python ${REPO_PATH}/tasks/mrc_ner/train.py \ 55 | --gpus="1" \ 56 | --precision=${PRECISION} \ 57 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 58 | --eval_batch_size ${EVAL_BATCH_SIZE} \ 59 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 60 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 61 | --max_length ${MAX_LENGTH} \ 62 | --optimizer ${OPTIMIZER} \ 63 | --data_dir ${DATA_DIR} \ 64 | --bert_hidden_dropout ${BERT_DROPOUT} \ 65 | --bert_config_dir ${BERT_DIR} \ 66 | --lr ${LR} \ 67 | --lr_scheduler ${LR_SCHEDULE} \ 68 | --accumulate_grad_batches ${ACC_GRAD} \ 69 | --default_root_dir ${OUTPUT_DIR} \ 70 | --output_dir ${OUTPUT_DIR} \ 71 | --max_epochs ${MAX_EPOCH} \ 72 | --gradient_clip_val ${GRAD_CLIP} \ 73 | --weight_decay ${WEIGHT_DECAY} \ 74 | --loss_type ${LOSS_TYPE} \ 75 | --weight_start ${W_START} \ 76 | --weight_end ${W_END} \ 77 | --weight_span ${W_SPAN} \ 78 | --warmup_proportion ${WARMUP_PROPORTION} \ 79 | --span_loss_candidates gold_pred_random \ 80 | --construct_entity_span start_and_end \ 81 | --num_labels 1 \ 82 | --flat_ner \ 83 | --focal_gamma ${FOCAL_GAMMA} \ 84 | --pred_answerable "train_infer" \ 85 | --answerable_task_ratio 0.2 \ 86 | --activate_func relu \ 87 | --data_sign en_onto -------------------------------------------------------------------------------- /scripts/ner_zhmsra/bert_dice.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | FILE_NAME=reproduce_zhmsra_dice 6 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP 7 | MODEL_SCALE=base 8 | DATA_DIR=/userhome/xiaoya/dataset/new_mrc_ner/new_zh_msra 9 | BERT_DIR=/userhome/xiaoya/bert/chinese_bert 10 | 11 | TRAIN_BATCH_SIZE=10 12 | EVAL_BATCH_SIZE=1 13 | MAX_LENGTH=275 14 | 15 | OPTIMIZER=torch.adam 16 | LR_SCHEDULE=polydecay 17 | LR=2e-5 18 | 19 | BERT_DROPOUT=0.1 20 | ACC_GRAD=3 21 | MAX_EPOCH=5 22 | GRAD_CLIP=1.0 23 | WEIGHT_DECAY=0.002 24 | WARMUP_PROPORTION=0.02 25 | 26 | LOSS_TYPE=dice 27 | W_START=1 28 | W_END=1 29 | W_SPAN=0.2 30 | DICE_SMOOTH=1 31 | DICE_OHEM=0.3 32 | DICE_ALPHA=0.01 33 | FOCAL_GAMMA=2 34 | 35 | PRECISION=16 36 | PROGRESS_BAR=1 37 | VAL_CHECK_INTERVAL=0.25 38 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 39 | 40 | if [[ ${LOSS_TYPE} == "bce" ]]; then 41 | LOSS_SIGN=${LOSS_TYPE} 42 | elif [[ ${LOSS_TYPE} == "focal" ]]; then 43 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA} 44 | elif [[ ${LOSS_TYPE} == "dice" ]]; then 45 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA} 46 | fi 47 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}" 48 | 49 | OUTPUT_BASE_DIR=/userhome/xiaoya/outputs/dice_loss/mrc_ner 50 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${W_START}_${W_END}_${W_SPAN}_${LOSS_SIGN} 51 | 52 | mkdir -p ${OUTPUT_DIR} 53 | 54 | CUDA_VISIBLE_DEVICES=1 python ${REPO_PATH}/tasks/mrc_ner/train.py \ 55 | --gpus="1" \ 56 | --precision=${PRECISION} \ 57 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 58 | --eval_batch_size ${EVAL_BATCH_SIZE} \ 59 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 60 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 61 | --max_length ${MAX_LENGTH} \ 62 | --optimizer ${OPTIMIZER} \ 63 | --data_dir ${DATA_DIR} \ 64 | --bert_hidden_dropout ${BERT_DROPOUT} \ 65 | --bert_config_dir ${BERT_DIR} \ 66 | --lr ${LR} \ 67 | --lr_scheduler ${LR_SCHEDULE} \ 68 | --accumulate_grad_batches ${ACC_GRAD} \ 69 | --default_root_dir ${OUTPUT_DIR} \ 70 | --output_dir ${OUTPUT_DIR} \ 71 | --max_epochs ${MAX_EPOCH} \ 72 | --gradient_clip_val ${GRAD_CLIP} \ 73 | --weight_decay ${WEIGHT_DECAY} \ 74 | --loss_type ${LOSS_TYPE} \ 75 | --weight_start ${W_START} \ 76 | --weight_end ${W_END} \ 77 | --weight_span ${W_SPAN} \ 78 | --dice_smooth ${DICE_SMOOTH} \ 79 | --dice_ohem ${DICE_OHEM} \ 80 | --dice_alpha ${DICE_ALPHA} \ 81 | --dice_square \ 82 | --warmup_proportion ${WARMUP_PROPORTION} \ 83 | --span_loss_candidates gold_pred_random \ 84 | --construct_entity_span start_and_end \ 85 | --num_labels 1 \ 86 | --flat_ner \ 87 | --is_chinese \ 88 | --pred_answerable "train_infer" \ 89 | --answerable_task_ratio 0.3 \ 90 | --activate_func relu \ 91 | --data_sign zh_msra -------------------------------------------------------------------------------- /scripts/ner_zhmsra/bert_focal.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | FILE_NAME=brain_zhmsra_focal 6 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP 7 | MODEL_SCALE=base 8 | DATA_DIR=/userhome/xiaoya/dataset/mrc_ner/new_msra 9 | BERT_DIR=/userhome/xiaoya/bert/chinese_bert 10 | 11 | TRAIN_BATCH_SIZE=9 12 | EVAL_BATCH_SIZE=1 13 | MAX_LENGTH=275 14 | 15 | OPTIMIZER=torch.adam 16 | LR_SCHEDULE=linear 17 | LR=1e-5 18 | 19 | BERT_DROPOUT=0.1 20 | ACC_GRAD=3 21 | MAX_EPOCH=5 22 | GRAD_CLIP=1.0 23 | WEIGHT_DECAY=0.002 24 | WARMUP_PROPORTION=0.08 25 | 26 | LOSS_TYPE=focal 27 | W_START=1 28 | W_END=1 29 | W_SPAN=0.3 30 | DICE_SMOOTH=1 31 | DICE_OHEM=0.4 32 | DICE_ALPHA=0.01 33 | FOCAL_GAMMA=3 34 | 35 | PRECISION=16 36 | PROGRESS_BAR=1 37 | VAL_CHECK_INTERVAL=0.25 38 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 39 | 40 | if [[ ${LOSS_TYPE} == "bce" ]]; then 41 | LOSS_SIGN=${LOSS_TYPE} 42 | elif [[ ${LOSS_TYPE} == "focal" ]]; then 43 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA} 44 | elif [[ ${LOSS_TYPE} == "dice" ]]; then 45 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA} 46 | fi 47 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}" 48 | 49 | OUTPUT_BASE_DIR=/userhome/xiaoya/outputs/dice_loss/mrc_ner 50 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${W_START}_${W_END}_${W_SPAN}_${LOSS_SIGN} 51 | 52 | mkdir -p ${OUTPUT_DIR} 53 | 54 | CUDA_VISIBLE_DEVICES=1 python ${REPO_PATH}/tasks/mrc_ner/train.py \ 55 | --gpus="1" \ 56 | --precision=${PRECISION} \ 57 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 58 | --eval_batch_size ${EVAL_BATCH_SIZE} \ 59 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 60 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 61 | --max_length ${MAX_LENGTH} \ 62 | --optimizer ${OPTIMIZER} \ 63 | --data_dir ${DATA_DIR} \ 64 | --bert_hidden_dropout ${BERT_DROPOUT} \ 65 | --bert_config_dir ${BERT_DIR} \ 66 | --lr ${LR} \ 67 | --lr_scheduler ${LR_SCHEDULE} \ 68 | --accumulate_grad_batches ${ACC_GRAD} \ 69 | --default_root_dir ${OUTPUT_DIR} \ 70 | --output_dir ${OUTPUT_DIR} \ 71 | --max_epochs ${MAX_EPOCH} \ 72 | --gradient_clip_val ${GRAD_CLIP} \ 73 | --weight_decay ${WEIGHT_DECAY} \ 74 | --loss_type ${LOSS_TYPE} \ 75 | --weight_start ${W_START} \ 76 | --weight_end ${W_END} \ 77 | --weight_span ${W_SPAN} \ 78 | --focal_gamma ${FOCAL_GAMMA} \ 79 | --warmup_proportion ${WARMUP_PROPORTION} \ 80 | --span_loss_candidates gold_pred_random \ 81 | --construct_entity_span start_end_match \ 82 | --num_labels 1 \ 83 | --flat_ner \ 84 | --is_chinese \ 85 | --pred_answerable "train_infer" \ 86 | --answerable_task_ratio 0.2 \ 87 | --activate_func relu \ 88 | --data_sign zh_msra -------------------------------------------------------------------------------- /scripts/ner_zhonto4/bert_dice.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | FILE_NAME=reproduce_zhonto_dice 6 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP 7 | MODEL_SCALE=base 8 | DATA_DIR=/userhome/xiaoya/dataset/new_mrc_ner/new_zh_onto4 9 | BERT_DIR=/userhome/xiaoya/bert/chinese_bert 10 | 11 | TRAIN_BATCH_SIZE=8 12 | EVAL_BATCH_SIZE=1 13 | MAX_LENGTH=300 14 | 15 | OPTIMIZER=torch.adam 16 | LR_SCHEDULE=polydecay 17 | LR=2e-5 18 | 19 | BERT_DROPOUT=0.1 20 | ACC_GRAD=2 21 | MAX_EPOCH=5 22 | GRAD_CLIP=1.0 23 | WEIGHT_DECAY=0.002 24 | WARMUP_PROPORTION=0.1 25 | 26 | LOSS_TYPE=dice 27 | W_START=1 28 | W_END=1 29 | W_SPAN=0.3 30 | DICE_SMOOTH=1 31 | DICE_OHEM=0.3 32 | DICE_ALPHA=0.01 33 | FOCAL_GAMMA=2 34 | 35 | PRECISION=16 36 | PROGRESS_BAR=1 37 | VAL_CHECK_INTERVAL=0.25 38 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 39 | 40 | if [[ ${LOSS_TYPE} == "bce" ]]; then 41 | LOSS_SIGN=${LOSS_TYPE} 42 | elif [[ ${LOSS_TYPE} == "focal" ]]; then 43 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA} 44 | elif [[ ${LOSS_TYPE} == "dice" ]]; then 45 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA} 46 | fi 47 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}" 48 | 49 | OUTPUT_BASE_DIR=/userhome/xiaoya/outputs/dice_loss/mrc_ner 50 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${W_START}_${W_END}_${W_SPAN}_${LOSS_SIGN} 51 | 52 | mkdir -p ${OUTPUT_DIR} 53 | 54 | CUDA_VISIBLE_DEVICES=0 python ${REPO_PATH}/tasks/mrc_ner/train.py \ 55 | --gpus="1" \ 56 | --precision=${PRECISION} \ 57 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 58 | --eval_batch_size ${EVAL_BATCH_SIZE} \ 59 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 60 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 61 | --max_length ${MAX_LENGTH} \ 62 | --optimizer ${OPTIMIZER} \ 63 | --data_dir ${DATA_DIR} \ 64 | --bert_hidden_dropout ${BERT_DROPOUT} \ 65 | --bert_config_dir ${BERT_DIR} \ 66 | --lr ${LR} \ 67 | --lr_scheduler ${LR_SCHEDULE} \ 68 | --accumulate_grad_batches ${ACC_GRAD} \ 69 | --default_root_dir ${OUTPUT_DIR} \ 70 | --output_dir ${OUTPUT_DIR} \ 71 | --max_epochs ${MAX_EPOCH} \ 72 | --gradient_clip_val ${GRAD_CLIP} \ 73 | --weight_decay ${WEIGHT_DECAY} \ 74 | --loss_type ${LOSS_TYPE} \ 75 | --weight_start ${W_START} \ 76 | --weight_end ${W_END} \ 77 | --weight_span ${W_SPAN} \ 78 | --dice_smooth ${DICE_SMOOTH} \ 79 | --dice_ohem ${DICE_OHEM} \ 80 | --dice_alpha ${DICE_ALPHA} \ 81 | --dice_square \ 82 | --warmup_proportion ${WARMUP_PROPORTION} \ 83 | --span_loss_candidates gold_pred_random \ 84 | --construct_entity_span start_and_end \ 85 | --num_labels 1 \ 86 | --flat_ner \ 87 | --is_chinese \ 88 | --pred_answerable "train_infer" \ 89 | --answerable_task_ratio 0.2 \ 90 | --activate_func relu -------------------------------------------------------------------------------- /scripts/ner_zhonto4/bert_focal.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | FILE_NAME=brain_zhonto_focal 6 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP 7 | MODEL_SCALE=base 8 | DATA_DIR=/userhome/xiaoya/dataset/mrc_ner/new_onto4 9 | BERT_DIR=/userhome/xiaoya/bert/chinese_bert 10 | 11 | TRAIN_BATCH_SIZE=8 12 | EVAL_BATCH_SIZE=1 13 | MAX_LENGTH=300 14 | 15 | OPTIMIZER=torch.adam 16 | LR_SCHEDULE=polydecay 17 | LR=2e-5 18 | 19 | BERT_DROPOUT=0.1 20 | ACC_GRAD=2 21 | MAX_EPOCH=5 22 | GRAD_CLIP=1.0 23 | WEIGHT_DECAY=0.002 24 | WARMUP_PROPORTION=0.1 25 | 26 | LOSS_TYPE=focal 27 | W_START=1 28 | W_END=1 29 | W_SPAN=0.3 30 | DICE_SMOOTH=1 31 | DICE_OHEM=0.4 32 | DICE_ALPHA=0.01 33 | FOCAL_GAMMA=4 34 | 35 | PRECISION=16 36 | PROGRESS_BAR=1 37 | VAL_CHECK_INTERVAL=0.25 38 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 39 | 40 | if [[ ${LOSS_TYPE} == "bce" ]]; then 41 | LOSS_SIGN=${LOSS_TYPE} 42 | elif [[ ${LOSS_TYPE} == "focal" ]]; then 43 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA} 44 | elif [[ ${LOSS_TYPE} == "dice" ]]; then 45 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA} 46 | fi 47 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}" 48 | 49 | OUTPUT_BASE_DIR=/userhome/xiaoya/outputs/dice_loss/mrc_ner 50 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${W_START}_${W_END}_${W_SPAN}_${LOSS_SIGN} 51 | 52 | mkdir -p ${OUTPUT_DIR} 53 | 54 | CUDA_VISIBLE_DEVICES=1 python ${REPO_PATH}/tasks/mrc_ner/train.py \ 55 | --gpus="1" \ 56 | --precision=${PRECISION} \ 57 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 58 | --eval_batch_size ${EVAL_BATCH_SIZE} \ 59 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 60 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 61 | --max_length ${MAX_LENGTH} \ 62 | --optimizer ${OPTIMIZER} \ 63 | --data_dir ${DATA_DIR} \ 64 | --bert_hidden_dropout ${BERT_DROPOUT} \ 65 | --bert_config_dir ${BERT_DIR} \ 66 | --lr ${LR} \ 67 | --lr_scheduler ${LR_SCHEDULE} \ 68 | --accumulate_grad_batches ${ACC_GRAD} \ 69 | --default_root_dir ${OUTPUT_DIR} \ 70 | --output_dir ${OUTPUT_DIR} \ 71 | --max_epochs ${MAX_EPOCH} \ 72 | --gradient_clip_val ${GRAD_CLIP} \ 73 | --weight_decay ${WEIGHT_DECAY} \ 74 | --loss_type ${LOSS_TYPE} \ 75 | --weight_start ${W_START} \ 76 | --weight_end ${W_END} \ 77 | --weight_span ${W_SPAN} \ 78 | --focal_gamma ${FOCAL_GAMMA} \ 79 | --warmup_proportion ${WARMUP_PROPORTION} \ 80 | --span_loss_candidates gold_pred_random \ 81 | --construct_entity_span start_and_end \ 82 | --num_labels 1 \ 83 | --flat_ner \ 84 | --is_chinese \ 85 | --pred_answerable "train_infer" \ 86 | --answerable_task_ratio 0.3 \ 87 | --activate_func relu -------------------------------------------------------------------------------- /scripts/prepare_ckpt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # description: 5 | # NOTICE: 6 | # Please make sure tensorflow 7 | # 8 | 9 | # should install tensorflow for loading parameters in Pretrained Models. 10 | pip install tensorflow 11 | 12 | 13 | BERT_BASE_DIR=$1 14 | 15 | transformers-cli convert --model_type bert \ 16 | --tf_checkpoint ${BERT_BASE_DIR}/bert_model.ckpt \ 17 | --config ${BERT_BASE_DIR}/bert_config.json \ 18 | --pytorch_dump_output ${BERT_BASE_DIR}/pytorch_model.bin 19 | 20 | cp ${BERT_BASE_DIR}/bert_config.json ${BERT_BASE_DIR}/config.json -------------------------------------------------------------------------------- /scripts/prepare_mrpc_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # example: 5 | # bash scripts/prepare_mrpc_data.sh /data/xiaoya/outputs/debug 6 | 7 | SAVE_DATA_DIR=$1 8 | DEV_IDS_FILE=$PWD/tasks/glue/mrpc_dev_ids.tsv 9 | 10 | echo "***** INFO ***** -> data dir is : $SAVE_DATA_DIR" 11 | echo "***** INFO ***** dev ids file is : $DEV_IDS_FILE" 12 | 13 | # download mrpc data files 14 | wget https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt -P ${SAVE_DATA_DIR} 15 | wget https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt -P ${SAVE_DATA_DIR} 16 | 17 | # process mrpc data 18 | python3 $PWD/tasks/glue/process_mrpc.py ${SAVE_DATA_DIR} ${DEV_IDS_FILE} -------------------------------------------------------------------------------- /scripts/textcl_tnews/bert_dice.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | FILE_NAME=tnews_dice 6 | REPO_PATH=/data/xiaoya/workspace/dice_loss_for_NLP 7 | MODEL_SCALE=base 8 | DATA_DIR=/data/xiaoya/datasets/tnews_public_data 9 | BERT_DIR=/data/xiaoya/pretrain_lm/chinese_L-12_H-768_A-12 10 | 11 | TRAIN_BATCH_SIZE=36 12 | EVAL_BATCH_SIZE=12 13 | MAX_LENGTH=128 14 | 15 | 16 | OPTIMIZER=torch.adam 17 | LR_SCHEDULE=polydecay 18 | LR=3e-5 19 | 20 | BERT_DROPOUT=0.1 21 | ACC_GRAD=1 22 | MAX_EPOCH=5 23 | GRAD_CLIP=1.0 24 | WEIGHT_DECAY=0.002 25 | WARMUP_PROPORTION=0.1 26 | 27 | LOSS_TYPE=dice 28 | # ce, focal, dice 29 | DICE_SMOOTH=1 30 | DICE_OHEM=0 31 | DICE_ALPHA=0.01 32 | FOCAL_GAMMA=2 33 | 34 | PRECISION=32 35 | PROGRESS_BAR=1 36 | VAL_CHECK_INTERVAL=0.25 37 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 38 | 39 | 40 | if [[ ${LOSS_TYPE} == "bce" ]]; then 41 | LOSS_SIGN=${LOSS_TYPE} 42 | elif [[ ${LOSS_TYPE} == "focal" ]]; then 43 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA} 44 | elif [[ ${LOSS_TYPE} == "dice" ]]; then 45 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA} 46 | fi 47 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}" 48 | 49 | OUTPUT_BASE_DIR=/data/xiaoya/outputs/dice_loss/tnews 50 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${LOSS_SIGN} 51 | 52 | mkdir -p ${OUTPUT_DIR} 53 | 54 | CUDA_VISIBLE_DEVICES=0 python ${REPO_PATH}/tasks/tnews/train.py \ 55 | --gpus="1" \ 56 | --precision=${PRECISION} \ 57 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 58 | --eval_batch_size ${EVAL_BATCH_SIZE} \ 59 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 60 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 61 | --max_length ${MAX_LENGTH} \ 62 | --optimizer ${OPTIMIZER} \ 63 | --data_dir ${DATA_DIR} \ 64 | --bert_hidden_dropout ${BERT_DROPOUT} \ 65 | --bert_config_dir ${BERT_DIR} \ 66 | --lr ${LR} \ 67 | --lr_scheduler ${LR_SCHEDULE} \ 68 | --accumulate_grad_batches ${ACC_GRAD} \ 69 | --default_root_dir ${OUTPUT_DIR} \ 70 | --output_dir ${OUTPUT_DIR} \ 71 | --max_epochs ${MAX_EPOCH} \ 72 | --gradient_clip_val ${GRAD_CLIP} \ 73 | --weight_decay ${WEIGHT_DECAY} \ 74 | --loss_type ${LOSS_TYPE} \ 75 | --dice_smooth ${DICE_SMOOTH} \ 76 | --dice_ohem ${DICE_OHEM} \ 77 | --dice_alpha ${DICE_ALPHA} \ 78 | --dice_square \ 79 | --warmup_proportion ${WARMUP_PROPORTION} -------------------------------------------------------------------------------- /scripts/textcl_tnews/bert_focal.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | FILE_NAME=tnews_focal 6 | REPO_PATH=/userhome/xiaoya/dice_loss_for_NLP 7 | MODEL_SCALE=base 8 | DATA_DIR=/userhome/xiaoya/dataset/tnews 9 | BERT_DIR=/userhome/xiaoya/bert/chinese_bert 10 | 11 | TRAIN_BATCH_SIZE=18 12 | EVAL_BATCH_SIZE=12 13 | MAX_LENGTH=128 14 | 15 | OPTIMIZER=torch.adam 16 | LR_SCHEDULE=linear 17 | LR=3e-5 18 | 19 | BERT_DROPOUT=0.2 20 | ACC_GRAD=1 21 | MAX_EPOCH=5 22 | GRAD_CLIP=1.0 23 | WEIGHT_DECAY=0.002 24 | WARMUP_PROPORTION=0.02 25 | 26 | LOSS_TYPE=focal 27 | # ce, focal, dice 28 | DICE_SMOOTH=1 29 | DICE_OHEM=1 30 | DICE_ALPHA=0.01 31 | FOCAL_GAMMA=4 32 | 33 | PRECISION=16 34 | PROGRESS_BAR=1 35 | VAL_CHECK_INTERVAL=0.25 36 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 37 | 38 | if [[ ${LOSS_TYPE} == "ce" ]]; then 39 | LOSS_SIGN=${LOSS_TYPE} 40 | elif [[ ${LOSS_TYPE} == "focal" ]]; then 41 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA} 42 | elif [[ ${LOSS_TYPE} == "dice" ]]; then 43 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA} 44 | fi 45 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}" 46 | 47 | OUTPUT_BASE_DIR=/userhome/xiaoya/outputs/dice_loss/tnews 48 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${LOSS_SIGN} 49 | 50 | mkdir -p ${OUTPUT_DIR} 51 | 52 | CUDA_VISIBLE_DEVICES=0 python ${REPO_PATH}/tasks/tnews/train.py \ 53 | --gpus="1" \ 54 | --precision=${PRECISION} \ 55 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 56 | --eval_batch_size ${EVAL_BATCH_SIZE} \ 57 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 58 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 59 | --max_length ${MAX_LENGTH} \ 60 | --optimizer ${OPTIMIZER} \ 61 | --data_dir ${DATA_DIR} \ 62 | --bert_hidden_dropout ${BERT_DROPOUT} \ 63 | --bert_config_dir ${BERT_DIR} \ 64 | --lr ${LR} \ 65 | --lr_scheduler ${LR_SCHEDULE} \ 66 | --accumulate_grad_batches ${ACC_GRAD} \ 67 | --default_root_dir ${OUTPUT_DIR} \ 68 | --output_dir ${OUTPUT_DIR} \ 69 | --max_epochs ${MAX_EPOCH} \ 70 | --gradient_clip_val ${GRAD_CLIP} \ 71 | --weight_decay ${WEIGHT_DECAY} \ 72 | --loss_type ${LOSS_TYPE} \ 73 | --focal_gamma ${FOCAL_GAMMA} \ 74 | --warmup_proportion ${WARMUP_PROPORTION} -------------------------------------------------------------------------------- /tasks/glue/download_glue_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' Script for downloading all GLUE data. 4 | Note: for legal reasons, we are unable to host MRPC. 5 | You can either use the version hosted by the SentEval team, which is already tokenized, 6 | or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually. 7 | For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example). 8 | You should then rename and place specific files in a folder (see below for an example). 9 | mkdir MRPC 10 | cabextract MSRParaphraseCorpus.msi -d MRPC 11 | cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt 12 | cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt 13 | rm MRPC/_* 14 | rm MSRParaphraseCorpus.msi 15 | 1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now. 16 | 2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray! 17 | ''' 18 | 19 | import os 20 | import sys 21 | import shutil 22 | import argparse 23 | import tempfile 24 | import urllib.request 25 | import zipfile 26 | 27 | import os 28 | import sys 29 | import shutil 30 | import argparse 31 | import tempfile 32 | import urllib.request 33 | import zipfile 34 | 35 | 36 | TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"] 37 | TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4', 38 | "SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', 39 | "MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc', 40 | "QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5', 41 | "STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5', 42 | "MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce', 43 | "SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df', 44 | "QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601', 45 | "RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb', 46 | "WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf', 47 | "diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'} 48 | 49 | MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt' 50 | MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt' 51 | 52 | def download_and_extract(task, data_dir): 53 | print("Downloading and extracting %s..." % task) 54 | data_file = "%s.zip" % task 55 | urllib.request.urlretrieve(TASK2PATH[task], data_file) 56 | with zipfile.ZipFile(data_file) as zip_ref: 57 | zip_ref.extractall(data_dir) 58 | os.remove(data_file) 59 | print("\tCompleted!") 60 | 61 | def format_mrpc(data_dir, path_to_data): 62 | print("Processing MRPC...") 63 | mrpc_dir = os.path.join(data_dir, "MRPC") 64 | if not os.path.isdir(mrpc_dir): 65 | os.mkdir(mrpc_dir) 66 | if path_to_data: 67 | mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt") 68 | mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt") 69 | else: 70 | print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN) 71 | mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt") 72 | mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt") 73 | urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file) 74 | urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file) 75 | assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file 76 | assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file 77 | urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv")) 78 | 79 | dev_ids = [] 80 | with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh: 81 | for row in ids_fh: 82 | dev_ids.append(row.strip().split('\t')) 83 | 84 | with open(mrpc_train_file, encoding="utf8") as data_fh, \ 85 | open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \ 86 | open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh: 87 | header = data_fh.readline() 88 | train_fh.write(header) 89 | dev_fh.write(header) 90 | for row in data_fh: 91 | label, id1, id2, s1, s2 = row.strip().split('\t') 92 | if [id1, id2] in dev_ids: 93 | dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 94 | else: 95 | train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 96 | 97 | with open(mrpc_test_file, encoding="utf8") as data_fh, \ 98 | open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh: 99 | header = data_fh.readline() 100 | test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n") 101 | for idx, row in enumerate(data_fh): 102 | label, id1, id2, s1, s2 = row.strip().split('\t') 103 | test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2)) 104 | print("\tCompleted!") 105 | 106 | def download_diagnostic(data_dir): 107 | print("Downloading and extracting diagnostic...") 108 | if not os.path.isdir(os.path.join(data_dir, "diagnostic")): 109 | os.mkdir(os.path.join(data_dir, "diagnostic")) 110 | data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv") 111 | urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file) 112 | print("\tCompleted!") 113 | return 114 | 115 | def get_tasks(task_names): 116 | task_names = task_names.split(',') 117 | if "all" in task_names: 118 | tasks = TASKS 119 | else: 120 | tasks = [] 121 | for task_name in task_names: 122 | assert task_name in TASKS, "Task %s not found!" % task_name 123 | tasks.append(task_name) 124 | return tasks 125 | 126 | def main(arguments): 127 | parser = argparse.ArgumentParser() 128 | parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data') 129 | parser.add_argument('--tasks', help='tasks to download data for as a comma separated string', 130 | type=str, default='all') 131 | parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt', 132 | type=str, default='') 133 | args = parser.parse_args(arguments) 134 | 135 | if not os.path.isdir(args.data_dir): 136 | os.mkdir(args.data_dir) 137 | tasks = get_tasks(args.tasks) 138 | 139 | for task in tasks: 140 | print(task) 141 | if task == 'MRPC': 142 | format_mrpc(args.data_dir, args.path_to_mrpc) 143 | elif task == 'diagnostic': 144 | download_diagnostic(args.data_dir) 145 | else: 146 | download_and_extract(task, args.data_dir) 147 | 148 | 149 | if __name__ == '__main__': 150 | sys.exit(main(sys.argv[1:])) 151 | # python3 download_glue_data.py --data_dir /data/xiaoya/datasets/glue --tasks QQP 152 | -------------------------------------------------------------------------------- /tasks/glue/evaluate_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: glue/evaluate.py 5 | # description: 6 | # code for evaluating saved model checkpoints. 7 | 8 | import os 9 | import argparse 10 | from utils.random_seed import set_random_seed 11 | set_random_seed(0) 12 | 13 | from pytorch_lightning import Trainer 14 | from tasks.glue.train import BertForGLUETask 15 | from utils.get_parser import get_parser 16 | 17 | 18 | def init_evaluate_parser(parser) -> argparse.ArgumentParser: 19 | parser.add_argument("--path_to_model_checkpoint", type=str, help="") 20 | parser.add_argument("--path_to_model_hparams_file", type=str, default="") 21 | return parser 22 | 23 | 24 | def evaluate(args): 25 | trainer = Trainer(gpus=args.gpus, 26 | distributed_backend=args.distributed_backend, 27 | deterministic=True) 28 | model = BertForGLUETask.load_from_checkpoint( 29 | checkpoint_path=args.path_to_model_checkpoint, 30 | hparams_file=args.path_to_model_hparams_file, 31 | map_location=None, 32 | batch_size=args.eval_batch_size 33 | ) 34 | trainer.test(model=model) 35 | 36 | 37 | def main(): 38 | eval_parser = get_parser() 39 | eval_parser = init_evaluate_parser(eval_parser) 40 | eval_parser = BertForGLUETask.add_model_specific_args(eval_parser) 41 | eval_parser = Trainer.add_argparse_args(eval_parser) 42 | args = eval_parser.parse_args() 43 | 44 | if len(args.path_to_model_hparams_file) == 0: 45 | eval_output_dir = "/".join(args.path_to_model_checkpoint.split("/")[:-1]) 46 | args.path_to_model_hparams_file = os.path.join(eval_output_dir, "lightning_logs", "version_0", "hparams.yaml") 47 | 48 | evaluate(args) 49 | 50 | 51 | if __name__ == "__main__": 52 | main() 53 | 54 | -------------------------------------------------------------------------------- /tasks/glue/evaluate_predictions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: evaluate_predictions.py 5 | # description: 6 | # 7 | 8 | import os 9 | import sys 10 | import csv 11 | from glob import glob 12 | from metrics.functional.cls_acc_f1 import compute_acc_f1_from_list 13 | 14 | def eval_single_file(file_path, quotechar=None, num_labels=2): 15 | with open(file_path, "r") as r_f: 16 | data_lines = list(csv.reader(r_f, delimiter="\t", quotechar=quotechar)) 17 | # id \t pred \t gold 18 | 19 | pred_labels = [] 20 | gold_labels = [] 21 | for idx, data_line in enumerate(data_lines): 22 | if idx == 0: 23 | continue 24 | pred_labels.append(data_line[1]) 25 | gold_labels.append(data_line[2]) 26 | acc, f1, precision, recall = compute_acc_f1_from_list(pred_labels, gold_labels, num_labels=num_labels) 27 | print(f"acc: {acc}, f1: {f1}, precision: {precision}, recall: {recall}") 28 | 29 | 30 | def eval_files_in_folder(folder_path, prefix="dev", quotechar=None): 31 | file_lst = glob(os.path.join(folder_path, f"{prefix}-*.txt")) 32 | file_lst = sorted(file_lst) 33 | 34 | best_f1 = 0 35 | acc_when_best_f1 = 0 36 | best_file = "" 37 | for file_item in file_lst: 38 | acc, f1, precision, recall = eval_single_file(file_item) 39 | if f1 > best_f1: 40 | best_f1 = f1 41 | acc_when_best_f1 = acc 42 | best_file = file_item 43 | print(f"INFO -> {file_item}") 44 | print(f"INFO -> acc: {acc}, f1: {f1}, precision: {precision}, recall: {recall}") 45 | 46 | print(f"Summary INFO -> Best f1: {best_f1}, acc: {acc_when_best_f1}") 47 | print(f"Summary INFO -> Best file: {best_file}") 48 | 49 | if __name__ == "__main__": 50 | eval_folder_or_file = sys.argv[1] 51 | if eval_folder_or_file.endswith('.txt'): 52 | eval_single_file(eval_folder_or_file) 53 | else: 54 | eval_files_in_folder(eval_folder_or_file) 55 | -------------------------------------------------------------------------------- /tasks/glue/mrpc_dev_ids.tsv: -------------------------------------------------------------------------------- 1 | 1606495 1606619 2 | 3444633 3444733 3 | 1196962 1197061 4 | 1720166 1720115 5 | 2546175 2546198 6 | 2167771 2167744 7 | 2688145 2688162 8 | 2095803 2095786 9 | 1629064 1629043 10 | 116294 116332 11 | 722278 722383 12 | 1912524 1912648 13 | 726966 726945 14 | 936978 937500 15 | 2906104 2906322 16 | 1095780 1095652 17 | 487951 488007 18 | 3294207 3294290 19 | 1744257 1744378 20 | 1490044 1489975 21 | 2896308 2896334 22 | 2907762 2907649 23 | 101747 101777 24 | 1490811 1490840 25 | 2065523 2065836 26 | 886904 887158 27 | 2204592 2204588 28 | 3084554 3084612 29 | 917965 918315 30 | 2396937 2396818 31 | 3147370 3147525 32 | 71501 71627 33 | 1166473 1166857 34 | 1353356 1353174 35 | 1876120 1876059 36 | 1708040 1708062 37 | 2265271 2265152 38 | 2224884 2224819 39 | 2551891 2551563 40 | 360875 360943 41 | 2320654 2320666 42 | 3013483 3013540 43 | 1224743 1225510 44 | 58540 58567 45 | 1015249 1015204 46 | 3039310 3039413 47 | 726399 726078 48 | 2029631 2029565 49 | 1461629 1461781 50 | 1465073 1464854 51 | 283751 283290 52 | 799346 799268 53 | 533823 533909 54 | 801552 801516 55 | 2829648 2829613 56 | 2074182 2074668 57 | 1958079 1958143 58 | 3140260 3140288 59 | 969381 969512 60 | 2307064 2307235 61 | 271891 271839 62 | 1010655 1010430 63 | 1954 2142 64 | 985015 984975 65 | 218848 218851 66 | 753928 753890 67 | 555617 555528 68 | 487993 487952 69 | 3428298 3428362 70 | 389239 389299 71 | 2249237 2249305 72 | 970740 971209 73 | 2988297 2988555 74 | 2204353 2204418 75 | 544217 544325 76 | 2083598 2083810 77 | 263690 263819 78 | 2587300 2587243 79 | 3070979 3070949 80 | 953733 953537 81 | 2561999 2561941 82 | 684848 684557 83 | 162203 162101 84 | 958161 957782 85 | 2926039 2925982 86 | 2112330 2112376 87 | 3329379 3329416 88 | 961836 962243 89 | 1808166 1808434 90 | 2632692 2632767 91 | 2652187 2652218 92 | 1430357 1430425 93 | 144089 143697 94 | 661390 661218 95 | 325763 325928 96 | 3320577 3320553 97 | 2111629 2111786 98 | 2222998 2223097 99 | 205100 205145 100 | 533903 533818 101 | 629316 629289 102 | 1033204 1033365 103 | 3180014 3179967 104 | 2673104 2673130 105 | 349215 349241 106 | 129995 129864 107 | 368067 368018 108 | 655498 655391 109 | 2139506 2139427 110 | 1661381 1661317 111 | 426112 426210 112 | 1990975 1991132 113 | 162632 162653 114 | 1889954 1889847 115 | 347017 347002 116 | 3150803 3150839 117 | 173879 173832 118 | 431076 431242 119 | 1261116 1261234 120 | 1636060 1635946 121 | 2252795 2252970 122 | 1128884 1128865 123 | 3214517 3214483 124 | 1014983 1014963 125 | 1057995 1057778 126 | 2587767 2587673 127 | 539585 539355 128 | 1756329 1756394 129 | 2963943 2963880 130 | 3242051 3241897 131 | 977772 977804 132 | 2131318 2131372 133 | 103280 103431 134 | 581592 581570 135 | 2083612 2083810 136 | 2155514 2155377 137 | 723557 724115 138 | 1351550 1351155 139 | 611663 611716 140 | 886618 886456 141 | 1989515 1989458 142 | 224932 224868 143 | 751520 751373 144 | 396041 396188 145 | 55187 54831 146 | 132553 132725 147 | 1673112 1673068 148 | 1638813 1639087 149 | 2208376 2208198 150 | 849291 849442 151 | 2638861 2638982 152 | 1967578 1967664 153 | 781439 781461 154 | 1119721 1119714 155 | 34513 34742 156 | 644788 644816 157 | 2495223 2495307 158 | 954526 954607 159 | 195728 196099 160 | 2010705 2010779 161 | 1277539 1277527 162 | 2638975 2638855 163 | 2357324 2357271 164 | 2763381 2763517 165 | 1597193 1597119 166 | 555553 555528 167 | 2796658 2796682 168 | 101746 101775 169 | 2116843 2116883 170 | 1100998 1100441 171 | 3255597 3255668 172 | 1909579 1909408 173 | 2919853 2919804 174 | 315785 315653 175 | 1264509 1264471 176 | 3439114 3439084 177 | 3062202 3062308 178 | 2614947 2614904 179 | 1462409 1462504 180 | 2199097 2199072 181 | 331980 332110 182 | 3267026 3266930 183 | 698948 698933 184 | 461779 461815 185 | 1910610 1910455 186 | 389117 389052 187 | 789691 789665 188 | 1348909 1348954 189 | 261202 260995 190 | 2820371 2820525 191 | 696677 696932 192 | 54181 53570 193 | 589579 589557 194 | 2128530 2128455 195 | 3113791 3113782 196 | 637168 637447 197 | 490355 490378 198 | 780604 780466 199 | 219064 218969 200 | 2823575 2823513 201 | 3181118 3181443 202 | 485999 486011 203 | 2304696 2304863 204 | 2916199 2916164 205 | 2829194 2829229 206 | 1167835 1167651 207 | 1438666 1438643 208 | 98432 98657 209 | 249699 249623 210 | 347022 347003 211 | 2749410 2749625 212 | 2517014 2516995 213 | 2766112 2766084 214 | 2198694 2198937 215 | 548867 548785 216 | 2758265 2758282 217 | 981185 981234 218 | 1354501 1354476 219 | 2758944 2758975 220 | 1865364 1865251 221 | 131979 131957 222 | 490376 490490 223 | 146112 146127 224 | 2763517 2763576 225 | 327839 327748 226 | 3111452 3111428 227 | 1831696 1831660 228 | 515581 515752 229 | 315647 315778 230 | 1783137 1782659 231 | 1393764 1393984 232 | 1980654 1980641 233 | 1989213 1989116 234 | 3022833 3023029 235 | 86007 86373 236 | 1685339 1685429 237 | 1592037 1592076 238 | 2493369 2493428 239 | 1726935 1726879 240 | 3389318 3389271 241 | 3394891 3394775 242 | 2324704 2325023 243 | 2455942 2455978 244 | 192285 192327 245 | 3400796 3400822 246 | 4733 4557 247 | 1050307 1050144 248 | 1112021 1111925 249 | 3376093 3376101 250 | 816867 816831 251 | 3218713 3218830 252 | 1864253 1863810 253 | 3107137 3107119 254 | 3039165 3039036 255 | 3039007 3038845 256 | 2015389 2015410 257 | 1605818 1605806 258 | 2796978 2797024 259 | 1201306 1201329 260 | 2339738 2339771 261 | 3300040 3299992 262 | 2749322 2749663 263 | 2745055 2745022 264 | 3046488 3046824 265 | 2241925 2242066 266 | 86020 86007 267 | 69773 69792 268 | 1057876 1057778 269 | 2965576 2965701 270 | 577854 578500 271 | 221515 221509 272 | 587009 586969 273 | 3264732 3264648 274 | 3023029 3023229 275 | 2523564 2523358 276 | 1552068 1551928 277 | 1439663 1439808 278 | 2377289 2377259 279 | 2283737 2283794 280 | 588637 588864 281 | 1825432 1825301 282 | 2748287 2748550 283 | 1704987 1705268 284 | 54142 53641 285 | 2259788 2259747 286 | 2090911 2091154 287 | 3093023 3092996 288 | 3122429 3122305 289 | 1521034 1520582 290 | 2324708 2325028 291 | 3261484 3261306 292 | 1675025 1675047 293 | 2317018 2317252 294 | 3448488 3448449 295 | 1762569 1762526 296 | 1602860 1602844 297 | 1824224 1824209 298 | 2640607 2640576 299 | 2697659 2697747 300 | 2440680 2440474 301 | 818091 817811 302 | 853475 853342 303 | 2175939 2176090 304 | 314997 315030 305 | 1220668 1220801 306 | 554905 554627 307 | 3035788 3035918 308 | 383417 383558 309 | 1089053 1089297 310 | 1831453 1831491 311 | 2274844 2274714 312 | 2706154 2706185 313 | 2889005 2888954 314 | 1355540 1355592 315 | 2380695 2380822 316 | 1616174 1616206 317 | 1528383 1528083 318 | 635783 635802 319 | 1580638 1580663 320 | 1549586 1549609 321 | 2826681 2826474 322 | 221079 221003 323 | 720572 720486 324 | 3311600 3311633 325 | 460211 460445 326 | 2385288 2385256 327 | 1908763 1908744 328 | 2996241 2996734 329 | 2691044 2691264 330 | 1386884 1386857 331 | 2977500 2977547 332 | 1330643 1330622 333 | 2240399 2240149 334 | 2931098 2931144 335 | 919683 919782 336 | 60122 60445 337 | 805457 805985 338 | 3435735 3435717 339 | 110731 110648 340 | 524136 524119 341 | 3439854 3439874 342 | 2008984 2009175 343 | 260952 260924 344 | 844421 844679 345 | 872784 872834 346 | 1423836 1423708 347 | 2079200 2079131 348 | 753858 753890 349 | 787432 787464 350 | 2110220 2110199 351 | 1186754 1187056 352 | 2110775 2110924 353 | 780408 780363 354 | 52758 52343 355 | 763948 763991 356 | 2810634 2810670 357 | 2584416 2584653 358 | 2268396 2268480 359 | 447728 447699 360 | 2573262 2573319 361 | 1550897 1550977 362 | 941617 941673 363 | 3310210 3310286 364 | 2494149 2494073 365 | 1619244 1619274 366 | 2531749 2531607 367 | 374015 374162 368 | 2221603 2221633 369 | 2362761 2362698 370 | 2834988 2835026 371 | 1605350 1605425 372 | 1630585 1630657 373 | 3464314 3464302 374 | 2842562 2842582 375 | 1076861 1077018 376 | 3028143 3028234 377 | 518089 518133 378 | 2336453 2336545 379 | 3061836 3062031 380 | 2738677 2738741 381 | 2046630 2046644 382 | 1919740 1919926 383 | 1721433 1721267 384 | 1269572 1269682 385 | 1771131 1771091 386 | 1757264 1757375 387 | 1984039 1983986 388 | 1609290 1609098 389 | 2728425 2728251 390 | 2020252 2020081 391 | 665419 665612 392 | 2945693 2945847 393 | 2217613 2217659 394 | 2530671 2530542 395 | 2607718 2607708 396 | 1015010 1014963 397 | 1513190 1513246 398 | 969512 969295 399 | 1657632 1657619 400 | 2385348 2385394 401 | 821523 821385 402 | 2577517 2577531 403 | 862804 862715 404 | 977938 978162 405 | 3073773 3073779 406 | 3107118 3107136 407 | 2047034 2046820 408 | 308567 308525 409 | -------------------------------------------------------------------------------- /tasks/glue/process_mrpc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: process_mrpc.py 5 | 6 | import os 7 | import sys 8 | 9 | def process_mrpc_data(data_dir, dev_ids_file): 10 | print("Processing MRPC...") 11 | mrpc_train_file = os.path.join(data_dir, "msr_paraphrase_train.txt") 12 | mrpc_test_file = os.path.join(data_dir, "msr_paraphrase_test.txt") 13 | 14 | assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file 15 | assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file 16 | 17 | dev_ids = [] 18 | with open(dev_ids_file, encoding="utf8") as ids_fh: 19 | for row in ids_fh: 20 | dev_ids.append(row.strip().split('\t')) 21 | 22 | with open(mrpc_train_file, encoding="utf8") as data_fh, \ 23 | open(os.path.join(data_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \ 24 | open(os.path.join(data_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh: 25 | header = data_fh.readline() 26 | train_fh.write(header) 27 | dev_fh.write(header) 28 | for row in data_fh: 29 | label, id1, id2, s1, s2 = row.strip().split('\t') 30 | if [id1, id2] in dev_ids: 31 | dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 32 | else: 33 | train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 34 | 35 | with open(mrpc_test_file, encoding="utf8") as data_fh, \ 36 | open(os.path.join(data_dir, "test.tsv"), 'w', encoding="utf8") as test_fh: 37 | header = data_fh.readline() 38 | test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n") 39 | for idx, row in enumerate(data_fh): 40 | label, id1, id2, s1, s2 = row.strip().split('\t') 41 | test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2)) 42 | print("\tCompleted!") 43 | 44 | 45 | if __name__ == "__main__": 46 | data_dir = sys.argv[1] 47 | path_to_dev_ids = sys.argv[2] 48 | process_mrpc_data(data_dir, path_to_dev_ids) -------------------------------------------------------------------------------- /tasks/mrc_ner/data_preprocess/file_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | from tasks.mrc_ner.data_preprocess.label_utils import iob_iobes 6 | 7 | def load_conll03_sentences(data_path): 8 | dataset = [] 9 | with open(data_path, "r") as f: 10 | words, tags = [], [] 11 | # for each line of the file correspond to one word and tag 12 | for line_idx, line in enumerate(f): 13 | if line != "\n" and ("-DOCSTART-" not in line): 14 | line = line.strip() 15 | if " " not in line: 16 | continue 17 | try: 18 | word, pos_cat, pos_label, tag = line.split(" ") 19 | word = word.strip() 20 | tag = tag.strip() 21 | except: 22 | print(line) 23 | continue 24 | 25 | if len(word) > 0 and len(tag) > 0: 26 | word, tag = str(word), str(tag) 27 | words.append(word) 28 | tags.append(tag) 29 | else: 30 | if len(words) > 0 and line_idx != 0: 31 | assert len(words) == len(tags) 32 | dataset.append((words, tags)) 33 | words, tags = [], [] 34 | 35 | tokens = [data[0] for data in dataset] 36 | labels = [iob_iobes(data[1]) for data in dataset] 37 | dataset = [(data_tokens, data_labels) for data_tokens, data_labels in zip(tokens, labels)] 38 | return dataset 39 | 40 | def load_conll03_documents(data_path): 41 | dataset = [] 42 | with open(data_path, "r") as f: 43 | words, tags = [], [] 44 | # for each line of the file correspond to one word and tag 45 | for line_idx, line in enumerate(f): 46 | if "-DOCSTART-" not in line: 47 | line = line.strip() 48 | if " " not in line: 49 | continue 50 | try: 51 | word, pos_cat, pos_label, tag = line.split(" ") 52 | word = word.strip() 53 | tag = tag.strip() 54 | except: 55 | print(line) 56 | continue 57 | 58 | if len(word) > 0 and len(tag) > 0: 59 | word, tag = str(word), str(tag) 60 | words.append(word) 61 | tags.append(tag) 62 | else: 63 | if len(words) > 0 and line_idx != 0: 64 | assert len(words) == len(tags) 65 | dataset.append((words, tags)) 66 | words, tags = [], [] 67 | 68 | tokens = [data[0] for data in dataset] 69 | labels = [iob_iobes(data[1]) for data in dataset] 70 | dataset = [(data_tokens, data_labels) for data_tokens, data_labels in zip(tokens, labels)] 71 | return dataset 72 | 73 | def export_conll(sentence, label, export_file_path, dim=2): 74 | """ 75 | Args: 76 | sentence: a list of sentece of chars [["北", "京", "天", "安", "门"], ["真", "相", "警", 告"]] 77 | label: a list of labels [["B", "M", "E", "S", "O"], ["O", "O", "S", "S"]] 78 | Desc: 79 | export tagging data into conll format 80 | """ 81 | with open(export_file_path, "w") as f: 82 | for idx, (sent_item, label_item) in enumerate(zip(sentence, label)): 83 | for char_idx, (tmp_char, tmp_label) in enumerate(zip(sent_item, label_item)): 84 | f.write("{} {}\n".format(tmp_char, tmp_label)) 85 | f.write("\n") 86 | 87 | 88 | def load_conll(data_path): 89 | """ 90 | Desc: 91 | load data in conll format 92 | Returns: 93 | [([word1, word2, word3, word4], [label1, label2, label3, label4]), 94 | ([word5, word6, word7, wordd8], [label5, label6, label7, label8])] 95 | """ 96 | dataset = [] 97 | with open(data_path, "r") as f: 98 | words, tags = [], [] 99 | # for each line of the file correspond to one word and tag 100 | for line in f: 101 | if line != "\n": 102 | # line = line.strip() 103 | word, tag = line.split(" ") 104 | word = word.strip() 105 | tag = tag.strip() 106 | try: 107 | if len(word) > 0 and len(tag) > 0: 108 | word, tag = str(word), str(tag) 109 | words.append(word) 110 | tags.append(tag) 111 | except Exception as e: 112 | print("an exception was raise! skipping a word") 113 | else: 114 | if len(words) > 0: 115 | assert len(words) == len(tags) 116 | dataset.append((words, tags)) 117 | words, tags = [], [] 118 | 119 | return dataset 120 | 121 | def dump_tsv(data_lines, data_path): 122 | """ 123 | Desc: 124 | dump data into tsv format for TAGGING data 125 | Input: 126 | the format of data_lines is: 127 | [([word1, word2, word3, word4], [label1, label2, label3, label4]), 128 | ([word5, word6, word7, word8, word9], [label5, label6, label7, label8, label9]), 129 | ([word10, word11, word12, ], [label10, label11, label12])] 130 | """ 131 | print("dump dataliens into TSV format : ") 132 | with open(data_path, "w") as f: 133 | for data_item in data_lines: 134 | data_word, data_tag = data_item 135 | data_str = " ".join(data_word) 136 | data_tag = " ".join(data_tag) 137 | f.write(data_str + "\t" + data_tag + "\n") 138 | print("dump data set into data path") 139 | print(data_path) 140 | 141 | 142 | if __name__ == "__main__": 143 | import os 144 | repo_path = "/".join(os.path.realpath(__file__).split("/")[:-4]) 145 | print(repo_path) 146 | conll_data_file = os.path.join(repo_path, "tests", "data", "enconll03", "test.txt") 147 | conll_dataset = load_conll03_documents(conll_data_file) 148 | doc_len = [len(tmp[0]) for tmp in conll_dataset] 149 | # number of doc is 230 150 | print(f"NUM OF DOC -> {len(doc_len)}") 151 | print(f"AGV -> {sum(doc_len)/ float(len(doc_len))}") 152 | print(f"MAX -> {max(doc_len)}") 153 | print(f"MIN -> {min(doc_len)}") 154 | print(f"512 -> {len([tmp for tmp in doc_len if tmp >= 500])}") -------------------------------------------------------------------------------- /tasks/mrc_ner/data_preprocess/label_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # description: 5 | # utilies for sequence tagging tasks for entity-level tasks 6 | # (such as NER) 7 | 8 | 9 | def get_bmes(span_labels, length, encoding): 10 | tags = ["O" for _ in range(length)] 11 | 12 | for start, end in span_labels: 13 | for i in range(start+1, end+1): 14 | tags[i] = "M" 15 | if "E" in encoding: 16 | tags[end] = "E" 17 | if "B" in encoding: 18 | tags[start] = "B" 19 | if "S" in encoding and start == end: 20 | tags[start] = "S" 21 | return tags 22 | 23 | 24 | def get_span_labels(sentence_tags, inv_label_mapping=None): 25 | """ 26 | Desc: 27 | get from token_level labels to list of entities, 28 | it doesnot matter tagging scheme is BMES or BIO or BIOUS 29 | Returns: 30 | a list of entities 31 | [(start, end, labels), (start, end, labels)] 32 | """ 33 | 34 | if inv_label_mapping: 35 | sentence_tags = [inv_label_mapping[i] for i in sentence_tags] 36 | 37 | span_labels = [] 38 | last = "O" 39 | start = -1 40 | for i, tag in enumerate(sentence_tags): 41 | pos, _ = (None, "O") if tag == "O" else tag.split("-") 42 | if (pos == "S" or pos == "B" or tag == "O") and last != "O": 43 | span_labels.append((start, i - 1, last.split("-")[-1])) 44 | if pos == "B" or pos == "S" or last == "O": 45 | start = i 46 | last = tag 47 | 48 | if sentence_tags[-1] != "O": 49 | span_labels.append((start, len(sentence_tags) -1 , sentence_tags[-1].split("-"[-1]))) 50 | 51 | return span_labels 52 | 53 | 54 | 55 | def get_tags(span_labels, length, encoding): 56 | """ 57 | Desc: 58 | convert a list of entities to token-level labels based on the provided encoding (e.g., BMOES) 59 | Please notice that the left and right bounaries are involved. 60 | """ 61 | tags = ["O" for _ in range(length)] 62 | 63 | for start, end, tag in span_labels: 64 | for i in range(start, end + 1): 65 | tags[i] = "M-" + tag 66 | 67 | if "E" in encoding: 68 | tags[end] = "E-" + tag 69 | if "B" in encoding: 70 | tags[start] = "B-" + tag 71 | if "S" in encoding and start == end: 72 | tags[start] = "S-" + tag 73 | return tags 74 | 75 | 76 | 77 | def iob_iobes(tags): 78 | """ 79 | Desc: 80 | IOB -> IOBES 81 | """ 82 | new_tags = [] 83 | for i, tag in enumerate(tags): 84 | if tag == "O": 85 | new_tags.append(tag) 86 | elif tag.split("-")[0] == "B": 87 | if i + 1 != len(tags) and tags[i+1].split("-")[0] == "I": 88 | new_tags.append(tag) 89 | else: 90 | new_tags.append(tag.replace("B-", "S-")) 91 | elif tag.split("-")[0] == "I": 92 | if i + 1 < len(tags) and tags[i + 1].split("-")[0] == "I": 93 | new_tags.append(tag) 94 | else: 95 | new_tags.append(tag.replace("I-", "E-")) 96 | else: 97 | raise Exception("invalid IOB format !!") 98 | return new_tags 99 | 100 | 101 | 102 | if __name__ == "__main__": 103 | label_tags = ["O", "B-ORG", "M-ORG", "E-ORG", "B-PER", "M-PER", "E-PER"] 104 | span_labels = get_span_labels(label_tags, ) 105 | print("check the content of span_labels") 106 | print(span_labels) 107 | # [(1, 2, "ORG"), (5, 7, "PER")] 108 | 109 | # ------------------------- 110 | # test the functionality of get_tags 111 | # ------------------------- 112 | print("-*-"*10) 113 | print("check the content of span_labels") 114 | span_label = get_tags([(1, 3, "ORG"), (8, 10, "PER", )], "BIOES") 115 | print(span_label) 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /tasks/mrc_ner/data_preprocess/query_map.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: query_map.py 5 | # --------------------------------------------- 6 | # query collections for different dataset 7 | # --------------------------------------------- 8 | 9 | en_conll03_query = { 10 | "default": { 11 | "ORG": "organization entities are limited to named corporate, governmental, or other organizational entities.", 12 | "PER": "person entities are named persons or family.", 13 | "LOC": "location entities are the name of politically or geographically defined locations such as cities, provinces, countries, international regions, bodies of water, mountains, etc.", 14 | "MISC": "examples of miscellaneous entities include events, nationalities, products and works of art." 15 | }, 16 | "labels": ["ORG", "PER", "LOC", "MISC"] 17 | } 18 | 19 | queries_for_dataset = { 20 | "en_conll03": en_conll03_query 21 | } 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /tasks/mrc_ner/eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | PRECISION=32 6 | FILE_NAME=eval_mrc_ner 7 | REPO_PATH=/data/xiaoya/workspace/mrc-with-dice-loss 8 | CKPT_PATH=$1 9 | 10 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 11 | 12 | CUDA_VISIBLE_DEVICES=1 python3 ${REPO_PATH}/tasks/mrc_ner/evaluate.py \ 13 | --gpus="1" \ 14 | --precision=${PRECISION} \ 15 | --path_to_model_checkpoint ${CKPT_PATH} 16 | -------------------------------------------------------------------------------- /tasks/mrc_ner/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: tasks/mrc_ner/evaluate.py 5 | # 6 | 7 | import os 8 | import argparse 9 | from utils.random_seed import set_random_seed 10 | set_random_seed(2333) 11 | 12 | from pytorch_lightning import Trainer 13 | from tasks.mrc_ner.train import BertForNERTask 14 | from utils.get_parser import get_parser 15 | 16 | 17 | def init_evaluate_parser(parser) -> argparse.ArgumentParser: 18 | parser.add_argument("--path_to_model_checkpoint", type=str, help="") 19 | parser.add_argument("--path_to_model_hparams_file", type=str, default="") 20 | return parser 21 | 22 | 23 | def evaluate(args): 24 | trainer = Trainer(gpus=args.gpus, 25 | distributed_backend=args.distributed_backend, 26 | deterministic=True) 27 | model = BertForNERTask.load_from_checkpoint( 28 | checkpoint_path=args.path_to_model_checkpoint, 29 | hparams_file=args.path_to_model_hparams_file, 30 | map_location=None, 31 | batch_size=args.eval_batch_size 32 | ) 33 | trainer.test(model=model) 34 | 35 | 36 | def main(): 37 | eval_parser = get_parser() 38 | eval_parser = init_evaluate_parser(eval_parser) 39 | eval_parser = BertForNERTask.add_model_specific_args(eval_parser) 40 | eval_parser = Trainer.add_argparse_args(eval_parser) 41 | args = eval_parser.parse_args() 42 | 43 | if len(args.path_to_model_hparams_file) == 0: 44 | eval_output_dir = "/".join(args.path_to_model_checkpoint.split("/")[:-1]) 45 | args.path_to_model_hparams_file = os.path.join(eval_output_dir, "lightning_logs", "version_0", "hparams.yaml") 46 | 47 | evaluate(args) 48 | 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /tasks/mrc_ner/generate_mrc_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # transform data from sequence labeling to mrc formulation 5 | # ---------------------------------------------------------- 6 | # input data structure 7 | # --------------------------------------------------------- 8 | # this module is to generate mrc-style ner task. 9 | # 1. for flat ner, the input file follows conll and tagged in BMES schema 10 | # 2. for nested ner, the input file in json 11 | 12 | import os 13 | import sys 14 | import json 15 | 16 | repo_path = "/".join(os.path.realpath(__file__).split("/")[:-3]) 17 | print(repo_path) 18 | if repo_path not in sys.path: 19 | sys.path.insert(0, repo_path) 20 | 21 | from data_preprocess.file_utils import load_conll, load_conll03_documents, load_conll03_sentences 22 | from data_preprocess.label_utils import get_span_labels 23 | from data_preprocess.query_map import queries_for_dataset 24 | 25 | 26 | def generate_query_ner_dataset(source_file_path, dump_file_path, entity_sign="nested", 27 | dataset_name=None, query_sign="default"): 28 | """ 29 | Args: 30 | source_data_file: /data/genia/train.word.json | /data/msra/train.char.bmes 31 | dump_data_file: /data/genia-mrc/train.mrc.json | /data/msra-mrc/train.mrc.json 32 | dataset_name: one in ["en_ontonotes5", "en_conll03", ] 33 | entity_sign: one of ["nested", "flat"] 34 | query_sign: defualt is "default" 35 | """ 36 | entity_queries = queries_for_dataset[dataset_name][query_sign] 37 | label_lst = queries_for_dataset[dataset_name]["labels"] 38 | 39 | if entity_sign == "nested": 40 | with open(source_file_path, "r") as f: 41 | source_data = json.load(f) 42 | elif entity_sign == "flat": 43 | source_data = load_conll(source_file_path) 44 | elif entity_sign == "flat_enconll03_docs": 45 | source_data = load_conll03_documents(source_file_path) 46 | elif entity_sign == "flat_enconll03_sents": 47 | source_data = load_conll03_sentences(source_file_path) 48 | else: 49 | raise ValueError("ENTITY_SIGN can only be NESTED or FLAT.") 50 | 51 | target_data = transform_examples_to_qa_features(entity_queries, label_lst, source_data, entity_sign=entity_sign) 52 | 53 | with open(dump_file_path, "w") as f: 54 | json.dump(target_data, f, sort_keys=True, ensure_ascii=False, indent=2) 55 | 56 | 57 | def transform_examples_to_qa_features(query_map, entity_labels, data_instances, entity_sign="nested"): 58 | """ 59 | Desc: 60 | convert_examples to qa features 61 | Args: 62 | query_map: {entity label: entity query}; 63 | data_instance 64 | """ 65 | mrc_ner_dataset = [] 66 | 67 | if "flat" in entity_sign.lower(): 68 | tmp_qas_id = 0 69 | for idx, (word_lst, label_lst) in enumerate(data_instances): 70 | candidate_span_label = get_span_labels(label_lst) 71 | tmp_query_id = 0 72 | for label_idx, tmp_label in enumerate(entity_labels): 73 | tmp_query_id += 1 74 | tmp_query = query_map[tmp_label] 75 | tmp_context = " ".join(word_lst) 76 | 77 | tmp_start_pos = [] 78 | tmp_end_pos = [] 79 | tmp_entity_pos = [] 80 | 81 | start_end_label = [(start, end) for start, end, label_content in candidate_span_label if 82 | label_content == tmp_label] 83 | 84 | if len(start_end_label) != 0: 85 | for span_item in start_end_label: 86 | start_idx, end_idx = span_item 87 | tmp_start_pos.append(start_idx) 88 | tmp_end_pos.append(end_idx) 89 | tmp_entity_pos.append("{};{}".format(str(start_idx), str(end_idx))) 90 | tmp_impossible = False 91 | else: 92 | tmp_impossible = True 93 | 94 | mrc_ner_dataset.append({ 95 | "qas_id": "{}.{}".format(str(tmp_qas_id), str(tmp_query_id)), 96 | "query": tmp_query, 97 | "context": tmp_context, 98 | "entity_label": tmp_label, 99 | "start_position": tmp_start_pos, 100 | "end_position": tmp_end_pos, 101 | "span_position": tmp_entity_pos, 102 | "impossible": tmp_impossible 103 | }) 104 | tmp_qas_id += 1 105 | 106 | elif "nested" in entity_sign.lower(): 107 | tmp_qas_id = 0 108 | for idx, data_item in enumerate(data_instances): 109 | tmp_query_id = 0 110 | for label_idx, tmp_label in enumerate(entity_labels): 111 | tmp_query_id += 1 112 | tmp_query = query_map[tmp_label] 113 | tmp_context = data_item["context"] 114 | 115 | tmp_start_pos = [] 116 | tmp_end_pos = [] 117 | tmp_entity_pos = [] 118 | 119 | start_end_label = data_item["label"][tmp_label] if tmp_label in data_item["label"].keys() else -1 120 | 121 | if start_end_label == -1: 122 | tmp_impossible = True 123 | else: 124 | for start_end_item in data_item["label"][tmp_label]: 125 | start_idx, end_idx = [int(ix) for ix in start_end_item.split(",")] 126 | tmp_start_pos.append(start_idx) 127 | tmp_end_pos.append(end_idx) 128 | tmp_entity_pos.append(start_end_item) 129 | tmp_impossible = False 130 | 131 | mrc_ner_dataset.append({ 132 | "qas_id": "{}.{}".format(str(tmp_qas_id), str(tmp_query_id)), 133 | "query": tmp_query, 134 | "context": tmp_context, 135 | "entity_label": tmp_label, 136 | "start_position": tmp_start_pos, 137 | "end_position": tmp_end_pos, 138 | "span_position": tmp_entity_pos, 139 | "impossible": tmp_impossible 140 | }) 141 | tmp_qas_id += 1 142 | 143 | return mrc_ner_dataset 144 | 145 | 146 | if __name__ == "__main__": 147 | source_file_path = "/data/xiaoya/datasets/ner/conll2003_truecase/valid.txt" 148 | target_file_path = "/data/xiaoya/datasets/mrc_ner/en_conll03_truecase_doc/mrc-ner.dev" 149 | # generate_query_ner_dataset(source_file_path, target_file_path, entity_sign="flat_enconll03_document", dataset_name="en_conll03", query_sign="default") 150 | 151 | source_file_path = "/data/xiaoya/datasets/ner/conll2003_truecase/train.txt" 152 | target_file_path = "/data/xiaoya/datasets/mrc_ner/en_conll03_truecase_sent/mrc-ner.train" 153 | # generate_query_ner_dataset(source_file_path, target_file_path, entity_sign="flat_enconll03_sentences", dataset_name="en_conll03", query_sign="default") 154 | 155 | source_file_path = "/data/xiaoya/datasets/ner/conll2003/eng.train" 156 | target_file_path = "/data/xiaoya/datasets/mrc_ner/en_conll03_doc/mrc-ner.train" 157 | generate_query_ner_dataset(source_file_path, target_file_path, entity_sign="flat_enconll03_docs", 158 | dataset_name="en_conll03", query_sign="default") 159 | -------------------------------------------------------------------------------- /tasks/mrc_ner/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | TIME=2021.02.02 5 | FILE_NAME=debug_onto4 6 | REPO_PATH=/data/xiaoya/workspace/mrc-with-dice-loss 7 | MODEL_SCALE=base 8 | DATA_DIR=/data/nfsdata2/xiaoya/mrc_ner/zh_onto4 9 | BERT_DIR=/data/xiaoya/pretrain_lm/chinese_L-12_H-768_A-12 10 | 11 | TRAIN_BATCH_SIZE=5 12 | EVAL_BATCH_SIZE=12 13 | MAX_LENGTH=128 14 | 15 | OPTIMIZER=adamw 16 | LR_SCHEDULE=linear 17 | LR=3e-5 18 | 19 | BERT_DROPOUT=0.1 20 | ACC_GRAD=1 21 | MAX_EPOCH=5 22 | GRAD_CLIP=1.0 23 | WEIGHT_DECAY=0.002 24 | WARMUP_PROPORTION=0.002 25 | 26 | LOSS_TYPE=bce 27 | W_START=1 28 | W_END=1 29 | W_SPAN=1 30 | DICE_SMOOTH=1 31 | DICE_OHEM=0.8 32 | DICE_ALPHA=0.01 33 | FOCAL_GAMMA=2 34 | 35 | PRECISION=32 36 | PROGRESS_BAR=1 37 | VAL_CHECK_INTERVAL=0.25 38 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 39 | 40 | if [[ ${LOSS_TYPE} == "bce" ]]; then 41 | LOSS_SIGN=${LOSS_TYPE} 42 | elif [[ ${LOSS_TYPE} == "focal" ]]; then 43 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA} 44 | elif [[ ${LOSS_TYPE} == "dice" ]]; then 45 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA} 46 | fi 47 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}" 48 | 49 | OUTPUT_BASE_DIR=/data/xiaoya/outputs/dice_loss/mrc_ner/${TIME} 50 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${W_START}_${W_END}_${W_SPAN}_${LOSS_SIGN} 51 | 52 | mkdir -p ${OUTPUT_DIR} 53 | 54 | CUDA_VISIBLE_DEVICES=3 python ${REPO_PATH}/tasks/mrc_ner/train.py \ 55 | --gpus="1" \ 56 | --precision=${PRECISION} \ 57 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 58 | --eval_batch_size ${EVAL_BATCH_SIZE} \ 59 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 60 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 61 | --max_length ${MAX_LENGTH} \ 62 | --optimizer ${OPTIMIZER} \ 63 | --data_dir ${DATA_DIR} \ 64 | --bert_hidden_dropout ${BERT_DROPOUT} \ 65 | --bert_config_dir ${BERT_DIR} \ 66 | --lr ${LR} \ 67 | --lr_scheduler ${LR_SCHEDULE} \ 68 | --accumulate_grad_batches ${ACC_GRAD} \ 69 | --default_root_dir ${OUTPUT_DIR} \ 70 | --output_dir ${OUTPUT_DIR} \ 71 | --max_epochs ${MAX_EPOCH} \ 72 | --gradient_clip_val ${GRAD_CLIP} \ 73 | --weight_decay ${WEIGHT_DECAY} \ 74 | --loss_type ${LOSS_TYPE} \ 75 | --weight_start ${W_START} \ 76 | --weight_end ${W_END} \ 77 | --weight_span ${W_SPAN} \ 78 | --dice_smooth ${DICE_SMOOTH} \ 79 | --dice_ohem ${DICE_OHEM} \ 80 | --dice_alpha ${DICE_ALPHA} \ 81 | --dice_square \ 82 | --focal_gamma ${FOCAL_GAMMA} \ 83 | --warmup_proportion ${WARMUP_PROPORTION} \ 84 | --span_loss_candidates all \ 85 | --do_lower_case \ 86 | --is_chinese \ 87 | --flat_ner -------------------------------------------------------------------------------- /tasks/squad/evaluate_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: squad/evaluate_models.py 5 | # description: 6 | # evaluate the saved model checkpoints. 7 | 8 | import os 9 | from tasks.squad.train import BertForQA 10 | from utils.get_parser import get_parser 11 | from utils.random_seed import set_random_seed 12 | set_random_seed(0) 13 | from pytorch_lightning import Trainer 14 | 15 | 16 | 17 | def init_evaluate_parser(parser): 18 | parser.add_argument("--path_to_model_checkpoint", type=str, ) 19 | parser.add_argument("--path_to_model_hparams_file", type=str, default="") 20 | return parser 21 | 22 | 23 | def evaluate(args): 24 | """ 25 | Args: 26 | ckpt: model checkpoints. 27 | hparams_file: the string should end with "hparams.yaml" 28 | """ 29 | trainer = Trainer(gpus=args.gpus, 30 | distributed_backend=args.distributed_backend, 31 | deterministic=True) 32 | 33 | model = BertForQA.load_from_checkpoint( 34 | checkpoint_path=args.path_to_model_checkpoint, 35 | hparams_file=args.path_to_model_hparams_file, 36 | map_location=None, 37 | batch_size=args.eval_batch_size,) 38 | 39 | trainer.test(model=model) 40 | 41 | 42 | def main(): 43 | """evaluate model checkpoints on the dev set. """ 44 | eval_parser = init_evaluate_parser(get_parser()) 45 | eval_parser = BertForQA.add_model_specific_args(eval_parser) 46 | eval_parser = Trainer.add_argparse_args(eval_parser) 47 | args = eval_parser.parse_args() 48 | if len(args.path_to_model_hparams_file) == 0: 49 | args.path_to_model_hparams_file = os.path.join("/".join(args.path_to_model_checkpoint.split("/")[:-1]), "lightning_logs", "version_0", "hparams.yaml") 50 | 51 | evaluate(args) 52 | 53 | 54 | if __name__ == "__main__": 55 | main() -------------------------------------------------------------------------------- /tasks/squad/evaluate_predictions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # files: squad/evaluate_predictions.py 5 | # description: 6 | # evaluate the output prediction file 7 | # Example script: 8 | # python3 evaluate_predictions.py /data/dev-v1.1.json /outputs/dice_loss/squad 9 | 10 | import os 11 | import sys 12 | import json 13 | import subprocess 14 | from glob import glob 15 | from utils.random_seed import set_random_seed 16 | 17 | REPO_PATH = "/".join(os.path.realpath(__file__).split("/")[:-3]) 18 | 19 | 20 | def main(golden_data_file: str = "dev-v1.1.json", 21 | saved_ckpt_directory: str = None, 22 | version_2_with_negative: bool = False, 23 | eval_result_output: str = "result.json"): 24 | """evaluate model prediction files.""" 25 | 26 | if not os.path.exists(golden_data_file): 27 | raise ValueError("Please run 'python3 evaluate_predictions.py ' ") 28 | if saved_ckpt_directory is not None: 29 | prediction_files = glob(os.path.join(saved_ckpt_directory, "predictions_*_*.json")) 30 | else: 31 | raise ValueError("Please run 'python3 evaluate_predictions.py ' ") 32 | 33 | if version_2_with_negative: 34 | evaluate_script_file = os.path.join(REPO_PATH, "metrics", "functional", "squad", "evaluate_v2.py") 35 | else: 36 | evaluate_script_file = os.path.join(REPO_PATH, "metrics", "functional", "squad", "evaluate_v1.py") 37 | evaluate_sh_file = os.path.join(REPO_PATH, "metrics", "functional", "squad", "eval.sh") 38 | 39 | chmod_result = os.system(f"chmod 777 {evaluate_sh_file}") 40 | if chmod_result != 0: 41 | raise ValueError 42 | 43 | evaluate_results = {} 44 | best_em = 0 45 | best_f1 = 0 46 | best_ckpt_path = "" 47 | for prediction_file in prediction_files: 48 | evaluate_key = prediction_file.replace(f"{saved_ckpt_directory}/", "").replace("predictions_", "").replace(".json", "") 49 | cmd = [evaluate_sh_file, evaluate_script_file, golden_data_file, prediction_file] 50 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE) 51 | stdout, stderr = process.communicate() 52 | process.wait() 53 | 54 | stdout = stdout.strip() 55 | if stderr is not None: 56 | print(stderr) 57 | 58 | evaluate_value = json.loads(stdout) 59 | if best_f1 <= evaluate_value['f1']: 60 | best_f1 = evaluate_value['f1'] 61 | best_em = evaluate_value['exact_match'] 62 | best_ckpt_path = prediction_file 63 | evaluate_results[evaluate_key] = evaluate_value 64 | 65 | eval_log_file = os.path.join(saved_ckpt_directory, eval_result_output) 66 | with open(eval_log_file, "w") as f: 67 | json.dump(evaluate_results, f, sort_keys=True, indent=2, ensure_ascii=False) 68 | 69 | print(f"BEST CKPT is -> {best_ckpt_path}") 70 | print(f"BEST Exact Match is : -> {best_em}") 71 | print(f"BEST span-f1 is : {best_f1}") 72 | 73 | 74 | if __name__ == "__main__": 75 | set_random_seed(0) 76 | golden_data_file = sys.argv[1] 77 | saved_ckpt_directory = sys.argv[2] 78 | main(golden_data_file, saved_ckpt_directory, ) -------------------------------------------------------------------------------- /tasks/tnews/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | TIME=2021.02.08 5 | FILE_NAME=re_debug_tnews 6 | REPO_PATH=/data/xiaoya/workspace/mrc-with-dice-loss 7 | MODEL_SCALE=base 8 | DATA_DIR=/data/xiaoya/datasets/tnews_public_data 9 | BERT_DIR=/data/xiaoya/pretrain_lm/chinese_L-12_H-768_A-12 10 | 11 | TRAIN_BATCH_SIZE=24 12 | EVAL_BATCH_SIZE=12 13 | MAX_LENGTH=128 14 | 15 | OPTIMIZER=adamw 16 | LR_SCHEDULE=linear 17 | LR=3e-5 18 | 19 | BERT_DROPOUT=0.1 20 | ACC_GRAD=1 21 | MAX_EPOCH=5 22 | GRAD_CLIP=1.0 23 | WEIGHT_DECAY=0.002 24 | WARMUP_PROPORTION=0.002 25 | 26 | LOSS_TYPE=dice 27 | # ce, focal, dice 28 | DICE_SMOOTH=1 29 | DICE_OHEM=0.8 30 | DICE_ALPHA=0.01 31 | FOCAL_GAMMA=2 32 | 33 | PRECISION=32 34 | PROGRESS_BAR=1 35 | VAL_CHECK_INTERVAL=0.25 36 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 37 | 38 | if [[ ${LOSS_TYPE} == "bce" ]]; then 39 | LOSS_SIGN=${LOSS_TYPE} 40 | elif [[ ${LOSS_TYPE} == "focal" ]]; then 41 | LOSS_SIGN=${LOSS_TYPE}_${FOCAL_GAMMA} 42 | elif [[ ${LOSS_TYPE} == "dice" ]]; then 43 | LOSS_SIGN=${LOSS_TYPE}_${DICE_SMOOTH}_${DICE_OHEM}_${DICE_ALPHA} 44 | fi 45 | echo "DEBUG INFO -> loss sign is ${LOSS_SIGN}" 46 | 47 | OUTPUT_BASE_DIR=/data/xiaoya/outputs/dice_loss/tnews/${TIME} 48 | OUTPUT_DIR=${OUTPUT_BASE_DIR}/${FILE_NAME}_${MODEL_SCALE}_${TRAIN_BATCH_SIZE}_${MAX_LENGTH}_${LR}_${LR_SCHEDULE}_${BERT_DROPOUT}_${ACC_GRAD}_${MAX_EPOCH}_${GRAD_CLIP}_${WEIGHT_DECAY}_${WARMUP_PROPORTION}_${LOSS_SIGN} 49 | 50 | mkdir -p ${OUTPUT_DIR} 51 | 52 | CUDA_VISIBLE_DEVICES=3 python ${REPO_PATH}/tasks/tnews/train.py \ 53 | --gpus="1" \ 54 | --precision=${PRECISION} \ 55 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 56 | --eval_batch_size ${EVAL_BATCH_SIZE} \ 57 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 58 | --val_check_interval ${VAL_CHECK_INTERVAL} \ 59 | --max_length ${MAX_LENGTH} \ 60 | --optimizer ${OPTIMIZER} \ 61 | --data_dir ${DATA_DIR} \ 62 | --bert_hidden_dropout ${BERT_DROPOUT} \ 63 | --bert_config_dir ${BERT_DIR} \ 64 | --lr ${LR} \ 65 | --lr_scheduler ${LR_SCHEDULE} \ 66 | --accumulate_grad_batches ${ACC_GRAD} \ 67 | --default_root_dir ${OUTPUT_DIR} \ 68 | --output_dir ${OUTPUT_DIR} \ 69 | --max_epochs ${MAX_EPOCH} \ 70 | --gradient_clip_val ${GRAD_CLIP} \ 71 | --weight_decay ${WEIGHT_DECAY} \ 72 | --loss_type ${LOSS_TYPE} \ 73 | --dice_smooth ${DICE_SMOOTH} \ 74 | --dice_ohem ${DICE_OHEM} \ 75 | --dice_alpha ${DICE_ALPHA} \ 76 | --dice_square \ 77 | --focal_gamma ${FOCAL_GAMMA} \ 78 | --warmup_proportion ${WARMUP_PROPORTION} -------------------------------------------------------------------------------- /tests/count_length_autotokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # author: xiaoya li 5 | # file: count_length_autotokenizer.py 6 | 7 | import os 8 | import sys 9 | 10 | repo_path = "/".join(os.path.realpath(__file__).split("/")[:-2]) 11 | print(repo_path) 12 | if repo_path not in sys.path: 13 | sys.path.insert(0, repo_path) 14 | 15 | 16 | from transformers import AutoTokenizer 17 | from datasets.mrc_ner_dataset import MRCNERDataset 18 | 19 | 20 | 21 | class OntoNotesDataConfig: 22 | def __init__(self): 23 | self.data_dir = "/data/nfsdata2/xiaoya/mrc_ner/zh_onto4" 24 | self.model_path = "/data/xiaoya/pretrain_lm/chinese_L-12_H-768_A-12" 25 | self.do_lower_case = False 26 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False, tokenize_chinese_chars=True) 27 | # BertWordPieceTokenizer(os.path.join(self.model_path, "vocab.txt"), lowercase=self.do_lower_case) 28 | self.max_length = 512 29 | self.is_chinese = True 30 | self.threshold = 275 31 | self.data_sign = "zh_onto" 32 | 33 | class ChineseMSRADataConfig: 34 | def __init__(self): 35 | self.data_dir = "/data/xiaoya/datasets/mrc_ner/zh_msra" 36 | self.model_path = "/data/xiaoya/pretrain_lm/chinese_L-12_H-768_A-12" 37 | self.do_lower_case = False 38 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False, tokenize_chinese_chars=True) 39 | self.max_length = 512 40 | self.is_chinese = True 41 | self.threshold = 275 42 | self.data_sign = "zh_msra" 43 | 44 | 45 | class EnglishOntoDataConfig: 46 | def __init__(self): 47 | self.data_dir = "/data/xiaoya/datasets/mrc_ner/en_onto5" 48 | self.model_path = "/data/xiaoya/pretrain_lm/cased_L-12_H-768_A-12" 49 | self.do_lower_case = False 50 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False) 51 | self.max_length = 512 52 | self.is_chinese = False 53 | self.threshold = 275 54 | self.data_sign = "en_onto" 55 | 56 | 57 | class EnglishCoNLLDataConfig: 58 | def __init__(self): 59 | self.data_dir = "/data/xiaoya/datasets/mrc_ner/en_conll03" 60 | self.model_path = "/data/xiaoya/pretrain_lm/cased_L-12_H-768_A-12" 61 | if "uncased" in self.model_path: 62 | self.do_lower_case = True 63 | else: 64 | self.do_lower_case = False 65 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False, do_lower_case=self.do_lower_case) 66 | self.max_length = 512 67 | self.is_chinese = False 68 | self.threshold = 275 69 | self.data_sign = "en_conll03" 70 | 71 | class EnglishCoNLL03DocDataConfig: 72 | def __init__(self): 73 | self.data_dir = "/data/xiaoya/datasets/mrc_ner/en_conll03_doc" 74 | self.model_path = "/data/xiaoya/pretrain_lm/cased_L-12_H-768_A-12" 75 | self.do_lower_case = False 76 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False) 77 | self.max_length = 512 78 | self.is_chinese = False 79 | self.threshold = 384 80 | self.data_sign = "en_conll03" 81 | 82 | def count_max_length(data_sign): 83 | if data_sign == "zh_onto": 84 | data_config = OntoNotesDataConfig() 85 | elif data_sign == "zh_msra": 86 | data_config = ChineseMSRADataConfig() 87 | elif data_sign == "en_onto": 88 | data_config = EnglishOntoDataConfig() 89 | elif data_sign == "en_conll03": 90 | data_config = EnglishCoNLLDataConfig() 91 | elif data_sign == "en_conll03_doc": 92 | data_config = EnglishCoNLL03DocDataConfig() 93 | else: 94 | raise ValueError 95 | for prefix in ["test", "train", "dev"]: 96 | print("=*"*15) 97 | print(f"INFO -> loading {prefix} data. ") 98 | data_file_path = os.path.join(data_config.data_dir, f"mrc-ner.{prefix}") 99 | dataset = MRCNERDataset(json_path=data_file_path, 100 | tokenizer=data_config.tokenizer, 101 | max_length=data_config.max_length, 102 | is_chinese=data_config.is_chinese, 103 | pad_to_maxlen=False, 104 | data_sign=data_config.data_sign) 105 | max_len = 0 106 | counter = 0 107 | for idx, data_item in enumerate(dataset): 108 | tokens = data_item[0] 109 | num_tokens = tokens.shape[0] 110 | if num_tokens >= max_len: 111 | max_len = num_tokens 112 | if num_tokens > data_config.threshold: 113 | counter += 1 114 | 115 | print(f"INFO -> Max LEN for {prefix} set is : {max_len}") 116 | print(f"INFO -> large than {data_config.threshold} is {counter}") 117 | 118 | 119 | 120 | if __name__ == '__main__': 121 | # for english 122 | data_sign = "en_onto" 123 | # data_sign = "zh_onto" 124 | count_max_length(data_sign) 125 | -------------------------------------------------------------------------------- /tests/count_length_glue.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: count_length_glue.py 5 | 6 | import os 7 | import sys 8 | 9 | repo_path = "/".join(os.path.realpath(__file__).split("/")[:-2]) 10 | print(repo_path) 11 | if repo_path not in sys.path: 12 | sys.path.insert(0, repo_path) 13 | 14 | from transformers import AutoTokenizer 15 | from datasets.qqp_dataset import QQPDataset 16 | 17 | 18 | class QQPDataConfig: 19 | def __init__(self): 20 | self.data_dir = "/data/xiaoya/datasets/glue/qqp" 21 | self.model_path = "/data/xiaoya/models/bert_cased_large" 22 | self.do_lower_case = False 23 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False, 24 | do_lower_case=False, 25 | tokenize_chinese_chars=False) 26 | self.max_seq_length = 1024 27 | self.is_chinese = False 28 | self.threshold = 275 29 | self.pad_to_maxlen = False 30 | 31 | 32 | def main(): 33 | data_config = QQPDataConfig() 34 | 35 | for mode in ["train", "dev", "test"]: 36 | print("=*"*20) 37 | print(mode) 38 | print("=*"*20) 39 | data_length_collection = [] 40 | qqp_dataset = QQPDataset(data_config, data_config.tokenizer, mode=mode, ) 41 | for data_item in qqp_dataset: 42 | input_tokens = data_item["input_ids"].shape 43 | print(input_tokens) 44 | exit() 45 | 46 | 47 | if __name__ == "__main__": 48 | main() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShannonAI/dice_loss_for_NLP/d437bb999185535df46fdb74d1f2f57161331b44/utils/__init__.py -------------------------------------------------------------------------------- /utils/get_parser.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: get_parser.py 5 | # description: 6 | # argument parser 7 | 8 | import argparse 9 | 10 | def get_parser() -> argparse.ArgumentParser: 11 | """ 12 | return basic arg parser 13 | """ 14 | parser = argparse.ArgumentParser(description="argument parser") 15 | 16 | parser.add_argument("--seed", type=int, default=2333) 17 | parser.add_argument("--data_dir", type=str, help="data dir") 18 | parser.add_argument("--bert_config_dir", type=str, help="bert config dir") 19 | parser.add_argument("--pretrained_checkpoint", default="", type=str, help="pretrained checkpoint path") 20 | parser.add_argument("--train_batch_size", type=int, default=32, help="batch size for train dataloader") 21 | parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for eval dataloader") 22 | parser.add_argument("--lr", type=float, default=2e-5, help="learning rate") 23 | parser.add_argument("--lr_scheduler", type=str, default="onecycle", help="type of lr scheduler") 24 | parser.add_argument("--workers", type=int, default=0, help="num workers for dataloader") 25 | # number of data-loader workers should equal to 0. 26 | # https://blog.csdn.net/breeze210/article/details/99679048 27 | parser.add_argument("--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.") 28 | parser.add_argument("--weight_decay", default=0.01, type=float, 29 | help="Weight decay if we apply some.") 30 | # in case of not error, define a new argument 31 | parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Proportion of training to perform linear learning rate warmup for.") 32 | parser.add_argument("--adam_epsilon", default=1e-6, type=float, 33 | help="Epsilon for Adam optimizer.") 34 | parser.add_argument("--max_keep_ckpt", default=3, type=int, 35 | help="the number of keeping ckpt max.") 36 | parser.add_argument("--output_dir", default="/data", type=str, help="the directory to save model outputs") 37 | parser.add_argument("--debug", action="store_true", help="train with 10 data examples in the debug mode.") 38 | 39 | parser.add_argument("--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets") 40 | parser.add_argument("--cache_dir", default="", type=str, help="Where do you want to store the pre-trained models downloaded from s3", ) 41 | 42 | # optimizer and loss func 43 | parser.add_argument("--bert_hidden_dropout", type=float, default=0.1, ) 44 | parser.add_argument("--final_div_factor", type=float, default=1e4, 45 | help="final div factor of linear decay scheduler") 46 | # TODO: choices=["adamw", "sgd", "debias"] 47 | parser.add_argument("--optimizer", default="adamw", help="loss type") 48 | # TODO: change chocies 49 | # choices=["ce", "bce", "dice", "focal", "adaptive_dice"], 50 | parser.add_argument("--loss_type", default="bce", help="loss type") 51 | ## dice loss 52 | parser.add_argument("--dice_smooth", type=float, default=1e-4, help="smooth value of dice loss") 53 | parser.add_argument("--dice_ohem", type=float, default=0.0, help="ohem ratio of dice loss") 54 | parser.add_argument("--dice_alpha", type=float, default=0.01, help="alpha value of adaptive dice loss") 55 | parser.add_argument("--dice_square", action="store_true", help="use square for dice loss") 56 | ## focal loss 57 | parser.add_argument("--focal_gamma", type=float, default=2, help="gamma for focal loss.") 58 | parser.add_argument("--focal_alpha", type=float, help="alpha for focal loss.") 59 | 60 | # only keep the best checkpoint after training. 61 | parser.add_argument("--only_keep_the_best_ckpt_after_training", action="store_true", help="only the best model checkpoint after training. ") 62 | 63 | return parser 64 | -------------------------------------------------------------------------------- /utils/random_seed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: random_seed.py 5 | # refer to : 6 | # issue: https://github.com/PyTorchLightning/pytorch-lightning/issues/1868 7 | # Please Notice: 8 | # set for trainer: https://pytorch-lightning.readthedocs.io/en/latest/trainer.html 9 | # from pytorch_lightning import Trainer, seed_everything 10 | # seed_everything(42) 11 | # sets seeds for numpy, torch, python.random and PYTHONHASHSEED. 12 | # model = Model() 13 | # trainer = Trainer(deterministic=True) 14 | 15 | import random 16 | import torch 17 | import numpy as np 18 | from pytorch_lightning import seed_everything 19 | 20 | 21 | def set_random_seed(seed: int): 22 | """set seeds for reproducibility""" 23 | random.seed(seed) 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed_all(seed) 27 | seed_everything(seed=seed) 28 | torch.backends.cudnn.deterministic = True 29 | torch.backends.cudnn.benchmark = False 30 | 31 | 32 | if __name__ == '__main__': 33 | # without this line, x would be different in every execution. 34 | set_random_seed(0) 35 | 36 | x = np.random.random() 37 | print(x) 38 | --------------------------------------------------------------------------------