├── data └── .gitignore ├── checkpoint └── .gitignore ├── requirements.txt ├── src ├── __init__.py ├── dataset.py ├── utils.py ├── data_processor.py └── models.py ├── LICENSE ├── .gitignore ├── readme.md ├── README_eng.md └── main.py /data/.gitignore: -------------------------------------------------------------------------------- 1 | * -------------------------------------------------------------------------------- /checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | * -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | lxml 2 | torch 3 | pytorch_lightning 4 | transformers 5 | opencc 6 | tqdm -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-12 17:52:09 3 | @File : __init__.py.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 abtion 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .DS_Store 3 | __pycache__ 4 | lightning_logs 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | env/ 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # dotenv 88 | .env 89 | 90 | # virtualenv 91 | .venv 92 | venv/ 93 | ENV/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-12 16:04:23 3 | @File : dataset.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | import torch 8 | from torch.utils.data import Dataset, DataLoader 9 | 10 | from .utils import load_json 11 | 12 | 13 | class CorrectorDataset(Dataset): 14 | def __init__(self, fp): 15 | self.data = load_json(fp) 16 | 17 | def __len__(self): 18 | return len(self.data) 19 | 20 | def __getitem__(self, index): 21 | return self.data[index]['original_text'], self.data[index]['correct_text'], self.data[index]['wrong_ids'] 22 | 23 | 24 | def get_corrector_loader(fp, tokenizer, **kwargs): 25 | def _collate_fn(data): 26 | ori_texts, cor_texts, wrong_idss = zip(*data) 27 | encoded_texts = [tokenizer.tokenize(t) for t in ori_texts] 28 | max_len = max([len(t) for t in encoded_texts]) + 2 29 | det_labels = torch.zeros(len(ori_texts), max_len).long() 30 | for i, (encoded_text, wrong_ids) in enumerate(zip(encoded_texts, wrong_idss)): 31 | for idx in wrong_ids: 32 | margins = [] 33 | for word in encoded_text[:idx]: 34 | if word == '[UNK]': 35 | break 36 | if word.startswith('##'): 37 | margins.append(len(word) - 3) 38 | else: 39 | margins.append(len(word) - 1) 40 | margin = sum(margins) 41 | move = 0 42 | while (abs(move) < margin) or (idx + move >= len(encoded_text)) or encoded_text[idx + move].startswith( 43 | '##'): 44 | move -= 1 45 | det_labels[i, idx + move + 1] = 1 46 | return ori_texts, cor_texts, det_labels 47 | 48 | dataset = CorrectorDataset(fp) 49 | loader = DataLoader(dataset, collate_fn=_collate_fn, **kwargs) 50 | return loader 51 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # SoftMaskedBert-PyTorch 2 | 🙈 基于 huggingface/transformers 的SoftMaskedBert的非官方实现 3 | 4 | [English](README_eng.md) | 简体中文 5 | 6 | ## 环境准备 7 | 1. 安装 python 3.6+ 8 | 2. 运行以下命令以安装必要的包. 9 | ```shell 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ## 数据准备 14 | 1. 从 [http://nlp.ee.ncu.edu.tw/resource/csc.html](http://nlp.ee.ncu.edu.tw/resource/csc.html)下载SIGHAN数据集 15 | 2. 解压上述数据集并将文件夹中所有 ''.sgml'' 文件复制至 data/ 目录 16 | 3. 复制 ''SIGHAN15_CSC_TestInput.txt'' 和 ''SIGHAN15_CSC_TestTruth.txt'' 至 data/ 目录 17 | 4. 下载 [https://github.com/wdimmy/Automatic-Corpus-Generation/blob/master/corpus/train.sgml](https://github.com/wdimmy/Automatic-Corpus-Generation/blob/master/corpus/train.sgml) 至 data/ 目录 18 | 5. 请确保以下文件在 data/ 中 19 | ``` 20 | train.sgml 21 | B1_training.sgml 22 | C1_training.sgml 23 | SIGHAN15_CSC_A2_Training.sgml 24 | SIGHAN15_CSC_B2_Training.sgml 25 | SIGHAN15_CSC_TestInput.txt 26 | SIGHAN15_CSC_TestTruth.txt 27 | ``` 28 | 6. 运行以下命令进行数据预处理 29 | ```shell 30 | python main.py --mode preproc 31 | ``` 32 | 33 | ## 下载预训练权重 34 | 1. 从 [https://huggingface.co/bert-base-chinese/tree/main](https://huggingface.co/bert-base-chinese/tree/main) 下载BERT的预训练权重(pytorch_model.bin) 至 checkpoint/ 目录 35 | 36 | ## 训练及测试 37 | 1. 运行以下命令以训练模型。 38 | ```shell 39 | python main.py --mode train 40 | ``` 41 | 2. 运行以下命令以测试模型。 42 | ```shell 43 | python main.py --mode test 44 | ``` 45 | 3. 更多模型运行及训练参数请使用以下命令查看。 46 | ```shell 47 | python main.py --help 48 | ``` 49 | ``` 50 | --hard_device HARD_DEVICE 51 | 硬件,cpu or cuda 52 | --gpu_index GPU_INDEX 53 | gpu索引, one of [0,1,2,3] 54 | --load_checkpoint [LOAD_CHECKPOINT] 55 | 是否加载训练保存的权重, one of [t,f] 56 | --bert_checkpoint BERT_CHECKPOINT 57 | --model_save_path MODEL_SAVE_PATH 58 | --epochs EPOCHS 训练轮数 59 | --batch_size BATCH_SIZE 60 | 批大小 61 | --warmup_epochs WARMUP_EPOCHS 62 | warmup轮数, 需小于训练轮数 63 | --lr LR 学习率 64 | --accumulate_grad_batches ACCUMULATE_GRAD_BATCHES 65 | 梯度累加的batch数 66 | --mode MODE 代码运行模式,以此来控制训练测试或数据预处理,one of [train, test, preproc] 67 | --loss_weight LOSS_WEIGHT 68 | 论文中的lambda,即correction loss的权重 69 | ``` 70 | 71 | ## 实验结果 72 | ### 字级 73 | |component|p|r|f| 74 | |:-:|:-:|:-:|:-:| 75 | |Detection|0.8417|0.8274|0.8345| 76 | |Correction|0.9487|0.8739|0.9106| 77 | ### 句级 78 | |acc|p|r|f| 79 | |:-:|:-:|:-:|:-:| 80 | |0.8145|0.8674|0.7361|0.7964| 81 | 82 | detection的表现差是因为欠拟合,该实验结果仅是在处理后的数据集上跑了10个epochs的结果,并没有像paper一样做大量的预训练。 83 | 84 | 85 | ## References 86 | 1. [Spelling Error Correction with Soft-Masked BERT](https://arxiv.org/abs/2005.07421) 87 | 2. [http://ir.itc.ntnu.edu.tw/lre/sighan8csc.html](http://ir.itc.ntnu.edu.tw/lre/sighan8csc.html) 88 | 3. [https://github.com/wdimmy/Automatic-Corpus-Generation](https://github.com/wdimmy/Automatic-Corpus-Generation) 89 | 4. [transformers](https://huggingface.co/) 90 | 5. [https://github.com/sunnyqiny/Confusionset-guided-Pointer-Networks-for-Chinese-Spelling-Check](https://github.com/sunnyqiny/Confusionset-guided-Pointer-Networks-for-Chinese-Spelling-Check) -------------------------------------------------------------------------------- /README_eng.md: -------------------------------------------------------------------------------- 1 | # SoftMaskedBert-PyTorch 2 | 🙈 An unofficial implementation of SoftMaskedBert based on huggingface/transformers. 3 | 4 | English | [简体中文](readme.md) 5 | 6 | ## prepare env 7 | 1. install python 3.6+ 8 | 2. run the following command in terminal. 9 | ```shell 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ## prepare data 14 | 1. download sighan data from [http://nlp.ee.ncu.edu.tw/resource/csc.html](http://nlp.ee.ncu.edu.tw/resource/csc.html) 15 | 2. unzip the file and copy all the ''.sgml'' file to data/ 16 | 3. copy ''SIGHAN15_CSC_TestInput.txt'' and ''SIGHAN15_CSC_TestTruth.txt'' to data/ 17 | 4. download [https://github.com/wdimmy/Automatic-Corpus-Generation/blob/master/corpus/train.sgml](https://github.com/wdimmy/Automatic-Corpus-Generation/blob/master/corpus/train.sgml) to data/ 18 | 5. check following files are in data/ 19 | ``` 20 | train.sgml 21 | B1_training.sgml 22 | C1_training.sgml 23 | SIGHAN15_CSC_A2_Training.sgml 24 | SIGHAN15_CSC_B2_Training.sgml 25 | SIGHAN15_CSC_TestInput.txt 26 | SIGHAN15_CSC_TestTruth.txt 27 | ``` 28 | 6. run the following command to process the data 29 | ```shell 30 | python main.py --mode preproc 31 | ``` 32 | 33 | ## prepare bert checkpoint 34 | 1. download bert checkpoint (pytorch_model.bin) from [https://huggingface.co/bert-base-chinese/tree/main](https://huggingface.co/bert-base-chinese/tree/main) to checkpoint/ 35 | 36 | ## run 37 | 1. run the following command to train the model. 38 | ```shell 39 | python main.py --mode train 40 | ``` 41 | 2. run the following command to test the model. 42 | ```shell 43 | python main.py --mode test 44 | ``` 45 | 3. you can use the following command to get any help for arguments. 46 | ```shell 47 | python main.py --help 48 | ``` 49 | ``` 50 | --hard_device HARD_DEVICE 51 | 硬件,cpu or cuda 52 | --gpu_index GPU_INDEX 53 | gpu索引, one of [0,1,2,3] 54 | --load_checkpoint [LOAD_CHECKPOINT] 55 | 是否加载训练保存的权重, one of [t,f] 56 | --bert_checkpoint BERT_CHECKPOINT 57 | --model_save_path MODEL_SAVE_PATH 58 | --epochs EPOCHS 训练轮数 59 | --batch_size BATCH_SIZE 60 | 批大小 61 | --warmup_epochs WARMUP_EPOCHS 62 | warmup轮数, 需小于训练轮数 63 | --lr LR 学习率 64 | --accumulate_grad_batches ACCUMULATE_GRAD_BATCHES 65 | 梯度累加的batch数 66 | --mode MODE 代码运行模式,以此来控制训练测试或数据预处理,one of [train, test, preproc] 67 | --loss_weight LOSS_WEIGHT 68 | 论文中的lambda,即correction loss的权重 69 | ``` 70 | 71 | ## experimental Results 72 | 73 | ### char level 74 | 75 | |component|p|r|f| 76 | |:-:|:-:|:-:|:-:| 77 | |Detection|0.8417|0.8274|0.8345| 78 | |Correction|0.9487|0.8739|0.9106| 79 | 80 | ### sentence level 81 | 82 | |acc|p|r|f| 83 | |:-:|:-:|:-:|:-:| 84 | |0.8145|0.8674|0.7361|0.7964| 85 | 86 | 87 | ## references 88 | 1. [Spelling Error Correction with Soft-Masked BERT](https://arxiv.org/abs/2005.07421) 89 | 2. [http://ir.itc.ntnu.edu.tw/lre/sighan8csc.html](http://ir.itc.ntnu.edu.tw/lre/sighan8csc.html) 90 | 3. [https://github.com/wdimmy/Automatic-Corpus-Generation](https://github.com/wdimmy/Automatic-Corpus-Generation) 91 | 4. [transformers](https://huggingface.co/) 92 | 5. [https://github.com/sunnyqiny/Confusionset-guided-Pointer-Networks-for-Chinese-Spelling-Check](https://github.com/sunnyqiny/Confusionset-guided-Pointer-Networks-for-Chinese-Spelling-Check) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-12 15:23:56 3 | @File : main.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | import argparse 8 | import os 9 | import torch 10 | from transformers import BertTokenizer 11 | import pytorch_lightning as pl 12 | from src.dataset import get_corrector_loader 13 | from src.models import SoftMaskedBertModel 14 | from src.data_processor import preproc 15 | from src.utils import get_abs_path 16 | 17 | 18 | def str2bool(v): 19 | if isinstance(v, bool): 20 | return v 21 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 22 | return True 23 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 24 | return False 25 | else: 26 | raise argparse.ArgumentTypeError('Boolean value expected.') 27 | 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--hard_device", default='cpu', type=str, help="硬件,cpu or cuda") 32 | parser.add_argument("--gpu_index", default=0, type=int, help='gpu索引, one of [0,1,2,3,...]') 33 | parser.add_argument("--load_checkpoint", nargs='?', const=True, default=False, type=str2bool, 34 | help="是否加载训练保存的权重, one of [t,f]") 35 | parser.add_argument('--bert_checkpoint', default='bert-base-chinese', type=str) 36 | parser.add_argument('--model_save_path', default='checkpoint', type=str) 37 | parser.add_argument('--epochs', default=10, type=int, help='训练轮数') 38 | parser.add_argument('--batch_size', default=16, type=int, help='批大小') 39 | parser.add_argument('--warmup_epochs', default=8, type=int, help='warmup轮数, 需小于训练轮数') 40 | parser.add_argument('--lr', default=1e-4, type=float, help='学习率') 41 | parser.add_argument('--accumulate_grad_batches', 42 | default=16, 43 | type=int, 44 | help='梯度累加的batch数') 45 | parser.add_argument('--mode', default='train', type=str, 46 | help='代码运行模式,以此来控制训练测试或数据预处理,one of [train, test, preproc]') 47 | parser.add_argument('--loss_weight', default=0.8, type=float, help='论文中的lambda,即correction loss的权重') 48 | arguments = parser.parse_args() 49 | if arguments.hard_device == 'cpu': 50 | arguments.device = torch.device(arguments.hard_device) 51 | else: 52 | arguments.device = torch.device(f'cuda:{arguments.gpu_index}') 53 | if not 0 <= arguments.loss_weight <= 1: 54 | raise ValueError(f"The loss weight must be in [0, 1], but get{arguments.loss_weight}") 55 | print(arguments) 56 | return arguments 57 | 58 | 59 | def main(): 60 | args = parse_args() 61 | if args.mode == 'preproc': 62 | print('preprocessing...') 63 | preproc() 64 | return 65 | 66 | tokenizer = BertTokenizer.from_pretrained(args.bert_checkpoint) 67 | model = SoftMaskedBertModel(args, tokenizer) 68 | train_loader = get_corrector_loader(get_abs_path('data', 'train.json'), 69 | tokenizer, 70 | batch_size=args.batch_size, 71 | shuffle=True, 72 | num_workers=4) 73 | valid_loader = get_corrector_loader(get_abs_path('data', 'dev.json'), 74 | tokenizer, 75 | batch_size=args.batch_size, 76 | shuffle=False, 77 | num_workers=4) 78 | test_loader = get_corrector_loader(get_abs_path('data', 'test.json'), 79 | tokenizer, 80 | batch_size=args.batch_size, 81 | shuffle=False, 82 | num_workers=4) 83 | trainer = pl.Trainer(max_epochs=args.epochs, 84 | gpus=None if args.hard_device == 'cpu' else [args.gpu_index], 85 | accumulate_grad_batches=args.accumulate_grad_batches) 86 | model.load_from_transformers_state_dict(get_abs_path('checkpoint', 'pytorch_model.bin')) 87 | if args.load_checkpoint: 88 | model.load_state_dict(torch.load(get_abs_path('checkpoint', f'{model.__class__.__name__}_model.bin'), 89 | map_location=args.hard_device)) 90 | if args.mode == 'train': 91 | trainer.fit(model, train_loader, valid_loader) 92 | 93 | model.load_state_dict( 94 | torch.load(get_abs_path('checkpoint', f'{model.__class__.__name__}_model.bin'), map_location=args.hard_device)) 95 | trainer.test(model, test_loader) 96 | 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-12 15:10:43 3 | @File : utils.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | import json 8 | import os 9 | import sys 10 | 11 | 12 | def compute_corrector_prf(results): 13 | """ 14 | copy from https://github.com/sunnyqiny/Confusionset-guided-Pointer-Networks-for-Chinese-Spelling-Check/blob/master/utils/evaluation_metrics.py 15 | """ 16 | TP = 0 17 | FP = 0 18 | FN = 0 19 | all_predict_true_index = [] 20 | all_gold_index = [] 21 | for item in results: 22 | src, tgt, predict = item 23 | gold_index = [] 24 | each_true_index = [] 25 | for i in range(len(list(src))): 26 | if src[i] == tgt[i]: 27 | continue 28 | else: 29 | gold_index.append(i) 30 | all_gold_index.append(gold_index) 31 | predict_index = [] 32 | for i in range(len(list(src))): 33 | if src[i] == predict[i]: 34 | continue 35 | else: 36 | predict_index.append(i) 37 | 38 | for i in predict_index: 39 | if i in gold_index: 40 | TP += 1 41 | each_true_index.append(i) 42 | else: 43 | FP += 1 44 | for i in gold_index: 45 | if i in predict_index: 46 | continue 47 | else: 48 | FN += 1 49 | all_predict_true_index.append(each_true_index) 50 | 51 | # For the detection Precision, Recall and F1 52 | detection_precision = TP / (TP + FP) if (TP + FP) > 0 else 0 53 | detection_recall = TP / (TP + FN) if (TP + FN) > 0 else 0 54 | if detection_precision + detection_recall == 0: 55 | detection_f1 = 0 56 | else: 57 | detection_f1 = 2 * (detection_precision * detection_recall) / (detection_precision + detection_recall) 58 | print("The detection result is precision={}, recall={} and F1={}".format(detection_precision, detection_recall, 59 | detection_f1)) 60 | 61 | TP = 0 62 | FP = 0 63 | FN = 0 64 | 65 | for i in range(len(all_predict_true_index)): 66 | # we only detect those correctly detected location, which is a different from the common metrics since 67 | # we wanna to see the precision improve by using the confusionset 68 | if len(all_predict_true_index[i]) > 0: 69 | predict_words = [] 70 | for j in all_predict_true_index[i]: 71 | predict_words.append(results[i][2][j]) 72 | if results[i][1][j] == results[i][2][j]: 73 | TP += 1 74 | else: 75 | FP += 1 76 | for j in all_gold_index[i]: 77 | if results[i][1][j] in predict_words: 78 | continue 79 | else: 80 | FN += 1 81 | 82 | # For the correction Precision, Recall and F1 83 | correction_precision = TP / (TP + FP) if (TP + FP) > 0 else 0 84 | correction_recall = TP / (TP + FN) if (TP + FN) > 0 else 0 85 | if correction_precision + correction_recall == 0: 86 | correction_f1 = 0 87 | else: 88 | correction_f1 = 2 * (correction_precision * correction_recall) / (correction_precision + correction_recall) 89 | print("The correction result is precision={}, recall={} and F1={}".format(correction_precision, 90 | correction_recall, 91 | correction_f1)) 92 | 93 | return detection_f1, correction_f1 94 | 95 | 96 | def load_json(fp): 97 | if not os.path.exists(fp): 98 | return dict() 99 | 100 | with open(fp, 'r', encoding='utf8') as f: 101 | return json.load(f) 102 | 103 | 104 | def dump_json(obj, fp): 105 | try: 106 | fp = os.path.abspath(fp) 107 | if not os.path.exists(os.path.dirname(fp)): 108 | os.makedirs(os.path.dirname(fp)) 109 | with open(fp, 'w', encoding='utf8') as f: 110 | json.dump(obj, f, ensure_ascii=False, indent=4, separators=(',', ':')) 111 | print(f'json文件保存成功,{fp}') 112 | return True 113 | except Exception as e: 114 | print(f'json文件{obj}保存失败, {e}') 115 | return False 116 | 117 | 118 | def get_main_dir(): 119 | # 如果是使用pyinstaller打包后的执行文件,则定位到执行文件所在目录 120 | if hasattr(sys, 'frozen'): 121 | return os.path.join(os.path.dirname(sys.executable)) 122 | # 其他情况则定位至项目根目录 123 | return os.path.join(os.path.dirname(__file__), '..') 124 | 125 | 126 | def get_abs_path(*name): 127 | return os.path.abspath(os.path.join(get_main_dir(), *name)) 128 | 129 | 130 | def compute_sentence_level_prf(results): 131 | """ 132 | 自定义的句级prf,设定需要纠错为正样本,无需纠错为负样本 133 | :param results: 134 | :return: 135 | """ 136 | 137 | TP = 0.0 138 | FP = 0.0 139 | FN = 0.0 140 | TN = 0.0 141 | total_num = len(results) 142 | 143 | for item in results: 144 | src, tgt, predict = item 145 | 146 | # 负样本 147 | if src == tgt: 148 | # 预测也为负 149 | if tgt == predict: 150 | TN += 1 151 | # 预测为正 152 | else: 153 | FP += 1 154 | # 正样本 155 | else: 156 | # 预测也为正 157 | if tgt == predict: 158 | TP += 1 159 | # 预测为负 160 | else: 161 | FN += 1 162 | 163 | acc = (TP + TN) / total_num 164 | precision = TP / (TP + FP) 165 | recall = TP / (TP + FN) 166 | f1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0 167 | 168 | print(f'Sentence Level: acc:{acc:.6f}, precision:{precision:.6f}, recall:{recall:.6f}, f1:{f1:.6f}') 169 | return acc, precision, recall, f1 170 | -------------------------------------------------------------------------------- /src/data_processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-12 15:23:38 3 | @File : data_processor.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | import gc 8 | import random 9 | from lxml import etree 10 | 11 | import opencc 12 | from tqdm import tqdm 13 | import os 14 | from .utils import dump_json, get_abs_path 15 | 16 | 17 | def proc_item(item, convertor): 18 | root = etree.XML(item) 19 | passages = dict() 20 | mistakes = [] 21 | for passage in root.xpath('/ESSAY/TEXT/PASSAGE'): 22 | passages[passage.get('id')] = convertor.convert(passage.text) 23 | for mistake in root.xpath('/ESSAY/MISTAKE'): 24 | mistakes.append({'id': mistake.get('id'), 25 | 'location': int(mistake.get('location')) - 1, 26 | 'wrong': convertor.convert(mistake.xpath('./WRONG/text()')[0].strip()), 27 | 'correction': convertor.convert(mistake.xpath('./CORRECTION/text()')[0].strip())}) 28 | 29 | rst_items = dict() 30 | for mistake in mistakes: 31 | if mistake['id'] not in rst_items.keys(): 32 | rst_items[mistake['id']] = {'original_text': passages[mistake['id']], 33 | 'wrong_ids': [], 34 | 'correct_text': passages[mistake['id']]} 35 | 36 | # todo 繁体转简体字符数量或位置发生改变校验 37 | 38 | ori_text = rst_items[mistake['id']]['original_text'] 39 | cor_text = rst_items[mistake['id']]['correct_text'] 40 | if len(ori_text) == len(cor_text): 41 | if ori_text[mistake['location']] in mistake['wrong']: 42 | rst_items[mistake['id']]['wrong_ids'].append(mistake['location']) 43 | wrong_char_idx = mistake['wrong'].index(ori_text[mistake['location']]) 44 | start = mistake['location'] - wrong_char_idx 45 | end = start + len(mistake['wrong']) 46 | rst_items[mistake['id']][ 47 | 'correct_text'] = f'{cor_text[:start]}{mistake["correction"]}{cor_text[end:]}' 48 | else: 49 | print(f'{mistake["id"]}\n{ori_text}\n{cor_text}') 50 | rst = [] 51 | for k in rst_items.keys(): 52 | if len(rst_items[k]['correct_text']) == len(rst_items[k]['original_text']): 53 | rst.append({'id': k, **rst_items[k]}) 54 | else: 55 | text = rst_items[k]['correct_text'] 56 | rst.append({'id': k, 'correct_text': text, 'original_text': text, 'wrong_ids': []}) 57 | return rst 58 | 59 | 60 | def proc_test_set(fp, convertor): 61 | """ 62 | 生成sighan15的测试集 63 | Args: 64 | fp: 65 | convertor: 66 | 67 | Returns: 68 | 69 | """ 70 | inputs = dict() 71 | with open(os.path.join(fp, 'SIGHAN15_CSC_TestInput.txt'), 'r', encoding='utf8') as f: 72 | for line in f: 73 | pid = line[5:14] 74 | text = line[16:].strip() 75 | inputs[pid] = text 76 | 77 | rst = [] 78 | with open(os.path.join(fp, 'SIGHAN15_CSC_TestTruth.txt'), 'r', encoding='utf8') as f: 79 | for line in f: 80 | pid = line[0:9] 81 | mistakes = line[11:].strip().split(', ') 82 | if len(mistakes) <= 1: 83 | text = convertor.convert(inputs[pid]) 84 | rst.append({'id': pid, 85 | 'original_text': text, 86 | 'wrong_ids': [], 87 | 'correct_text': text}) 88 | else: 89 | wrong_ids = [] 90 | original_text = inputs[pid] 91 | cor_text = inputs[pid] 92 | for i in range(len(mistakes) // 2): 93 | idx = int(mistakes[2 * i]) - 1 94 | cor_char = mistakes[2 * i + 1] 95 | wrong_ids.append(idx) 96 | cor_text = f'{cor_text[:idx]}{cor_char}{cor_text[idx + 1:]}' 97 | original_text = convertor.convert(original_text) 98 | cor_text = convertor.convert(cor_text) 99 | if len(original_text) != len(cor_text): 100 | print(pid) 101 | print(original_text) 102 | print(cor_text) 103 | continue 104 | rst.append({'id': pid, 105 | 'original_text': original_text, 106 | 'wrong_ids': wrong_ids, 107 | 'correct_text': cor_text}) 108 | 109 | return rst 110 | 111 | 112 | def read_data(fp): 113 | for fn in os.listdir(fp): 114 | if fn.endswith('ing.sgml'): 115 | with open(os.path.join(fp, fn), 'r') as f: 116 | item = [] 117 | for line in f: 118 | if line.strip().startswith(' 0: 119 | yield ''.join(item) 120 | item = [line.strip()] 121 | elif line.strip().startswith('<'): 122 | item.append(line.strip()) 123 | 124 | 125 | def read_confusion_data(fp): 126 | fn = os.path.join(fp, 'train.sgml') 127 | with open(fn, 'r') as f: 128 | item = [] 129 | for line in tqdm(f): 130 | if line.strip().startswith(' 0: 131 | yield ''.join(item) 132 | item = [line.strip()] 133 | elif line.strip().startswith('<'): 134 | item.append(line.strip()) 135 | 136 | 137 | def proc_confusion_item(item): 138 | """ 139 | 处理confusionset数据集 140 | Args: 141 | item: 142 | 143 | Returns: 144 | 145 | """ 146 | root = etree.XML(item) 147 | text = root.xpath('/SENTENCE/TEXT/text()')[0] 148 | mistakes = [] 149 | for mistake in root.xpath('/SENTENCE/MISTAKE'): 150 | mistakes.append({'location': int(mistake.xpath('./LOCATION/text()')[0]) - 1, 151 | 'wrong': mistake.xpath('./WRONG/text()')[0].strip(), 152 | 'correction': mistake.xpath('./CORRECTION/text()')[0].strip()}) 153 | 154 | cor_text = text 155 | wrong_ids = [] 156 | 157 | for mis in mistakes: 158 | cor_text = f'{cor_text[:mis["location"]]}{mis["correction"]}{cor_text[mis["location"] + 1:]}' 159 | wrong_ids.append(mis['location']) 160 | 161 | rst = [{ 162 | 'id': '-', 163 | 'original_text': text, 164 | 'wrong_ids': wrong_ids, 165 | 'correct_text': cor_text 166 | }] 167 | if len(text) != len(cor_text): 168 | print(text) 169 | print(cor_text) 170 | return [{'id': '--', 171 | 'original_text': cor_text, 172 | 'wrong_ids': [], 173 | 'correct_text': cor_text}] 174 | # 取一定概率保留原文本 175 | if random.random() < 0.3: 176 | rst.append({'id': '--', 177 | 'original_text': cor_text, 178 | 'wrong_ids': [], 179 | 'correct_text': cor_text}) 180 | return rst 181 | 182 | 183 | def preproc(): 184 | rst_items = [] 185 | convertor = opencc.OpenCC('tw2sp.json') 186 | test_items = proc_test_set('data', convertor) 187 | for item in read_data(get_abs_path('data')): 188 | rst_items += proc_item(item, convertor) 189 | for item in read_confusion_data(get_abs_path('data')): 190 | rst_items += proc_confusion_item(item) 191 | 192 | # 拆分训练与测试 193 | dev_set_len = len(rst_items) // 10 194 | print(len(rst_items)) 195 | random.seed(666) 196 | random.shuffle(rst_items) 197 | dump_json(rst_items[:dev_set_len], get_abs_path('data', 'dev.json')) 198 | dump_json(rst_items[dev_set_len:], get_abs_path('data', 'train.json')) 199 | dump_json(test_items, get_abs_path('data', 'test.json')) 200 | gc.collect() 201 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-12 15:08:01 3 | @File : models.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | import operator 8 | import os 9 | from collections import OrderedDict 10 | 11 | import torch 12 | from torch import nn 13 | import pytorch_lightning as pl 14 | from torch.optim.lr_scheduler import LambdaLR 15 | from transformers import BertConfig 16 | from transformers.models.bert.modeling_bert import BertEmbeddings, BertEncoder, BertPooler, BertOnlyMLMHead 17 | from transformers.modeling_utils import ModuleUtilsMixin 18 | 19 | from .utils import compute_corrector_prf, compute_sentence_level_prf 20 | import numpy as np 21 | 22 | 23 | class DetectionNetwork(nn.Module): 24 | def __init__(self, config): 25 | super().__init__() 26 | self.config = config 27 | self.gru = nn.GRU( 28 | self.config.hidden_size, 29 | self.config.hidden_size // 2, 30 | num_layers=2, 31 | batch_first=True, 32 | dropout=self.config.hidden_dropout_prob, 33 | bidirectional=True, 34 | ) 35 | self.sigmoid = nn.Sigmoid() 36 | self.linear = nn.Linear(self.config.hidden_size, 1) 37 | 38 | def forward(self, hidden_states): 39 | out, _ = self.gru(hidden_states) 40 | prob = self.linear(out) 41 | prob = self.sigmoid(prob) 42 | return prob 43 | 44 | 45 | class BertCorrectionModel(torch.nn.Module, ModuleUtilsMixin): 46 | def __init__(self, config, tokenizer, device): 47 | super().__init__() 48 | self.config = config 49 | self.tokenizer = tokenizer 50 | self.embeddings = BertEmbeddings(self.config) 51 | self.corrector = BertEncoder(self.config) 52 | self.mask_token_id = self.tokenizer.mask_token_id 53 | self.pooler = BertPooler(self.config) 54 | self.cls = BertOnlyMLMHead(self.config) 55 | self._device = device 56 | 57 | def forward(self, texts, prob, embed=None, cor_labels=None, residual_connection=False): 58 | if cor_labels is not None: 59 | text_labels = self.tokenizer(cor_labels, padding=True, return_tensors='pt')['input_ids'] 60 | text_labels = text_labels.to(self._device) 61 | # torch的cross entropy loss 会忽略-100的label 62 | text_labels[text_labels == 0] = -100 63 | else: 64 | text_labels = None 65 | encoded_texts = self.tokenizer(texts, padding=True, return_tensors='pt') 66 | encoded_texts.to(self._device) 67 | if embed is None: 68 | embed = self.embeddings(input_ids=encoded_texts['input_ids'], 69 | token_type_ids=encoded_texts['token_type_ids']) 70 | # 此处较原文有一定改动,做此改动意在完整保留type_ids及position_ids的embedding。 71 | # mask_embed = self.embeddings(torch.ones_like(prob.squeeze(-1)).long() * self.mask_token_id).detach() 72 | # 此处为原文实现 73 | mask_embed = self.embeddings(torch.tensor([[self.mask_token_id]], device=self._device)).detach() 74 | cor_embed = prob * mask_embed + (1 - prob) * embed 75 | 76 | input_shape = encoded_texts['input_ids'].size() 77 | device = encoded_texts['input_ids'].device 78 | 79 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(encoded_texts['attention_mask'], 80 | input_shape, device) 81 | head_mask = self.get_head_mask(None, self.config.num_hidden_layers) 82 | encoder_outputs = self.corrector( 83 | cor_embed, 84 | attention_mask=extended_attention_mask, 85 | head_mask=head_mask, 86 | encoder_hidden_states=None, 87 | encoder_attention_mask=None, 88 | return_dict=False, 89 | ) 90 | sequence_output = encoder_outputs[0] 91 | pooled_output = self.pooler(sequence_output) if self.pooler is not None else None 92 | 93 | sequence_output = sequence_output + embed if residual_connection else sequence_output 94 | prediction_scores = self.cls(sequence_output) 95 | out = (prediction_scores, sequence_output, pooled_output) 96 | 97 | # Masked language modeling softmax layer 98 | if text_labels is not None: 99 | loss_fct = nn.CrossEntropyLoss(reduction='sum') # -100 index = padding token 100 | cor_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), text_labels.view(-1)) 101 | out = (cor_loss,) + out 102 | return out 103 | 104 | def load_from_transformers_state_dict(self, gen_fp): 105 | state_dict = OrderedDict() 106 | gen_state_dict = torch.load(gen_fp) 107 | for k, v in gen_state_dict.items(): 108 | name = k 109 | if name.startswith('bert'): 110 | name = name[5:] 111 | if name.startswith('encoder'): 112 | name = f'corrector.{name[8:]}' 113 | if 'gamma' in name: 114 | name = name.replace('gamma', 'weight') 115 | if 'beta' in name: 116 | name = name.replace('beta', 'bias') 117 | state_dict[name] = v 118 | self.load_state_dict(state_dict, strict=False) 119 | 120 | 121 | class BaseCorrectorTrainingModel(pl.LightningModule): 122 | """ 123 | 用于CSC的BaseModel, 定义了训练及预测步骤 124 | """ 125 | 126 | def __init__(self, arguments, *args, **kwargs): 127 | super().__init__(*args, **kwargs) 128 | self.args = arguments 129 | self.w = arguments.loss_weight 130 | self.min_loss = float('inf') 131 | 132 | def training_step(self, batch, batch_idx): 133 | ori_text, cor_text, det_labels = batch 134 | outputs = self.forward(ori_text, cor_text, det_labels) 135 | loss = self.w * outputs[1] + (1 - self.w) * outputs[0] 136 | return loss 137 | 138 | def validation_step(self, batch, batch_idx): 139 | ori_text, cor_text, det_labels = batch 140 | outputs = self.forward(ori_text, cor_text, det_labels) 141 | loss = self.w * outputs[1] + (1 - self.w) * outputs[0] 142 | det_y_hat = (outputs[2] > 0.5).long() 143 | cor_y_hat = torch.argmax((outputs[3]), dim=-1) 144 | encoded_x = self.tokenizer(cor_text, padding=True, return_tensors='pt') 145 | encoded_x.to(self._device) 146 | cor_y = encoded_x['input_ids'] 147 | cor_y_hat *= encoded_x['attention_mask'] 148 | 149 | results = [] 150 | det_acc_labels = [] 151 | cor_acc_labels = [] 152 | for src, tgt, predict, det_predict, det_label in zip(ori_text, cor_y, cor_y_hat, det_y_hat, det_labels): 153 | _src = self.tokenizer(src, add_special_tokens=False)['input_ids'] 154 | _tgt = tgt[1:len(_src) + 1].cpu().numpy().tolist() 155 | _predict = predict[1:len(_src) + 1].cpu().numpy().tolist() 156 | cor_acc_labels.append(1 if operator.eq(_tgt, _predict) else 0) 157 | det_acc_labels.append(det_predict[1:len(_src) + 1].equal(det_label[1:len(_src) + 1])) 158 | results.append((_src, _tgt, _predict,)) 159 | 160 | return loss.cpu().item(), det_acc_labels, cor_acc_labels, results 161 | 162 | def on_validation_batch_start(self, batch, batch_idx: int, dataloader_idx: int) -> None: 163 | print('Valid.') 164 | 165 | def validation_epoch_end(self, outputs) -> None: 166 | det_acc_labels = [] 167 | cor_acc_labels = [] 168 | results = [] 169 | for out in outputs: 170 | det_acc_labels += out[1] 171 | cor_acc_labels += out[2] 172 | results += out[3] 173 | loss = np.mean([out[0] for out in outputs]) 174 | print(f'loss: {loss}') 175 | print(f'Detection:\n' 176 | f'acc: {np.mean(det_acc_labels):.4f}') 177 | print(f'Correction:\n' 178 | f'acc: {np.mean(cor_acc_labels):.4f}') 179 | print('Char Level:') 180 | compute_corrector_prf(results) 181 | compute_sentence_level_prf(results) 182 | if (len(outputs) > 5) and (loss < self.min_loss): 183 | self.min_loss = loss 184 | torch.save(self.state_dict(), 185 | os.path.join(self.args.model_save_path, f'{self.__class__.__name__}_model.bin')) 186 | print('model saved.') 187 | torch.save(self.state_dict(), 188 | os.path.join(self.args.model_save_path, f'{self.__class__.__name__}_model.bin')) 189 | 190 | def test_step(self, batch, batch_idx): 191 | return self.validation_step(batch, batch_idx) 192 | 193 | def test_epoch_end(self, outputs) -> None: 194 | print('Test.') 195 | self.validation_epoch_end(outputs) 196 | 197 | def configure_optimizers(self): 198 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.lr) 199 | scheduler = LambdaLR(optimizer, 200 | lr_lambda=lambda step: min((step + 1) ** -0.5, 201 | (step + 1) * self.args.warmup_epochs ** (-1.5)), 202 | last_epoch=-1) 203 | return [optimizer], [scheduler] 204 | 205 | 206 | class SoftMaskedBertModel(BaseCorrectorTrainingModel): 207 | def __init__(self, args, tokenizer): 208 | super().__init__(args) 209 | self.args = args 210 | self.config = BertConfig.from_pretrained(args.bert_checkpoint) 211 | self.detector = DetectionNetwork(self.config) 212 | self.tokenizer = tokenizer 213 | self.corrector = BertCorrectionModel(self.config, tokenizer, args.device) 214 | self._device = args.device 215 | 216 | def forward(self, texts, cor_labels=None, det_labels=None): 217 | encoded_texts = self.tokenizer(texts, padding=True, return_tensors='pt') 218 | encoded_texts.to(self._device) 219 | embed = self.corrector.embeddings(input_ids=encoded_texts['input_ids'], 220 | token_type_ids=encoded_texts['token_type_ids']) 221 | prob = self.detector(embed) 222 | cor_out = self.corrector(texts, prob, embed, cor_labels, residual_connection=True) 223 | 224 | if det_labels is not None: 225 | det_loss_fct = nn.BCELoss(reduction='sum') 226 | # pad部分不计算损失 227 | active_loss = encoded_texts['attention_mask'].view(-1, prob.shape[1]) == 1 228 | active_probs = prob.view(-1, prob.shape[1])[active_loss] 229 | active_labels = det_labels[active_loss] 230 | det_loss = det_loss_fct(active_probs, active_labels.float()) 231 | outputs = (det_loss, cor_out[0], prob.squeeze(-1)) + cor_out[1:] 232 | else: 233 | outputs = (prob.squeeze(-1),) + cor_out 234 | 235 | return outputs 236 | 237 | def load_from_transformers_state_dict(self, gen_fp): 238 | """ 239 | 从transformers加载预训练权重 240 | :param gen_fp: 241 | :return: 242 | """ 243 | self.corrector.load_from_transformers_state_dict(gen_fp) 244 | --------------------------------------------------------------------------------