├── .gitignore ├── README.md ├── command └── train.sh ├── data ├── .gitkeep ├── example.txt ├── example_input.json └── example_output.json ├── logs └── .gitkeep ├── model └── .gitkeep ├── pretrained_model └── .gitkeep ├── requirements.txt └── src ├── __init__.py ├── baseline ├── __init__.py ├── ctc_vocab │ ├── config.py │ ├── ctc_correct_tags.txt │ └── ctc_detect_tags.txt ├── dataset.py ├── loss.py ├── modeling.py ├── predictor.py ├── tokenizer.py └── trainer.py ├── corrector.py ├── evaluate.py ├── metric.py ├── prepare_for_upload.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | nogit 2 | data/* 3 | __pycache__ 4 | model/* 5 | pretrained_model/* 6 | logs/* 7 | !.gitkeep 8 | !data/example_input.json 9 | !data/example_output.json 10 | !data/example.txt 11 | utils -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 文本智能校对大赛 2 | 3 | 4 | - [文本智能校对大赛](#文本智能校对大赛) 5 | - [最新动态](#最新动态) 6 | - [赛程](#赛程) 7 | - [任务描述](#任务描述) 8 | - [Baseline介绍](#baseline介绍) 9 | - [模型](#模型) 10 | - [代码结构](#代码结构) 11 | - [使用说明](#使用说明) 12 | - [baseline表现](#baseline表现) 13 | - [开始训练](#开始训练) 14 | - [其他公开数据集](#其他公开数据集) 15 | - [相关资源](#相关资源) 16 | 17 | 18 | ## 最新动态 19 | 20 | | 时间 | 事件 | 21 | | ------- | ------- | 22 | | 2022.7.19 | 修复指标计算的Bug, 详见[metry.py](https://github.com/bitallin/MiduCTC-competition/blob/main/src/metric.py) ,感谢[@HillZhang1999](https://github.com/HillZhang1999)的提醒和贡献| 23 | | 2022.7.21 | 更新baseline在a榜数据集上的表现| 24 | 25 | ## 赛程 26 | 27 | | 时间 | 事件 | 28 | | ------- | ------- | 29 | | 2022.7.13 | 比赛启动,开放报名,[赛事网址](https://aistudio.baidu.com/aistudio/competition/detail/404/0/introduction),初赛A榜数据集,初赛A榜提交入口| 30 | | 2022.8.12 | 报名截止,关闭初赛A榜评测入口 | 31 | | 2022.8.13 | 开放初赛B榜数据集、评测入口 | 32 | | 2022.8.17 | 关闭初赛B榜数据集、评测入口 | 33 | | 2022.8.18 | 开放决赛数据集、评测入口 | 34 | | 2022.8.20 | 关闭决赛数据集、评测入口 | 35 | 36 | 37 | ## 获奖队伍 38 | 39 | | 排名 | 参赛队伍 | 得分 | 40 | |-------| ------- | ------- | 41 | |1| 苏州大学-阿里达摩院联队 | 0.7637 | 42 | |2| Grandmaly | 0.7338| 43 | |3| 语言组小分队 | 0.6916| 44 | |4| YanSun的团队 | 0.6779| 45 | |5| TAL-有错必改 | 0.6528| 46 | |6| NLPIR | 0.6425| 47 | 48 | 49 | ## 任务描述 50 | 51 | 本次赛题选择网络文本作为输入,从中检测并纠正错误,实现中文文本校对系统。即给定一段文本,校对系统从中检测出错误字词、错误类型,并进行纠正,最终输出校正后的结果。 52 | 53 | 文本校对又称文本纠错,相关资料可参考自然语言处理方向的**语法纠错(Grammatical Error Correction, GEC)**任务和**中文拼写纠错(Chinese spelling check, CSC)** 54 | 55 | 56 | ## Baseline介绍 57 | 58 | ### 模型 59 | 60 | 提供了**GECToR**作为baseline模型,可参考[GECToR论文](https://aclanthology.org/2020.bea-1.16.pdf)和[GECToR源代码](https://github.com/grammarly/gector) 61 | 62 | 63 | 64 | ### 代码结构 65 | ``` 66 | ├── command 67 | │ └── train.sh # 训练脚本 68 | ├── data 69 | ├── logs 70 | ├── pretrained_model 71 | └── src 72 | ├── __init__.py 73 | ├── baseline # baseline系统 74 | ├── corrector.py # 文本校对入口 75 | ├── evaluate.py # 指标评估 76 | ├── metric.py # 指标计算文件 77 | ├── prepare_for_upload.py # 生成要提交的结果文件 78 | └── train.py # 训练入口 79 | ``` 80 | 81 | ### 使用说明 82 | 83 | - 数据集获取:请于[比赛网站](https://aistudio.baidu.com/aistudio/competition/detail/404/0/introduction)获取数据集 84 | - 提供了基础校对系统的baseline,其中baseline模型训练参数说明参考src/baseline/trainer.py 85 | - baseline中的预训练模型支持使用bert类模型,可从HuggingFace下载bert类预训练模型,如: [chinese-roberta-wwm-ext](https://huggingface.co/hfl/chinese-roberta-wwm-ext)等 86 | - baseline仅作参考,参赛队伍可对baseline进行二次开发,或采取其他解决方案。 87 | 88 | ### baseline表现 89 | 90 | - baseline在a榜训练集(不含preliminary_extend_train.json),使用单机4卡分布式训练的情况下 91 | - 训练到**第4个epoch**结束在a榜提交得分约为:**0.3587** 92 | 93 | 具体训练参数如下: 94 | 95 | ``` 96 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m src.train \ 97 | --in_model_dir "pretrained_model/chinese-roberta-wwm-ext" \ 98 | --out_model_dir "model/ctc_train" \ 99 | --epochs "50" \ 100 | --batch_size "158" \ 101 | --max_seq_len "128" \ 102 | --learning_rate "5e-5" \ 103 | --train_fp "data/comp_data/preliminary_a_data/preliminary_train.json" \ 104 | --test_fp "data/comp_data/preliminary_a_data/preliminary_val.json" \ 105 | --random_seed_num "42" \ 106 | --check_val_every_n_epoch "0.5" \ 107 | --early_stop_times "20" \ 108 | --warmup_steps "-1" \ 109 | --dev_data_ratio "0.01" \ 110 | --training_mode "ddp" \ 111 | --amp true \ 112 | --freeze_embedding false 113 | ``` 114 | 115 | - 其中所用的预训练为:[chinese-roberta-wwm-ext](https://huggingface.co/hfl/chinese-roberta-wwm-ext) 116 | - 若使用Macbert可能会有进一步的提升 117 | 118 | 119 | ### 开始训练 120 | 121 | ``` 122 | cd command && sh train.sh 123 | ``` 124 | 125 | ## 其他公开数据集 126 | 127 | - CGED历年公开数据集:http://www.cged.tech/ 128 | - NLPCC2018语法纠错数据集:http://tcci.ccf.org.cn/conference/2018/taskdata.php 129 | - SIGHAN及相关训练集:http://ir.itc.ntnu.edu.tw/lre/sighan8csc.html 130 | 131 | ## 相关资源 132 | 133 | - [pycorrector](https://github.com/shibing624/pycorrector) 134 | - [中文文本纠错开源项目整理](https://github.com/li-aolong/li-aolong.github.io/issues/12) 135 | - [PyCorrector文本纠错工具实践和代码详解](https://zhuanlan.zhihu.com/p/138981644) 136 | - [CTC-2021](https://github.com/destwang/CTC2021) 137 | - [Text Correction Papers](https://github.com/nghuyong/text-correction-papers) 138 | - [文本语法纠错不完全调研:学术界 v.s. 工业界最新研究进展](https://zhuanlan.zhihu.com/p/398928434) 139 | - [知物由学 | “找茬”不如交给AI算法,细说文本纠错的多种实现途径 ](https://zhuanlan.zhihu.com/p/434672168) 140 | - [中文文本纠错算法--错别字纠正的二三事 ](https://zhuanlan.zhihu.com/p/40806718) 141 | - [中文文本纠错算法走到多远了?](https://cloud.tencent.com/developer/article/1435917) 142 | - [平安寿险 AI 团队 | 文本纠错技术探索和实践](https://www.6aiq.com/article/1594474039153) 143 | - [中文文本纠错(Chinese Text Correction, CTC)相关资源,本资源由哈工大讯飞联合实验室(HFL)王宝鑫和赵红红整理维护。](https://github.com/destwang/CTCResources) 144 | -------------------------------------------------------------------------------- /command/train.sh: -------------------------------------------------------------------------------- 1 | cd .. && CUDA_VISIBLE_DEVICES=0,1,2,3 python -m src.train \ 2 | --in_model_dir "pretrained_model/chinese-roberta-wwm-ext" \ 3 | --out_model_dir "model/ctc_train" \ 4 | --epochs "50" \ 5 | --batch_size "158" \ 6 | --max_seq_len "128" \ 7 | --learning_rate "5e-5" \ 8 | --train_fp "data/comp_data/preliminary_a_data/preliminary_train.json" \ 9 | --test_fp "data/comp_data/preliminary_a_data/preliminary_val.json" \ 10 | --random_seed_num "42" \ 11 | --check_val_every_n_epoch "0.5" \ 12 | --early_stop_times "20" \ 13 | --warmup_steps "-1" \ 14 | --dev_data_ratio "0.01" \ 15 | --training_mode "ddp" \ 16 | --amp true \ 17 | --freeze_embedding false 18 | 19 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bitallin/MiduCTC-competition/f224bce5806ebc3e2af1c3bfb082e4ba8a357e47/data/.gitkeep -------------------------------------------------------------------------------- /data/example.txt: -------------------------------------------------------------------------------- 1 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 2 | 领导的按排,我坚决服从 领导的安排,我坚决服从 3 | 今天的天气真错! 今天的天气真不错! 4 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 5 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 6 | 领导的按排,我坚决服从 领导的安排,我坚决服从 7 | 今天的天气真错! 今天的天气真不错! 8 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 9 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 10 | 领导的按排,我坚决服从 领导的安排,我坚决服从 11 | 今天的天气真错! 今天的天气真不错! 12 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 13 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 14 | 领导的按排,我坚决服从 领导的安排,我坚决服从 15 | 今天的天气真错! 今天的天气真不错! 16 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 17 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 18 | 领导的按排,我坚决服从 领导的安排,我坚决服从 19 | 今天的天气真错! 今天的天气真不错! 20 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 21 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 22 | 领导的按排,我坚决服从 领导的安排,我坚决服从 23 | 今天的天气真错! 今天的天气真不错! 24 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 25 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 26 | 领导的按排,我坚决服从 领导的安排,我坚决服从 27 | 今天的天气真错! 今天的天气真不错! 28 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 29 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 30 | 领导的按排,我坚决服从 领导的安排,我坚决服从 31 | 今天的天气真错! 今天的天气真不错! 32 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 33 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 34 | 领导的按排,我坚决服从 领导的安排,我坚决服从 35 | 今天的天气真错! 今天的天气真不错! 36 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 37 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 38 | 领导的按排,我坚决服从 领导的安排,我坚决服从 39 | 今天的天气真错! 今天的天气真不错! 40 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 41 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 42 | 领导的按排,我坚决服从 领导的安排,我坚决服从 43 | 今天的天气真错! 今天的天气真不错! 44 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 45 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 46 | 领导的按排,我坚决服从 领导的安排,我坚决服从 47 | 今天的天气真错! 今天的天气真不错! 48 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 49 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 50 | 领导的按排,我坚决服从 领导的安排,我坚决服从 51 | 今天的天气真错! 今天的天气真不错! 52 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 53 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 54 | 领导的按排,我坚决服从 领导的安排,我坚决服从 55 | 今天的天气真错! 今天的天气真不错! 56 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 57 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 58 | 领导的按排,我坚决服从 领导的安排,我坚决服从 59 | 今天的天气真错! 今天的天气真不错! 60 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 61 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 62 | 领导的按排,我坚决服从 领导的安排,我坚决服从 63 | 今天的天气真错! 今天的天气真不错! 64 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 65 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 66 | 领导的按排,我坚决服从 领导的安排,我坚决服从 67 | 今天的天气真错! 今天的天气真不错! 68 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 69 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 70 | 领导的按排,我坚决服从 领导的安排,我坚决服从 71 | 今天的天气真错! 今天的天气真不错! 72 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 73 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 74 | 领导的按排,我坚决服从 领导的安排,我坚决服从 75 | 今天的天气真错! 今天的天气真不错! 76 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 77 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 78 | 领导的按排,我坚决服从 领导的安排,我坚决服从 79 | 今天的天气真错! 今天的天气真不错! 80 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 81 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 82 | 领导的按排,我坚决服从 领导的安排,我坚决服从 83 | 今天的天气真错! 今天的天气真不错! 84 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 85 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 86 | 领导的按排,我坚决服从 领导的安排,我坚决服从 87 | 今天的天气真错! 今天的天气真不错! 88 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 89 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 90 | 领导的按排,我坚决服从 领导的安排,我坚决服从 91 | 今天的天气真错! 今天的天气真不错! 92 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 93 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 94 | 领导的按排,我坚决服从 领导的安排,我坚决服从 95 | 今天的天气真错! 今天的天气真不错! 96 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 97 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 98 | 领导的按排,我坚决服从 领导的安排,我坚决服从 99 | 今天的天气真错! 今天的天气真不错! 100 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 101 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 102 | 领导的按排,我坚决服从 领导的安排,我坚决服从 103 | 今天的天气真错! 今天的天气真不错! 104 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 105 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 106 | 领导的按排,我坚决服从 领导的安排,我坚决服从 107 | 今天的天气真错! 今天的天气真不错! 108 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 109 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 110 | 领导的按排,我坚决服从 领导的安排,我坚决服从 111 | 今天的天气真错! 今天的天气真不错! 112 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 113 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 114 | 领导的按排,我坚决服从 领导的安排,我坚决服从 115 | 今天的天气真错! 今天的天气真不错! 116 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 117 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 118 | 领导的按排,我坚决服从 领导的安排,我坚决服从 119 | 今天的天气真错! 今天的天气真不错! 120 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 121 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 122 | 领导的按排,我坚决服从 领导的安排,我坚决服从 123 | 今天的天气真错! 今天的天气真不错! 124 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 125 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 126 | 领导的按排,我坚决服从 领导的安排,我坚决服从 127 | 今天的天气真错! 今天的天气真不错! 128 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 129 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 130 | 领导的按排,我坚决服从 领导的安排,我坚决服从 131 | 今天的天气真错! 今天的天气真不错! 132 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 133 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 134 | 领导的按排,我坚决服从 领导的安排,我坚决服从 135 | 今天的天气真错! 今天的天气真不错! 136 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 137 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 138 | 领导的按排,我坚决服从 领导的安排,我坚决服从 139 | 今天的天气真错! 今天的天气真不错! 140 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 141 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 142 | 领导的按排,我坚决服从 领导的安排,我坚决服从 143 | 今天的天气真错! 今天的天气真不错! 144 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 145 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 146 | 领导的按排,我坚决服从 领导的安排,我坚决服从 147 | 今天的天气真错! 今天的天气真不错! 148 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 149 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 150 | 领导的按排,我坚决服从 领导的安排,我坚决服从 151 | 今天的天气真错! 今天的天气真不错! 152 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 153 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 154 | 领导的按排,我坚决服从 领导的安排,我坚决服从 155 | 今天的天气真错! 今天的天气真不错! 156 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 157 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 158 | 领导的按排,我坚决服从 领导的安排,我坚决服从 159 | 今天的天气真错! 今天的天气真不错! 160 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 161 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 162 | 领导的按排,我坚决服从 领导的安排,我坚决服从 163 | 今天的天气真错! 今天的天气真不错! 164 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 165 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 166 | 领导的按排,我坚决服从 领导的安排,我坚决服从 167 | 今天的天气真错! 今天的天气真不错! 168 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 169 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 170 | 领导的按排,我坚决服从 领导的安排,我坚决服从 171 | 今天的天气真错! 今天的天气真不错! 172 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 173 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 174 | 领导的按排,我坚决服从 领导的安排,我坚决服从 175 | 今天的天气真错! 今天的天气真不错! 176 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 177 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 178 | 领导的按排,我坚决服从 领导的安排,我坚决服从 179 | 今天的天气真错! 今天的天气真不错! 180 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 181 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 182 | 领导的按排,我坚决服从 领导的安排,我坚决服从 183 | 今天的天气真错! 今天的天气真不错! 184 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 185 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 186 | 领导的按排,我坚决服从 领导的安排,我坚决服从 187 | 今天的天气真错! 今天的天气真不错! 188 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 189 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 190 | 领导的按排,我坚决服从 领导的安排,我坚决服从 191 | 今天的天气真错! 今天的天气真不错! 192 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 193 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 194 | 领导的按排,我坚决服从 领导的安排,我坚决服从 195 | 今天的天气真错! 今天的天气真不错! 196 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 197 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 198 | 领导的按排,我坚决服从 领导的安排,我坚决服从 199 | 今天的天气真错! 今天的天气真不错! 200 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 201 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 202 | 领导的按排,我坚决服从 领导的安排,我坚决服从 203 | 今天的天气真错! 今天的天气真不错! 204 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 205 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 206 | 领导的按排,我坚决服从 领导的安排,我坚决服从 207 | 今天的天气真错! 今天的天气真不错! 208 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 209 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 210 | 领导的按排,我坚决服从 领导的安排,我坚决服从 211 | 今天的天气真错! 今天的天气真不错! 212 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 213 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 214 | 领导的按排,我坚决服从 领导的安排,我坚决服从 215 | 今天的天气真错! 今天的天气真不错! 216 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 217 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 218 | 领导的按排,我坚决服从 领导的安排,我坚决服从 219 | 今天的天气真错! 今天的天气真不错! 220 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 221 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 222 | 领导的按排,我坚决服从 领导的安排,我坚决服从 223 | 今天的天气真错! 今天的天气真不错! 224 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 225 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 226 | 领导的按排,我坚决服从 领导的安排,我坚决服从 227 | 今天的天气真错! 今天的天气真不错! 228 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 229 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 230 | 领导的按排,我坚决服从 领导的安排,我坚决服从 231 | 今天的天气真错! 今天的天气真不错! 232 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 233 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 234 | 领导的按排,我坚决服从 领导的安排,我坚决服从 235 | 今天的天气真错! 今天的天气真不错! 236 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 237 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 238 | 领导的按排,我坚决服从 领导的安排,我坚决服从 239 | 今天的天气真错! 今天的天气真不错! 240 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 241 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 242 | 领导的按排,我坚决服从 领导的安排,我坚决服从 243 | 今天的天气真错! 今天的天气真不错! 244 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 245 | 在公共场所要自觉遵守次序。 在公共场所要自觉遵守秩序。 246 | 领导的按排,我坚决服从 领导的安排,我坚决服从 247 | 今天的天气真错! 今天的天气真不错! 248 | 张明拾金不昧的事迹传遍全校全校。 张明拾金不昧的事迹传遍全校。 -------------------------------------------------------------------------------- /data/example_input.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "source": "领导的按排,我坚决服从", 4 | "id": 1 5 | }, 6 | { 7 | "source": "今天的天气真错!", 8 | "id": 2 9 | } 10 | ] -------------------------------------------------------------------------------- /data/example_output.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "inference": "领导的安排,我坚决服从", 4 | "id": 1 5 | }, 6 | { 7 | "inference": "今天的天气真不错!", 8 | "id": 2 9 | } 10 | ] -------------------------------------------------------------------------------- /logs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bitallin/MiduCTC-competition/f224bce5806ebc3e2af1c3bfb082e4ba8a357e47/logs/.gitkeep -------------------------------------------------------------------------------- /model/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bitallin/MiduCTC-competition/f224bce5806ebc3e2af1c3bfb082e4ba8a357e47/model/.gitkeep -------------------------------------------------------------------------------- /pretrained_model/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bitallin/MiduCTC-competition/f224bce5806ebc3e2af1c3bfb082e4ba8a357e47/pretrained_model/.gitkeep -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | auto_argparse==0.0.7 2 | numpy==1.19.5 3 | rich==12.3.0 4 | torch==1.9.0+cu111 5 | transformers==4.6.0 6 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import re 4 | from logging.handlers import TimedRotatingFileHandler 5 | 6 | 7 | def setup_log(log_name): 8 | logger = logging.getLogger(log_name) 9 | log_path = os.path.join("logs", log_name) 10 | logger.setLevel(logging.DEBUG) 11 | file_handler = TimedRotatingFileHandler( 12 | filename=log_path, when="MIDNIGHT", interval=1, backupCount=30 13 | ) 14 | file_handler.suffix = "%Y-%m-%d.log" 15 | file_handler.extMatch = re.compile(r"^\d{4}-\d{2}-\d{2}.log$") 16 | stream_handler = logging.StreamHandler() 17 | formatter = logging.Formatter( 18 | "[%(asctime)s] [%(process)d] [%(levelname)s] - %(module)s.%(funcName)s (%(filename)s:%(lineno)d) - %(message)s" 19 | ) 20 | 21 | stream_handler.setFormatter(file_handler) 22 | file_handler.setFormatter( 23 | formatter 24 | ) 25 | logger.addHandler(stream_handler) 26 | logger.addHandler(file_handler) 27 | return logger 28 | 29 | 30 | logger = setup_log("ctc.log") 31 | -------------------------------------------------------------------------------- /src/baseline/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bitallin/MiduCTC-competition/f224bce5806ebc3e2af1c3bfb082e4ba8a357e47/src/baseline/__init__.py -------------------------------------------------------------------------------- /src/baseline/ctc_vocab/config.py: -------------------------------------------------------------------------------- 1 | class VocabConf: 2 | detect_vocab_size = 2 3 | correct_vocab_size = 20675 -------------------------------------------------------------------------------- /src/baseline/ctc_vocab/ctc_detect_tags.txt: -------------------------------------------------------------------------------- 1 | $KEEP 2 | $ERROR -------------------------------------------------------------------------------- /src/baseline/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from difflib import SequenceMatcher 3 | from typing import Dict, List 4 | 5 | import torch 6 | from src import logger 7 | from src.baseline.tokenizer import CtcTokenizer 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class DatasetCTC(Dataset): 12 | 13 | def __init__(self, 14 | in_model_dir: str, 15 | src_texts: List[str], 16 | trg_texts: List[str], 17 | max_seq_len: int = 128, 18 | ctc_label_vocab_dir: str = 'src/baseline/ctc_vocab', 19 | _loss_ignore_id: int = -100 20 | ): 21 | """ 22 | :param data_dir: 数据集txt文件目录: 例如 data/train or data/dev 23 | :param tokenizer: 24 | :param ctc_label_vocab_dir: ctc任务的文件夹路径, 包含d_tags.txt和labels.txt 25 | :param keep_one_append:多个操作型label保留一个 26 | """ 27 | super(DatasetCTC, self).__init__() 28 | 29 | assert len(src_texts) == len( 30 | trg_texts), 'keep equal length between srcs and trgs' 31 | self.src_texts = src_texts 32 | self.trg_texts = trg_texts 33 | self.tokenizer = CtcTokenizer.from_pretrained(in_model_dir) 34 | self.max_seq_len = max_seq_len 35 | self.id2dtag, self.dtag2id, self.id2ctag, self.ctag2id = self.load_label_dict( 36 | ctc_label_vocab_dir) 37 | 38 | self.dtag_num = len(self.dtag2id) 39 | 40 | # 检测标签 41 | self._keep_d_tag_id, self._error_d_tag_id = self.dtag2id['$KEEP'], self.dtag2id['$ERROR'] 42 | # 纠错标签 43 | self._keep_c_tag_id = self.ctag2id['$KEEP'] 44 | self._delete_c_tag_id = self.ctag2id['$DELETE'] 45 | self.replace_unk_c_tag_id = self.ctag2id['[REPLACE_UNK]'] 46 | self.append_unk_c_tag_id = self.ctag2id['[APPEND_UNK]'] 47 | 48 | # voab id 49 | try: 50 | self._start_vocab_id = self.tokenizer.vocab['[START]'] 51 | except KeyError: 52 | self._start_vocab_id = self.tokenizer.vocab['[unused1]'] 53 | # loss ignore id 54 | self._loss_ignore_id = _loss_ignore_id 55 | 56 | def load_label_dict(self, ctc_label_vocab_dir: str): 57 | dtag_fp = os.path.join(ctc_label_vocab_dir, 'ctc_detect_tags.txt') 58 | ctag_fp = os.path.join(ctc_label_vocab_dir, 'ctc_correct_tags.txt') 59 | 60 | id2dtag = [line.strip() for line in open(dtag_fp, encoding='utf8')] 61 | d_tag2id = {v: i for i, v in enumerate(id2dtag)} 62 | 63 | id2ctag = [line.strip() for line in open(ctag_fp, encoding='utf8')] 64 | c_tag2id = {v: i for i, v in enumerate(id2ctag)} 65 | logger.info('d_tag num: {}, d_tags:{}'.format(len(id2dtag), d_tag2id)) 66 | return id2dtag, d_tag2id, id2ctag, c_tag2id 67 | 68 | @staticmethod 69 | def match_ctc_idx(src_text: str, trg_text: str): 70 | replace_idx_list, delete_idx_list, missing_idx_list = [], [], [] 71 | r = SequenceMatcher(None, src_text, trg_text) 72 | diffs = r.get_opcodes() 73 | 74 | for diff in diffs: 75 | tag, i1, i2, j1, j2 = diff 76 | if tag == 'replace' and i2-i1 == j2-j1: 77 | replace_idx_list += [(i, '$REPLACE_'+trg_text[j]) 78 | for i, j in zip(range(i1, i2), range(j1, j2))] 79 | elif tag == 'insert' and j2-j1 == 1: 80 | missing_idx_list.append((i1-1, '$APPEND_'+trg_text[j1])) 81 | elif tag == 'delete': 82 | # 叠字叠词删除后面的 83 | redundant_length = i2-i1 84 | post_i1, post_i2 = i1+redundant_length, i2+redundant_length 85 | if src_text[i1:i2] == src_text[post_i1:post_i2]: 86 | i1, i2 = post_i1, post_i2 87 | for i in range(i1, i2): 88 | delete_idx_list.append(i) 89 | 90 | return replace_idx_list, delete_idx_list, missing_idx_list 91 | 92 | def __getitem__(self, item: int) -> Dict[str, torch.Tensor]: 93 | src, trg = self.src_texts[item], self.trg_texts[item] 94 | inputs = self.parse_item(src, trg) 95 | return_dict = { 96 | 'input_ids': torch.LongTensor(inputs['input_ids']), 97 | 'attention_mask': torch.LongTensor(inputs['attention_mask']), 98 | 'token_type_ids': torch.LongTensor(inputs['token_type_ids']), 99 | 'd_tags': torch.LongTensor(inputs['d_tags']), 100 | 'c_tags': torch.LongTensor(inputs['c_tags']) 101 | } 102 | return return_dict 103 | 104 | def __len__(self) -> int: 105 | return len(self.src_texts) 106 | 107 | def convert_ids_to_ctags(self, ctag_id_list: list) -> list: 108 | "id to correct tag" 109 | return [self.id2ctag[i] if i != self._loss_ignore_id else self._loss_ignore_id for i in ctag_id_list] 110 | 111 | def convert_ids_to_dtags(self, dtag_id_list: list) -> list: 112 | "id to detect tag" 113 | return [self.id2dtag[i] if i != self._loss_ignore_id else self._loss_ignore_id for i in dtag_id_list] 114 | 115 | def parse_item(self, src, trg): 116 | """[summary] 117 | 118 | Args: 119 | src ([type]): text 120 | redundant_marks ([type]): [(1,2), (5,6)] 121 | 122 | Returns: 123 | [type]: [description] 124 | """ 125 | if src and len(src) < 3: 126 | trg = src 127 | 128 | src, trg = '始' + src, '始'+trg 129 | src, trg = src[:self.max_seq_len - 2], trg[:self.max_seq_len - 2] 130 | inputs = self.tokenizer(src, 131 | max_len=self.max_seq_len, 132 | return_batch=False) 133 | inputs['input_ids'][1] = self._start_vocab_id # 把 始 换成 [START] 134 | replace_idx_list, delete_idx_list, missing_idx_list = self.match_ctc_idx( 135 | src, trg) 136 | 137 | # --- 对所有 token 计算loss --- 138 | src_len = len(src) 139 | ignore_loss_seq_len = self.max_seq_len-(src_len+1) # ex sep and pad 140 | # 先默认给keep,会面对有错误标签的进行修改 141 | d_tags = [self._loss_ignore_id] + [self._keep_d_tag_id] * \ 142 | src_len + [self._loss_ignore_id] * ignore_loss_seq_len 143 | c_tags = [self._loss_ignore_id] + [self._keep_c_tag_id] * \ 144 | src_len + [self._loss_ignore_id] * ignore_loss_seq_len 145 | 146 | for (replace_idx, replace_char) in replace_idx_list: 147 | # +1 是因为input id的第一个token是cls 148 | d_tags[replace_idx+1] = self._error_d_tag_id 149 | c_tags[replace_idx + 150 | 1] = self.ctag2id.get(replace_char, self.replace_unk_c_tag_id) 151 | 152 | for delete_idx in delete_idx_list: 153 | d_tags[delete_idx+1] = self._error_d_tag_id 154 | c_tags[delete_idx+1] = self._delete_c_tag_id 155 | 156 | for (miss_idx, miss_char) in missing_idx_list: 157 | d_tags[miss_idx + 1] = self._error_d_tag_id 158 | c_tags[miss_idx + 159 | 1] = self.ctag2id.get(miss_char, self.append_unk_c_tag_id) 160 | 161 | inputs['d_tags'] = d_tags 162 | inputs['c_tags'] = c_tags 163 | return inputs -------------------------------------------------------------------------------- /src/baseline/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LabelSmoothingLoss(torch.nn.Module): 5 | """formula 6 | loss= { 7 | (1-smoothing) * logP(x), if (x==y) 8 | (smoothing) / (num_classes-1) * logP(x), if (x!=y) 9 | } 10 | Args: 11 | torch (_type_): _description_ 12 | """ 13 | def __init__(self, smoothing:float=0.1, reduction:str='mean', ignore_index:int=-100): 14 | assert reduction in ('mean', 'sum', 'none') 15 | super(LabelSmoothingLoss, self).__init__() 16 | self.confidence = 1.0 - smoothing 17 | self.smoothing = smoothing 18 | self._reduction = reduction 19 | self._ignore_index = ignore_index 20 | 21 | def forward(self, pred:torch.Tensor, target:torch.Tensor): 22 | num_classes = pred.size()[-1] 23 | pred = pred.log_softmax(dim=-1) 24 | 25 | pred = pred[target != self._ignore_index] 26 | target = target[target != self._ignore_index] 27 | 28 | new_target = torch.zeros_like(pred) 29 | new_target.fill_(value=self.smoothing / (num_classes - 1)) 30 | new_target.scatter_(dim=1, index=target.data.unsqueeze(1), value=self.confidence) 31 | loss = -new_target * pred 32 | if self._reduction == 'mean': 33 | return torch.mean(torch.sum(loss, -1)) 34 | elif self._reduction == 'sum': 35 | return torch.sum(loss, -1) 36 | elif self._reduction == 'none': 37 | return loss 38 | 39 | -------------------------------------------------------------------------------- /src/baseline/modeling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | from src.baseline.ctc_vocab.config import VocabConf 5 | from src.baseline.loss import LabelSmoothingLoss 6 | from transformers.models.bert import BertModel, BertPreTrainedModel 7 | from torch.nn import CrossEntropyLoss 8 | 9 | class ModelingCtcBert(BertPreTrainedModel): 10 | 11 | def __init__(self, config): 12 | super().__init__(config) 13 | self.config = config 14 | self.bert = BertModel(config) 15 | self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) 16 | self.tag_detect_projection_layer = torch.nn.Linear( 17 | config.hidden_size, VocabConf.detect_vocab_size) 18 | self.tag_label_projection_layer = torch.nn.Linear( 19 | config.hidden_size, VocabConf.correct_vocab_size) 20 | self.init_weights() 21 | self._detect_criterion = CrossEntropyLoss(ignore_index=-100) 22 | self._correct_criterion = LabelSmoothingLoss(smoothing=0.1, ignore_index=-100) 23 | 24 | @staticmethod 25 | def build_dummpy_inputs(): 26 | inputs = {} 27 | inputs['input_ids'] = torch.LongTensor( 28 | torch.randint(low=1, high=10, size=(8, 56))) 29 | inputs['attention_mask'] = torch.ones(size=(8, 56)).long() 30 | inputs['token_type_ids'] = torch.zeros(size=(8, 56)).long() 31 | inputs['detect_labels'] = torch.zeros(size=(8, 56)).long() 32 | inputs['correct_labels'] = torch.zeros(size=(8, 56)).long() 33 | return inputs 34 | 35 | def forward( 36 | self, 37 | input_ids=None, 38 | attention_mask=None, 39 | token_type_ids=None, 40 | detect_labels=None, 41 | correct_labels=None 42 | ): 43 | 44 | hidden_states = self.bert( 45 | input_ids=input_ids, 46 | attention_mask=attention_mask, 47 | token_type_ids=token_type_ids)[0] 48 | detect_outputs = self.tag_detect_projection_layer(hidden_states) 49 | correct_output = self.tag_label_projection_layer(hidden_states) 50 | 51 | loss = None 52 | if detect_labels is not None and correct_labels is not None: 53 | 54 | loss = self._detect_criterion( 55 | detect_outputs.view(-1, VocabConf.detect_vocab_size), detect_labels.view(-1)) + self._correct_criterion( 56 | correct_output.view(-1, VocabConf.correct_vocab_size), correct_labels.view(-1)) 57 | elif detect_labels is not None: 58 | loss = self._detect_criterion( 59 | detect_outputs.view(-1, VocabConf.detect_vocab_size), detect_labels.view(-1)) 60 | elif correct_labels is not None: 61 | loss = self._correct_criterion( 62 | correct_output.view(-1, VocabConf.correct_vocab_size), correct_labels.view(-1)) 63 | 64 | return detect_outputs, correct_output, loss -------------------------------------------------------------------------------- /src/baseline/predictor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import os 6 | 7 | import torch 8 | from src import logger 9 | from src.baseline.modeling import ModelingCtcBert 10 | from src.baseline.tokenizer import CtcTokenizer 11 | 12 | 13 | class PredictorCtc: 14 | def __init__( 15 | self, 16 | in_model_dir, 17 | ctc_label_vocab_dir='src/baseline/ctc_vocab', 18 | use_cuda=True, 19 | cuda_id=None, 20 | ): 21 | 22 | self.in_model_dir = in_model_dir 23 | self.model = ModelingCtcBert.from_pretrained( 24 | in_model_dir) 25 | self._id2dtag, self._dtag2id, self._id2ctag, self._ctag2id = self.load_label_dict( 26 | ctc_label_vocab_dir) 27 | logger.info('model loaded from dir {}'.format( 28 | self.in_model_dir)) 29 | self.use_cuda = use_cuda 30 | if self.use_cuda and torch.cuda.is_available(): 31 | if cuda_id is not None: 32 | torch.cuda.set_device(cuda_id) 33 | self.model.cuda() 34 | self.model.half() 35 | self.model.eval() 36 | self.tokenizer = CtcTokenizer.from_pretrained(in_model_dir) 37 | 38 | try: 39 | self._start_vocab_id = self.tokenizer.vocab['[START]'] 40 | except KeyError: 41 | self._start_vocab_id = self.tokenizer.vocab['[unused1]'] 42 | 43 | def load_label_dict(self, ctc_label_vocab_dir): 44 | dtag_fp = os.path.join(ctc_label_vocab_dir, 'ctc_detect_tags.txt') 45 | ctag_fp = os.path.join(ctc_label_vocab_dir, 'ctc_correct_tags.txt') 46 | 47 | id2dtag = [line.strip() for line in open(dtag_fp, encoding='utf8')] 48 | d_tag2id = {v: i for i, v in enumerate(id2dtag)} 49 | 50 | id2ctag = [line.strip() for line in open(ctag_fp, encoding='utf8')] 51 | c_tag2id = {v: i for i, v in enumerate(id2ctag)} 52 | logger.info('d_tag num: {}, d_tags:{}'.format(len(id2dtag), d_tag2id)) 53 | return id2dtag, d_tag2id, id2ctag, c_tag2id 54 | 55 | def id_list2ctag_list(self, id_list) -> list: 56 | 57 | return [self._id2ctag[i] for i in id_list] 58 | 59 | @torch.no_grad() 60 | def predict(self, texts, return_topk=1, batch_size=32): 61 | if isinstance(texts, str): 62 | texts = [texts] 63 | else: 64 | texts = texts 65 | outputs = [] 66 | for start_idx in range(0, len(texts), batch_size): 67 | batch_texts = texts[start_idx:start_idx+batch_size] 68 | 69 | batch_texts = [' ' + t for t in batch_texts] # 开头加一个占位符 70 | inputs = self.tokenizer(batch_texts, 71 | return_tensors='pt') 72 | # 把 ' ' 换成 _start_vocab_id 73 | inputs['input_ids'][..., 1] = self._start_vocab_id 74 | if self.use_cuda and torch.cuda.is_available(): 75 | inputs['input_ids'] = inputs['input_ids'].cuda() 76 | inputs['attention_mask'] = inputs['attention_mask'].cuda() 77 | inputs['token_type_ids'] = inputs['token_type_ids'].cuda() 78 | 79 | d_preds, preds, loss = self.model( 80 | input_ids=inputs['input_ids'], 81 | attention_mask=inputs['attention_mask'], 82 | token_type_ids=inputs['token_type_ids'], 83 | ) 84 | 85 | preds = torch.softmax(preds[:, 1:, :], dim=-1) # 从cls后面开始 86 | recall_top_k_probs, recall_top_k_ids = preds.topk( 87 | k=return_topk, dim=-1, largest=True, sorted=True) 88 | recall_top_k_probs = recall_top_k_probs.tolist() 89 | recall_top_k_ids = recall_top_k_ids.tolist() 90 | recall_top_k_chars = [[self.id_list2ctag_list( 91 | char_level) for char_level in sent_level] for sent_level in recall_top_k_ids] 92 | batch_texts = [['']+list(t)[1:] for t in batch_texts] # 占位符 93 | batch_outputs = [list(zip(text, top_k_char, top_k_prob)) for text, top_k_char, top_k_prob in zip( 94 | batch_texts, recall_top_k_chars, recall_top_k_probs)] 95 | outputs.extend(batch_outputs) 96 | return outputs 97 | 98 | @staticmethod 99 | def output2text(output): 100 | 101 | pred_text = '' 102 | for src_token, pred_token_list, pred_prob_list in output: 103 | pred_token = pred_token_list[0] 104 | if '$KEEP' in pred_token: 105 | pred_text += src_token 106 | elif '$DELETE' in pred_token: 107 | continue 108 | elif '$REPLACE' in pred_token: 109 | pred_text += pred_token.split('_')[-1] 110 | elif '$APPEND' in pred_token: 111 | pred_text += src_token+pred_token.split('_')[-1] 112 | 113 | return pred_text -------------------------------------------------------------------------------- /src/baseline/tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for Bert.""" 16 | 17 | 18 | import collections 19 | import os 20 | import unicodedata 21 | from typing import List, Optional, Tuple 22 | 23 | import torch 24 | from src import logger 25 | from transformers.tokenization_utils import (PreTrainedTokenizer, _is_control, 26 | _is_punctuation, _is_whitespace) 27 | 28 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 29 | 30 | PRETRAINED_VOCAB_FILES_MAP = { 31 | "vocab_file": { 32 | "bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt", 33 | "bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt", 34 | "bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt", 35 | "bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt", 36 | "bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt", 37 | "bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt", 38 | "bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt", 39 | "bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt", 40 | "bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt", 41 | "bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt", 42 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt", 43 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt", 44 | "bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt", 45 | "bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt", 46 | "bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt", 47 | "TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt", 48 | "TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt", 49 | "wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt", 50 | } 51 | } 52 | 53 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 54 | "bert-base-uncased": 512, 55 | "bert-large-uncased": 512, 56 | "bert-base-cased": 512, 57 | "bert-large-cased": 512, 58 | "bert-base-multilingual-uncased": 512, 59 | "bert-base-multilingual-cased": 512, 60 | "bert-base-chinese": 512, 61 | "bert-base-german-cased": 512, 62 | "bert-large-uncased-whole-word-masking": 512, 63 | "bert-large-cased-whole-word-masking": 512, 64 | "bert-large-uncased-whole-word-masking-finetuned-squad": 512, 65 | "bert-large-cased-whole-word-masking-finetuned-squad": 512, 66 | "bert-base-cased-finetuned-mrpc": 512, 67 | "bert-base-german-dbmdz-cased": 512, 68 | "bert-base-german-dbmdz-uncased": 512, 69 | "TurkuNLP/bert-base-finnish-cased-v1": 512, 70 | "TurkuNLP/bert-base-finnish-uncased-v1": 512, 71 | "wietsedv/bert-base-dutch-cased": 512, 72 | } 73 | 74 | PRETRAINED_INIT_CONFIGURATION = { 75 | "bert-base-uncased": {"do_lower_case": True}, 76 | "bert-large-uncased": {"do_lower_case": True}, 77 | "bert-base-cased": {"do_lower_case": False}, 78 | "bert-large-cased": {"do_lower_case": False}, 79 | "bert-base-multilingual-uncased": {"do_lower_case": True}, 80 | "bert-base-multilingual-cased": {"do_lower_case": False}, 81 | "bert-base-chinese": {"do_lower_case": False}, 82 | "bert-base-german-cased": {"do_lower_case": False}, 83 | "bert-large-uncased-whole-word-masking": {"do_lower_case": True}, 84 | "bert-large-cased-whole-word-masking": {"do_lower_case": False}, 85 | "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True}, 86 | "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False}, 87 | "bert-base-cased-finetuned-mrpc": {"do_lower_case": False}, 88 | "bert-base-german-dbmdz-cased": {"do_lower_case": False}, 89 | "bert-base-german-dbmdz-uncased": {"do_lower_case": True}, 90 | "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False}, 91 | "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True}, 92 | "wietsedv/bert-base-dutch-cased": {"do_lower_case": False}, 93 | } 94 | 95 | 96 | def load_vocab(vocab_file): 97 | """Loads a vocabulary file into a dictionary.""" 98 | vocab = collections.OrderedDict() 99 | with open(vocab_file, "r", encoding="utf-8") as reader: 100 | tokens = reader.readlines() 101 | for index, token in enumerate(tokens): 102 | token = token.rstrip("\n") 103 | vocab[token] = index 104 | return vocab 105 | 106 | 107 | def whitespace_tokenize(text): 108 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 109 | text = text.strip() 110 | if not text: 111 | return [] 112 | tokens = text.split() 113 | return tokens 114 | 115 | 116 | class CtcTokenizer(PreTrainedTokenizer): 117 | r""" 118 | char-level tokenizer for chinese text correction 119 | 120 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. 121 | Users should refer to this superclass for more information regarding those methods. 122 | 123 | Args: 124 | vocab_file (:obj:`str`): 125 | File containing the vocabulary. 126 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): 127 | Whether or not to lowercase the input when tokenizing. 128 | do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`): 129 | Whether or not to do basic tokenization before WordPiece. 130 | never_split (:obj:`Iterable`, `optional`): 131 | Collection of tokens which will never be split during tokenization. Only has an effect when 132 | :obj:`do_basic_tokenize=True` 133 | unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`): 134 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this 135 | token instead. 136 | sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`): 137 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for 138 | sequence classification or for a text and a question for question answering. It is also used as the last 139 | token of a sequence built with special tokens. 140 | pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`): 141 | The token used for padding, for example when batching sequences of different lengths. 142 | cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`): 143 | The classifier token which is used when doing sequence classification (classification of the whole sequence 144 | instead of per-token classification). It is the first token of the sequence when built with special tokens. 145 | mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`): 146 | The token used for masking values. This is the token used when training this model with masked language 147 | modeling. This is the token which the model will try to predict. 148 | tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): 149 | Whether or not to tokenize Chinese characters. 150 | 151 | This should likely be deactivated for Japanese (see this `issue 152 | `__). 153 | strip_accents: (:obj:`bool`, `optional`): 154 | Whether or not to strip all accents. If this option is not specified, then it will be determined by the 155 | value for :obj:`lowercase` (as in the original BERT). 156 | """ 157 | 158 | vocab_files_names = VOCAB_FILES_NAMES 159 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 160 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 161 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 162 | 163 | def __init__( 164 | self, 165 | vocab_file, 166 | do_lower_case=True, 167 | do_basic_tokenize=True, 168 | never_split=None, 169 | unk_token="[UNK]", 170 | sep_token="[SEP]", 171 | pad_token="[PAD]", 172 | cls_token="[CLS]", 173 | mask_token="[MASK]", 174 | tokenize_chinese_chars=True, 175 | strip_accents=None, 176 | **kwargs 177 | ): 178 | super().__init__( 179 | do_lower_case=do_lower_case, 180 | do_basic_tokenize=do_basic_tokenize, 181 | never_split=never_split, 182 | unk_token=unk_token, 183 | sep_token=sep_token, 184 | pad_token=pad_token, 185 | cls_token=cls_token, 186 | mask_token=mask_token, 187 | tokenize_chinese_chars=tokenize_chinese_chars, 188 | strip_accents=strip_accents, 189 | **kwargs, 190 | ) 191 | 192 | if not os.path.isfile(vocab_file): 193 | raise ValueError( 194 | f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained " 195 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" 196 | ) 197 | self.vocab = load_vocab(vocab_file) 198 | self.ids_to_tokens = collections.OrderedDict( 199 | [(ids, tok) for tok, ids in self.vocab.items()]) 200 | self.do_basic_tokenize = do_basic_tokenize 201 | if do_basic_tokenize: 202 | self.basic_tokenizer = BasicTokenizer( 203 | do_lower_case=do_lower_case, 204 | never_split=never_split, 205 | tokenize_chinese_chars=tokenize_chinese_chars, 206 | strip_accents=strip_accents, 207 | ) 208 | self.wordpiece_tokenizer = WordpieceTokenizer( 209 | vocab=self.vocab, unk_token=self.unk_token) 210 | 211 | @property 212 | def do_lower_case(self): 213 | return self.basic_tokenizer.do_lower_case 214 | 215 | @property 216 | def vocab_size(self): 217 | return len(self.vocab) 218 | 219 | def get_vocab(self): 220 | return dict(self.vocab, **self.added_tokens_encoder) 221 | 222 | def _tokenize(self, text): 223 | split_tokens = [] 224 | if self.do_basic_tokenize: 225 | for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): 226 | 227 | # If the token is part of the never_split set 228 | if token in self.basic_tokenizer.never_split: 229 | split_tokens.append(token) 230 | else: 231 | split_tokens += self.wordpiece_tokenizer.tokenize(token) 232 | else: 233 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 234 | return split_tokens 235 | 236 | def _convert_token_to_id(self, token): 237 | """Converts a token (str) in an id using the vocab.""" 238 | return self.vocab.get(token, self.vocab.get(self.unk_token)) 239 | 240 | def _convert_id_to_token(self, index): 241 | """Converts an index (integer) in a token (str) using the vocab.""" 242 | return self.ids_to_tokens.get(index, self.unk_token) 243 | 244 | def convert_tokens_to_string(self, tokens): 245 | """Converts a sequence of tokens (string) in a single string.""" 246 | out_string = " ".join(tokens).replace(" ##", "").strip() 247 | return out_string 248 | 249 | def build_inputs_with_special_tokens( 250 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 251 | ) -> List[int]: 252 | """ 253 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and 254 | adding special tokens. A BERT sequence has the following format: 255 | 256 | - single sequence: ``[CLS] X [SEP]`` 257 | - pair of sequences: ``[CLS] A [SEP] B [SEP]`` 258 | 259 | Args: 260 | token_ids_0 (:obj:`List[int]`): 261 | List of IDs to which the special tokens will be added. 262 | token_ids_1 (:obj:`List[int]`, `optional`): 263 | Optional second list of IDs for sequence pairs. 264 | 265 | Returns: 266 | :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. 267 | """ 268 | if token_ids_1 is None: 269 | return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] 270 | cls = [self.cls_token_id] 271 | sep = [self.sep_token_id] 272 | return cls + token_ids_0 + sep + token_ids_1 + sep 273 | 274 | def get_special_tokens_mask( 275 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False 276 | ) -> List[int]: 277 | """ 278 | Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding 279 | special tokens using the tokenizer ``prepare_for_model`` method. 280 | 281 | Args: 282 | token_ids_0 (:obj:`List[int]`): 283 | List of IDs. 284 | token_ids_1 (:obj:`List[int]`, `optional`): 285 | Optional second list of IDs for sequence pairs. 286 | already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): 287 | Whether or not the token list is already formatted with special tokens for the model. 288 | 289 | Returns: 290 | :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 291 | """ 292 | 293 | if already_has_special_tokens: 294 | return super().get_special_tokens_mask( 295 | token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True 296 | ) 297 | 298 | if token_ids_1 is not None: 299 | return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] 300 | return [1] + ([0] * len(token_ids_0)) + [1] 301 | 302 | def create_token_type_ids_from_sequences( 303 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 304 | ) -> List[int]: 305 | """ 306 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence 307 | pair mask has the following format: 308 | 309 | :: 310 | 311 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 312 | | first sequence | second sequence | 313 | 314 | If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s). 315 | 316 | Args: 317 | token_ids_0 (:obj:`List[int]`): 318 | List of IDs. 319 | token_ids_1 (:obj:`List[int]`, `optional`): 320 | Optional second list of IDs for sequence pairs. 321 | 322 | Returns: 323 | :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given 324 | sequence(s). 325 | """ 326 | sep = [self.sep_token_id] 327 | cls = [self.cls_token_id] 328 | if token_ids_1 is None: 329 | return len(cls + token_ids_0 + sep) * [0] 330 | return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] 331 | 332 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: 333 | index = 0 334 | if os.path.isdir(save_directory): 335 | vocab_file = os.path.join( 336 | save_directory, (filename_prefix + "-" if filename_prefix else "") + 337 | VOCAB_FILES_NAMES["vocab_file"] 338 | ) 339 | else: 340 | vocab_file = (filename_prefix + 341 | "-" if filename_prefix else "") + save_directory 342 | with open(vocab_file, "w", encoding="utf-8") as writer: 343 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 344 | if index != token_index: 345 | logger.warning( 346 | f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." 347 | " Please check that the vocabulary is not corrupted!" 348 | ) 349 | index = token_index 350 | writer.write(token + "\n") 351 | index += 1 352 | return (vocab_file,) 353 | 354 | def __call__(self, 355 | texts, 356 | max_len=None, 357 | return_tensors=None, 358 | return_length=False, 359 | return_batch=True): 360 | "预测tokenize, 按batch texts中最大的文本长度来pad, realise只需要input id, mask, length" 361 | 362 | if isinstance(texts, str): 363 | texts = [texts] 364 | cls_id, sep_id, pad_id, unk_id = self.vocab['[CLS]'], self.vocab[ 365 | '[SEP]'], self.vocab['[PAD]'], self.vocab['[UNK]'] 366 | input_ids, attention_mask, token_type_ids, length = [], [], [], [] 367 | if max_len is None: 368 | max_len = max([len(text) for text in texts]) + 2 # 注意+2 cls, sep 369 | 370 | for text in texts: 371 | true_input_id = [self.vocab.get( 372 | c, unk_id) for c in text][:max_len-2] 373 | pad_len = (max_len-len(true_input_id)-2) 374 | input_id = [cls_id] + true_input_id + [sep_id] + [pad_id] * pad_len 375 | a_mask = [1] * (len(true_input_id) + 2) + [0] * pad_len 376 | token_type_id = [0] * max_len 377 | input_ids.append(input_id) 378 | attention_mask.append(a_mask) 379 | token_type_ids.append(token_type_id) 380 | length.append(sum(a_mask)) 381 | 382 | rtn_dict = {'input_ids': input_ids, 383 | 'attention_mask': attention_mask, 384 | 'token_type_ids': token_type_ids, 385 | } 386 | if return_length: 387 | rtn_dict['length'] = length 388 | 389 | if return_tensors == 'pt': 390 | for k, v in rtn_dict.items(): 391 | rtn_dict[k] = torch.LongTensor(v) 392 | if not return_batch: 393 | for i,v in rtn_dict.items(): 394 | rtn_dict[i] = rtn_dict[i][0] 395 | return rtn_dict 396 | 397 | 398 | class BasicTokenizer(object): 399 | """ 400 | Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). 401 | 402 | Args: 403 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): 404 | Whether or not to lowercase the input when tokenizing. 405 | never_split (:obj:`Iterable`, `optional`): 406 | Collection of tokens which will never be split during tokenization. Only has an effect when 407 | :obj:`do_basic_tokenize=True` 408 | tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): 409 | Whether or not to tokenize Chinese characters. 410 | 411 | This should likely be deactivated for Japanese (see this `issue 412 | `__). 413 | strip_accents: (:obj:`bool`, `optional`): 414 | Whether or not to strip all accents. If this option is not specified, then it will be determined by the 415 | value for :obj:`lowercase` (as in the original BERT). 416 | """ 417 | 418 | def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None): 419 | if never_split is None: 420 | never_split = [] 421 | self.do_lower_case = do_lower_case 422 | self.never_split = set(never_split) 423 | self.tokenize_chinese_chars = tokenize_chinese_chars 424 | self.strip_accents = strip_accents 425 | 426 | def tokenize(self, text, never_split=None): 427 | """ 428 | Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see 429 | WordPieceTokenizer. 430 | 431 | Args: 432 | **never_split**: (`optional`) list of str 433 | Kept for backward compatibility purposes. Now implemented directly at the base class level (see 434 | :func:`PreTrainedTokenizer.tokenize`) List of token not to split. 435 | """ 436 | # union() returns a new set by concatenating the two sets. 437 | never_split = self.never_split.union( 438 | set(never_split)) if never_split else self.never_split 439 | text = self._clean_text(text) 440 | 441 | # This was added on November 1st, 2018 for the multilingual and Chinese 442 | # models. This is also applied to the English models now, but it doesn't 443 | # matter since the English models were not trained on any Chinese data 444 | # and generally don't have any Chinese data in them (there are Chinese 445 | # characters in the vocabulary because Wikipedia does have some Chinese 446 | # words in the English Wikipedia.). 447 | if self.tokenize_chinese_chars: 448 | text = self._tokenize_chinese_chars(text) 449 | orig_tokens = whitespace_tokenize(text) 450 | split_tokens = [] 451 | for token in orig_tokens: 452 | if token not in never_split: 453 | if self.do_lower_case: 454 | token = token.lower() 455 | if self.strip_accents is not False: 456 | token = self._run_strip_accents(token) 457 | elif self.strip_accents: 458 | token = self._run_strip_accents(token) 459 | split_tokens.extend(self._run_split_on_punc(token, never_split)) 460 | 461 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 462 | return output_tokens 463 | 464 | def _run_strip_accents(self, text): 465 | """Strips accents from a piece of text.""" 466 | text = unicodedata.normalize("NFD", text) 467 | output = [] 468 | for char in text: 469 | cat = unicodedata.category(char) 470 | if cat == "Mn": 471 | continue 472 | output.append(char) 473 | return "".join(output) 474 | 475 | def _run_split_on_punc(self, text, never_split=None): 476 | """Splits punctuation on a piece of text.""" 477 | if never_split is not None and text in never_split: 478 | return [text] 479 | chars = list(text) 480 | i = 0 481 | start_new_word = True 482 | output = [] 483 | while i < len(chars): 484 | char = chars[i] 485 | if _is_punctuation(char): 486 | output.append([char]) 487 | start_new_word = True 488 | else: 489 | if start_new_word: 490 | output.append([]) 491 | start_new_word = False 492 | output[-1].append(char) 493 | i += 1 494 | 495 | return ["".join(x) for x in output] 496 | 497 | def _tokenize_chinese_chars(self, text): 498 | """Adds whitespace around any CJK character.""" 499 | output = [] 500 | for char in text: 501 | cp = ord(char) 502 | if self._is_chinese_char(cp): 503 | output.append(" ") 504 | output.append(char) 505 | output.append(" ") 506 | else: 507 | output.append(char) 508 | return "".join(output) 509 | 510 | def _is_chinese_char(self, cp): 511 | """Checks whether CP is the codepoint of a CJK character.""" 512 | # This defines a "chinese character" as anything in the CJK Unicode block: 513 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 514 | # 515 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 516 | # despite its name. The modern Korean Hangul alphabet is a different block, 517 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 518 | # space-separated words, so they are not treated specially and handled 519 | # like the all of the other languages. 520 | if ( 521 | (cp >= 0x4E00 and cp <= 0x9FFF) 522 | or (cp >= 0x3400 and cp <= 0x4DBF) # 523 | or (cp >= 0x20000 and cp <= 0x2A6DF) # 524 | or (cp >= 0x2A700 and cp <= 0x2B73F) # 525 | or (cp >= 0x2B740 and cp <= 0x2B81F) # 526 | or (cp >= 0x2B820 and cp <= 0x2CEAF) # 527 | or (cp >= 0xF900 and cp <= 0xFAFF) 528 | or (cp >= 0x2F800 and cp <= 0x2FA1F) # 529 | ): # 530 | return True 531 | 532 | return False 533 | 534 | def _clean_text(self, text): 535 | """Performs invalid character removal and whitespace cleanup on text.""" 536 | output = [] 537 | for char in text: 538 | cp = ord(char) 539 | if cp == 0 or cp == 0xFFFD or _is_control(char): 540 | continue 541 | if _is_whitespace(char): 542 | output.append(" ") 543 | else: 544 | output.append(char) 545 | return "".join(output) 546 | 547 | 548 | class WordpieceTokenizer(object): 549 | """Runs WordPiece tokenization.""" 550 | 551 | def __init__(self, vocab, unk_token, max_input_chars_per_word=100): 552 | self.vocab = vocab 553 | self.unk_token = unk_token 554 | self.max_input_chars_per_word = max_input_chars_per_word 555 | 556 | def tokenize(self, text): 557 | """ 558 | Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform 559 | tokenization using the given vocabulary. 560 | 561 | For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`. 562 | 563 | Args: 564 | text: A single token or whitespace separated tokens. This should have 565 | already been passed through `BasicTokenizer`. 566 | 567 | Returns: 568 | A list of wordpiece tokens. 569 | """ 570 | 571 | output_tokens = [] 572 | for token in whitespace_tokenize(text): 573 | chars = list(token) 574 | if len(chars) > self.max_input_chars_per_word: 575 | output_tokens.append(self.unk_token) 576 | continue 577 | 578 | is_bad = False 579 | start = 0 580 | sub_tokens = [] 581 | while start < len(chars): 582 | end = len(chars) 583 | cur_substr = None 584 | while start < end: 585 | substr = "".join(chars[start:end]) 586 | if start > 0: 587 | substr = "##" + substr 588 | if substr in self.vocab: 589 | cur_substr = substr 590 | break 591 | end -= 1 592 | if cur_substr is None: 593 | is_bad = True 594 | break 595 | sub_tokens.append(cur_substr) 596 | start = end 597 | 598 | if is_bad: 599 | output_tokens.append(self.unk_token) 600 | else: 601 | output_tokens.extend(sub_tokens) 602 | return output_tokens 603 | 604 | 605 | if __name__ == '__main__': 606 | 607 | model_dir = 'pretrained_model/chinese-roberta-wwm-ext' 608 | tokenizer = CtcTokenizer.from_pretrained(model_dir) 609 | inputs = tokenizer(['撒打算大阿斯顿', '撒打算大'], max_len=128) 610 | print(inputs) 611 | -------------------------------------------------------------------------------- /src/baseline/trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import time 6 | from math import ceil 7 | from typing import List 8 | import json 9 | import numpy as np 10 | import torch 11 | from rich.progress import track 12 | from src import logger 13 | from src.baseline.dataset import DatasetCTC 14 | from src.baseline.modeling import ModelingCtcBert 15 | from torch.cuda.amp import GradScaler 16 | from torch.cuda.amp import autocast 17 | from torch.nn.parallel import DistributedDataParallel as DDP 18 | from torch.utils.data import DataLoader, DistributedSampler 19 | from transformers import AdamW 20 | from src.baseline.tokenizer import CtcTokenizer 21 | from typing import Optional 22 | from transformers.optimization import get_linear_schedule_with_warmup 23 | from torch.nn.utils import clip_grad_norm 24 | from src.metric import ctc_f1 25 | 26 | 27 | class TrainerCtc: 28 | def __init__(self, 29 | in_model_dir: str, 30 | out_model_dir: str, 31 | epochs: int, 32 | batch_size: int, 33 | learning_rate: float, 34 | max_seq_len: int, 35 | train_fp: str, 36 | dev_fp: str = None, 37 | test_fp: str = None, 38 | random_seed_num: int = 42, 39 | check_val_every_n_epoch: Optional[float] = 0.5, 40 | early_stop_times: Optional[int] = 100, 41 | freeze_embedding: bool = False, 42 | warmup_steps: int = -1, 43 | max_grad_norm: Optional[float] = None, 44 | dev_data_ratio: Optional[float] = 0.2, 45 | with_train_epoch_metric: bool = False, 46 | training_mode: str = 'normal', 47 | loss_ignore_id = -100, 48 | ctc_label_vocab_dir: str = 'src/baseline/ctc_vocab', 49 | amp: Optional[bool] = True, 50 | ddp_nodes_num: Optional[int] = 1, 51 | ddp_local_rank: Optional[int] = -1, 52 | **kwargs 53 | ): 54 | 55 | """ 56 | # in_model_dir 预训练模型目录 57 | # out_model_dir 输出模型目录 58 | # epochs 训练轮数 59 | # batch_size batch文本数 60 | # max_seq_len 最大句子长度 61 | # learning_rate 学习率 62 | # train_fp 训练集文件 63 | # test_fp 测试集文件 64 | # dev_data_ratio 没有验证集时,会从训练集按照比例分割出验证集 65 | # random_seed_num 随机种子 66 | # warmup_steps 预热steps 67 | # check_val_every_n_epoch 每几轮对验证集进行指标计算 68 | # training_mode 训练模式 包括 ddp,dp, normal,分别代表分布式,并行,普通训练 69 | # amp 是否开启混合精度 70 | # freeze_embedding 是否冻结bert embed层 71 | """ 72 | 73 | current_time = time.strftime("_%YY%mM%dD%HH", time.localtime()) 74 | self.in_model_dir = in_model_dir 75 | self.out_model_dir = os.path.join(out_model_dir, '')[ 76 | :-1] + current_time 77 | 78 | self.epochs = epochs 79 | self.batch_size = batch_size 80 | self.learning_rate = learning_rate 81 | self.max_seq_len = max_seq_len 82 | self.random_seed_num = random_seed_num 83 | self.freeze_embedding = freeze_embedding 84 | self.train_fp = train_fp 85 | self.dev_fp = dev_fp 86 | self.test_fp = test_fp 87 | self.ctc_label_vocab_dir = ctc_label_vocab_dir 88 | self.check_val_every_n_epoch = check_val_every_n_epoch 89 | self.early_stop_times = early_stop_times 90 | self.dev_data_ratio = dev_data_ratio 91 | self._loss_ignore_id = loss_ignore_id 92 | assert training_mode in ('normal', 'dp', 'ddp') # 普通,数据并行,分布式训练 93 | self.training_mode = training_mode 94 | self.ddp_nodes_num = ddp_nodes_num 95 | self.ddp_local_rank = int(ddp_local_rank) 96 | self.dev_data_ratio = dev_data_ratio 97 | self.amp = amp 98 | self._warmup_steps = warmup_steps 99 | self._max_grad_norm = max_grad_norm 100 | self.with_train_epoch_metric = with_train_epoch_metric 101 | if not os.path.exists(self.out_model_dir) and self.ddp_local_rank in (-1, 0): 102 | os.mkdir(self.out_model_dir) 103 | 104 | if self.amp: 105 | self.scaler = GradScaler() # auto mixed precision 106 | self.fit_seed(self.random_seed_num) 107 | self.tokenizer = CtcTokenizer.from_pretrained( 108 | self.in_model_dir) 109 | self.train_ds, self.dev_ds, self.test_ds = self.load_data() 110 | self.model, self.optimizer, self.scheduler = self.load_suite() 111 | 112 | self._id2dtag, self._dtag2id, self._id2ctag, self._ctag2id = self.load_label_vocab() 113 | 114 | self._keep_id_in_ctag = self._ctag2id['$KEEP'] 115 | 116 | @staticmethod 117 | def load_texts_from_fp(file_path): 118 | trg_texts, src_texts = [], [] 119 | 120 | if '.txt' in file_path: 121 | for line in open(file_path, 'r', encoding='utf-8'): 122 | line = line.strip().split('\t') 123 | if line: 124 | # 需注意txt文件中src和trg前后关系 125 | src_texts.append(line[0]) 126 | trg_texts.append(line[1]) 127 | elif '.json' in file_path: 128 | json_data = json.load(open(file_path, 'r', encoding='utf-8')) 129 | for line in json_data: 130 | src_texts.append(line['source']) 131 | trg_texts.append(line['target']) 132 | 133 | 134 | return src_texts, trg_texts 135 | 136 | def load_label_vocab(self): 137 | dtag_fp = os.path.join(self.ctc_label_vocab_dir, 'ctc_detect_tags.txt') 138 | ctag_fp = os.path.join(self.ctc_label_vocab_dir, 139 | 'ctc_correct_tags.txt') 140 | 141 | id2dtag = [line.strip() for line in open(dtag_fp, encoding='utf8')] 142 | d_tag2id = {v: i for i, v in enumerate(id2dtag)} 143 | 144 | id2ctag = [line.strip() for line in open(ctag_fp, encoding='utf8')] 145 | c_tag2id = {v: i for i, v in enumerate(id2ctag)} 146 | logger.info('d_tag num: {}, d_tags:{}'.format(len(id2dtag), d_tag2id)) 147 | return id2dtag, d_tag2id, id2ctag, c_tag2id 148 | 149 | def load_data(self) -> List[DataLoader]: 150 | 151 | # 加载train-dataset 152 | train_src_texts, train_trg_texts = self.load_texts_from_fp( 153 | self.train_fp) 154 | 155 | train_ds = DatasetCTC( 156 | in_model_dir=self.in_model_dir, 157 | src_texts=train_src_texts, 158 | trg_texts=train_trg_texts, 159 | max_seq_len=self.max_seq_len, 160 | ) 161 | 162 | if self.dev_fp is not None: 163 | dev_src_texts, dev_trg_texts = self.load_texts_from_fp( 164 | self.dev_fp) 165 | dev_ds = DatasetCTC( 166 | in_model_dir=self.in_model_dir, 167 | src_texts=dev_src_texts, 168 | trg_texts=dev_trg_texts, 169 | max_seq_len=self.max_seq_len, 170 | ) 171 | else: 172 | # 如果没有dev set,则从训练集切分 173 | _dev_size = max(int(len(train_ds) * self.dev_data_ratio), 1) 174 | _train_size = len(train_ds) - _dev_size 175 | train_ds, dev_ds = torch.utils.data.random_split( 176 | train_ds, [_train_size, _dev_size]) 177 | 178 | if self.test_fp is not None: 179 | test_src_texts, test_trg_texts = self.load_texts_from_fp( 180 | self.test_fp) 181 | test_ds = DatasetCTC( 182 | in_model_dir=self.in_model_dir, 183 | src_texts=test_src_texts, 184 | trg_texts=test_trg_texts, 185 | max_seq_len=self.max_seq_len, 186 | ) 187 | 188 | else: 189 | test_ds = None 190 | 191 | self._train_size = len(train_ds) 192 | self._dev_size = len(dev_ds) 193 | self._test_size = len(test_ds) if test_ds is not None else 0 194 | 195 | self._train_steps = ceil( 196 | self._train_size / self.batch_size) # 训练总step num 197 | self._dev_steps = ceil(self._dev_size / self.batch_size) # 训练总step num 198 | self._test_steps = ceil( 199 | self._test_size / self.batch_size) # 训练总step num 200 | 201 | 202 | 203 | # 如果是分布式训练,则步数要除以总节点数 204 | self._train_steps = ceil(self._train_steps / self.ddp_nodes_num) 205 | self._dev_steps = ceil(self._dev_steps / self.ddp_nodes_num) 206 | self._test_steps = ceil(self._test_steps / self.ddp_nodes_num) 207 | 208 | self.check_val_every_n_steps = ceil( 209 | self.check_val_every_n_epoch * self._train_steps) # 每多少个step进行验证 210 | 211 | # if self.check_val_every_n_steps < 10: 212 | # self.check_val_every_n_steps = 10 213 | 214 | logger.info('_train_size:{}'.format(self._train_size)) 215 | logger.info('_dev_size:{}'.format(self._dev_size)) 216 | logger.info('_test_size:{}'.format(self._test_size)) 217 | 218 | logger.info('Total Steps of one epoch : {}'.format(self._train_steps)) 219 | logger.info('Evaluation every {} steps'.format( 220 | self.check_val_every_n_steps)) 221 | 222 | if self.ddp_local_rank != -1: 223 | # 如果使用分布式训练, 对train_ds进行DistributedSampler 224 | train_ds = torch.utils.data.dataloader.DataLoader( 225 | train_ds, sampler=DistributedSampler(train_ds), batch_size=self.batch_size, num_workers=8) 226 | 227 | dev_ds = torch.utils.data.dataloader.DataLoader( 228 | dev_ds, batch_size=self.batch_size, shuffle=False, num_workers=8) 229 | 230 | if test_ds is not None: 231 | test_ds = torch.utils.data.dataloader.DataLoader( 232 | test_ds, batch_size=self.batch_size, shuffle=False, num_workers=8) 233 | 234 | else: 235 | train_ds = torch.utils.data.dataloader.DataLoader( 236 | train_ds, batch_size=self.batch_size, shuffle=True, num_workers=8) 237 | dev_ds = torch.utils.data.dataloader.DataLoader( 238 | dev_ds, batch_size=self.batch_size, shuffle=False, num_workers=8) 239 | if test_ds is not None: 240 | test_ds = torch.utils.data.dataloader.DataLoader( 241 | test_ds, batch_size=self.batch_size, shuffle=False, num_workers=8) 242 | 243 | return [train_ds, dev_ds, test_ds] 244 | 245 | def load_suite(self): 246 | "model" 247 | model = ModelingCtcBert.from_pretrained( 248 | self.in_model_dir) 249 | 250 | if self.freeze_embedding: 251 | embedding_name_list = ('embeddings.word_embeddings.weight', 252 | 'embeddings.position_embeddings.weight', 253 | 'embeddings.token_type_embeddings.weight') 254 | for named_para in model.named_parameters(): 255 | named_para[1].requires_grad = False if named_para[ 256 | 0] in embedding_name_list else True 257 | 258 | "optimizer" 259 | # bert常用权重衰减 260 | model_params = list(model.named_parameters()) 261 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 262 | optimizer_grouped_parameters = [{ 263 | 'params': [ 264 | p for n, p in model_params 265 | if not any(nd in n for nd in no_decay) 266 | ], 267 | 'weight_decay': 268 | 0.01 269 | }, { 270 | 'params': 271 | [p for n, p in model_params if any(nd in n for nd in no_decay)], 272 | 'weight_decay': 273 | 0.0 274 | }] 275 | 276 | optimizer = AdamW(optimizer_grouped_parameters, lr=self.learning_rate) 277 | "scheduler" 278 | scheduler = get_linear_schedule_with_warmup( 279 | optimizer=optimizer, 280 | num_warmup_steps=self._warmup_steps, 281 | num_training_steps=len(self.train_ds)*self.epochs 282 | ) if self._warmup_steps != -1 else None 283 | return model, optimizer, scheduler 284 | 285 | def save_model(self, out_model_dir): 286 | "保存模型" 287 | model_to_save = self.model.module if hasattr( 288 | self.model, 'module') else self.model 289 | if self.ddp_local_rank in (-1, 0): 290 | if not os.path.exists(out_model_dir): 291 | 292 | os.mkdir(out_model_dir) 293 | model_to_save.save_pretrained(out_model_dir) 294 | self.tokenizer.save_pretrained(out_model_dir) 295 | logger.info('=========== New Model saved at {} ============='.format( 296 | out_model_dir)) 297 | 298 | @staticmethod 299 | def fit_seed(random_seed_num): 300 | "固定随机种子 保证每次结果一样" 301 | np.random.seed(random_seed_num) 302 | torch.manual_seed(random_seed_num) 303 | torch.cuda.manual_seed_all(random_seed_num) 304 | torch.backends.cudnn.deterministic = True 305 | 306 | def train(self): 307 | """[summary] 308 | Args: 309 | wait_cuda_memory (bool, optional): []. Defaults to False. 310 | Returns: 311 | [type]: [description] 312 | """ 313 | 314 | 315 | self.equip_cuda() 316 | best_eval_score = 0 317 | ith_early_stop_time = 0 318 | final_eval_scores_for_early_stop = [] 319 | steps = ceil(self._train_size / self.batch_size) 320 | epoch_end_flag = False # 每个epoch再验证 321 | for epoch in range(1, self.epochs + 1): 322 | if self.ddp_local_rank != -1: 323 | self.train_ds.sampler.set_epoch(epoch) 324 | if self.with_train_epoch_metric: 325 | epoch_preds, epoch_gold_labels = [], [] 326 | else: 327 | epoch_preds, epoch_gold_labels = None, None 328 | epoch_c_loss = 0 329 | for step, batch_ds in track(enumerate(self.train_ds), 330 | description='Training', 331 | total=self._train_steps): 332 | step += 1 333 | 334 | # 训练过程可能有些许数据出错,跳过 335 | try: 336 | batch_c_loss, batch_gold, batch_pred = self.train_step( 337 | batch_ds, return_for_epoch_metric=self.with_train_epoch_metric) 338 | except RuntimeError as e: 339 | logger.error('ignore training step error!!') 340 | logger.exception(e) 341 | continue 342 | 343 | if self.with_train_epoch_metric: 344 | epoch_preds += batch_pred 345 | epoch_gold_labels += batch_gold 346 | 347 | epoch_c_loss += batch_c_loss 348 | if (step % self.check_val_every_n_steps == 0 or epoch_end_flag) and self.ddp_local_rank in (-1, 0): 349 | # 到达验证步数,则开始验证,保存模型,记录最大的dev指标 350 | logger.info('[Start Evaluating]: local rank {}'.format( 351 | self.ddp_local_rank)) 352 | epoch_end_flag = False 353 | eval_epoch_c_loss, eval_epoch_precision, eval_epoch_recall, eval_epoch_f1 = self.evaluate( 354 | dataset_type='dev') 355 | log_text = '[Evaluating] Epoch {}/{}, Step {}/{}, ' \ 356 | 'epoch_c_loss:{}, epoch_precision:{}, epoch_recall:{}, epoch_f1:{}, ' 357 | logger.info( 358 | log_text.format(epoch, self.epochs, step, steps, 359 | eval_epoch_c_loss, eval_epoch_precision, 360 | eval_epoch_recall, eval_epoch_f1, 361 | )) 362 | if self.test_ds is not None: 363 | 364 | test_epoch_c_loss, test_epoch_precision, test_epoch_recall, test_epoch_f1 = self.evaluate( 365 | dataset_type='test') 366 | 367 | log_text = '[Testing] Epoch {}/{}, Step {}/{}, ' \ 368 | 'epoch_c_loss:{}, epoch_precision:{}, epoch_recall:{}, epoch_f1:{}' 369 | logger.info( 370 | log_text.format(epoch, self.epochs, step, steps, 371 | test_epoch_c_loss, test_epoch_precision, 372 | test_epoch_recall, test_epoch_f1, 373 | )) 374 | 375 | if eval_epoch_f1 >= 0: 376 | # 377 | if eval_epoch_f1 > best_eval_score: 378 | best_eval_score = eval_epoch_f1 379 | # 重置early stop次数 380 | ith_early_stop_time = 0 381 | final_eval_scores_for_early_stop = [] 382 | else: 383 | # 验证集指标在下降,记录次数,为提前结束做准备。 384 | ith_early_stop_time += 1 385 | final_eval_scores_for_early_stop.append( 386 | eval_epoch_f1) 387 | if ith_early_stop_time >= self.early_stop_times: 388 | logger.info( 389 | '[Early Stop], final eval_score:{}'.format( 390 | final_eval_scores_for_early_stop)) 391 | return 392 | if self.test_ds is not None: 393 | test_f1_str = str(round(test_epoch_f1 * 100, 394 | 2)).replace('.', '_') + '%' 395 | else: 396 | test_f1_str = 'None' 397 | dev_f1_str = str(round(eval_epoch_f1 * 100, 398 | 2)).replace('.', '_') + '%' 399 | metric_str = 'epoch{},step{},testf1_{},devf1_{}'.format(epoch, step, 400 | test_f1_str, dev_f1_str) 401 | saved_dir = os.path.join( 402 | self.out_model_dir, metric_str) 403 | if self.ddp_local_rank in (-1, 0): 404 | self.save_model(saved_dir) 405 | 406 | if eval_epoch_f1 >= 1: 407 | # 验证集指标达到100% 408 | logger.info( 409 | 'Devset f1-score has reached to 1.0, check testset f1') 410 | if self.test_ds is not None and test_epoch_f1>=1: 411 | logger.info( 412 | 'Testset f1-score has reached to 1.0, stop training') 413 | return 414 | 415 | if self.with_train_epoch_metric: 416 | epoch_src = [self._keep_id_in_ctag]*len(epoch_src) 417 | (d_precision, d_recall, d_f1), (c_precision, c_recall, c_f1) = ctc_f1( 418 | src_texts=[epoch_src], trg_texts=[epoch_gold_labels], pred_texts=[epoch_preds]) 419 | 420 | else: 421 | epoch_precision, epoch_recall, epoch_f1 = None, None, None 422 | 423 | if self.ddp_local_rank in (-1, 0): 424 | logger.info('Epoch End..') 425 | epoch_end_flag = True 426 | log_text = '[Training epoch] Epoch {}/{},' \ 427 | 'epoch_c_loss:{}, epoch_precision:{}, epoch_recall:{}, epoch_f1:{}' 428 | logger.info( 429 | log_text.format(epoch, self.epochs, epoch_c_loss, 430 | epoch_precision, epoch_recall, epoch_f1)) 431 | 432 | return 1 433 | 434 | def equip_cuda(self): 435 | 436 | if torch.cuda.is_available(): 437 | self.model.cuda() 438 | # self.criterion.cuda() 439 | device_count = torch.cuda.device_count() 440 | devices_ids = list(range(device_count)) 441 | if self.training_mode == 'dp' and device_count > 1: 442 | self.model = torch.nn.DataParallel(self.model, 443 | device_ids=devices_ids) 444 | logger.info('DP training, use cuda list:{}'.format( 445 | devices_ids)) 446 | elif self.ddp_local_rank != -1: 447 | self.model = DDP(self.model, device_ids=[int( 448 | self.ddp_local_rank)], output_device=int(self.ddp_local_rank), find_unused_parameters=True) 449 | logger.info('DDP training, use cuda list:{}'.format( 450 | devices_ids)) 451 | else: 452 | logger.info('Use single cuda to train') 453 | else: 454 | logger.info('Use cpu to train') 455 | 456 | def train_step(self, batch_ds, return_for_epoch_metric=True): 457 | 458 | self.model.train() 459 | 460 | if torch.cuda.is_available(): 461 | for k, v in batch_ds.items(): 462 | batch_ds[k] = v.cuda() 463 | 464 | self.optimizer.zero_grad() 465 | 466 | if self.amp and torch.cuda.is_available(): 467 | # 混合精度模式 468 | with autocast(): 469 | detect_outputs, correct_output, batch_loss = self.model( 470 | input_ids=batch_ds['input_ids'], 471 | attention_mask=batch_ds['attention_mask'], 472 | token_type_ids=batch_ds['token_type_ids'], 473 | detect_labels=batch_ds['d_tags'], 474 | correct_labels=batch_ds['c_tags'], 475 | ) 476 | batch_loss = batch_loss.mean() 477 | if self._max_grad_norm is None: 478 | self.scaler.scale(batch_loss).backward() 479 | self.scaler.step(self.optimizer) 480 | self.scaler.update() 481 | else: 482 | self.scaler.scale(batch_loss).backward() 483 | # Unscales the gradients of optimizer's assigned params in-place 484 | self.scaler.unscale_(self.optimizer) 485 | # Since the gradients of optimizer's assigned params are unscaled, clips as usual: 486 | torch.nn.utils.clip_grad_norm_( 487 | self.model.parameters(), self._max_grad_norm) 488 | # optimizer's gradients are already unscaled, so scaler.step does not unscale them, 489 | # although it still skips optimizer.step() if the gradients contain infs or NaNs. 490 | self.scaler.step(self.optimizer) 491 | 492 | # Updates the scale for next iteration. 493 | self.scaler.update() 494 | else: 495 | # 常规模式 496 | detect_outputs, correct_output, batch_loss = self.model( 497 | input_ids=batch_ds['input_ids'], 498 | attention_mask=batch_ds['attention_mask'], 499 | token_type_ids=batch_ds['token_type_ids'], 500 | detect_labels=batch_ds['d_tags'], 501 | correct_labels=batch_ds['c_tags'], 502 | ) 503 | batch_loss = batch_loss.mean() 504 | if self._max_grad_norm is None: 505 | batch_loss.backward() 506 | self.optimizer.step() 507 | else: 508 | batch_loss.backward() 509 | clip_grad_norm(self.model.parameters(), self._max_grad_norm) 510 | self.optimizer.step() 511 | 512 | # scheduler 513 | if self._warmup_steps != -1: 514 | self.scheduler.step() 515 | 516 | if return_for_epoch_metric: 517 | batch_gold = batch_ds['c_tags'].view(-1).tolist() 518 | batch_pred = torch.argmax(correct_output, 519 | dim=-1).view(-1).tolist() 520 | 521 | seq_true_idx = np.argwhere(batch_gold != self._loss_ignore_id) 522 | batch_gold = batch_gold[seq_true_idx].squeeze() 523 | batch_pred = batch_pred[seq_true_idx].squeeze() 524 | 525 | return batch_loss.item(), list(batch_gold), list(batch_pred) 526 | else: 527 | 528 | return batch_loss.item(), None, None 529 | 530 | @torch.no_grad() 531 | def evaluate(self, dataset_type='dev'): 532 | # 分布式训练时, 外层调用前会确认节点为-1,0时, 才会做验证 533 | self.model.eval() 534 | epoch_loss = 0 535 | epoch_preds, epoch_gold_labels, epoch_src = [], [], [] 536 | ds = self.test_ds if dataset_type == 'test' else self.dev_ds 537 | for batch_ds in ds: 538 | if torch.cuda.is_available(): 539 | for k, v in batch_ds.items(): 540 | batch_ds[k] = v.cuda() 541 | if self.amp and torch.cuda.is_available(): 542 | with autocast(): 543 | detect_outputs, correct_output, batch_loss = self.model( 544 | input_ids=batch_ds['input_ids'], 545 | attention_mask=batch_ds['attention_mask'], 546 | token_type_ids=batch_ds['token_type_ids'], 547 | detect_labels=batch_ds['d_tags'], 548 | correct_labels=batch_ds['c_tags'], 549 | ) 550 | else: 551 | detect_outputs, correct_output, batch_loss = self.model( 552 | input_ids=batch_ds['input_ids'], 553 | attention_mask=batch_ds['attention_mask'], 554 | token_type_ids=batch_ds['token_type_ids'], 555 | detect_labels=batch_ds['d_tags'], 556 | correct_labels=batch_ds['c_tags'], 557 | ) 558 | batch_loss = batch_loss.mean() 559 | 560 | # correct 561 | 562 | batch_gold = batch_ds['c_tags'].view(-1).cpu().numpy() 563 | batch_pred = torch.argmax(correct_output, 564 | dim=-1).view(-1).cpu().numpy() 565 | batch_src = batch_ds['input_ids'].view(-1).cpu().numpy() 566 | 567 | seq_true_idx = np.argwhere(batch_gold != self._loss_ignore_id) # 获取非pad部分的标签 568 | 569 | batch_gold = batch_gold[seq_true_idx].squeeze() 570 | batch_pred = batch_pred[seq_true_idx].squeeze() 571 | batch_src = batch_src[seq_true_idx].squeeze() 572 | 573 | epoch_src += list(batch_src) 574 | 575 | epoch_gold_labels += list(batch_gold) 576 | epoch_preds += list(batch_pred) 577 | 578 | epoch_loss += batch_loss.item() 579 | 580 | "因为输出和输入空间不一样,所以计算指标要对应输出空间,原字符对应输出空间的keep" 581 | epoch_src = [self._keep_id_in_ctag]*len(epoch_src) 582 | (d_precision, d_recall, d_f1), (c_precision, c_recall, c_f1) = ctc_f1( 583 | src_texts=[epoch_src], trg_texts=[epoch_gold_labels], pred_texts=[epoch_preds]) 584 | 585 | return epoch_loss, c_precision, c_recall, c_f1 586 | -------------------------------------------------------------------------------- /src/corrector.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import List 3 | 4 | from src.baseline.predictor import PredictorCtc 5 | 6 | 7 | class Corrector: 8 | def __init__(self, in_model_dir:str): 9 | """_summary_ 10 | 11 | Args: 12 | in_model_dir (str): 训练好的模型目录 13 | """ 14 | self._predictor = PredictorCtc( 15 | in_model_dir=in_model_dir, 16 | ctc_label_vocab_dir='src/baseline/ctc_vocab', 17 | use_cuda=True, 18 | cuda_id=None, 19 | ) 20 | 21 | 22 | def __call__(self, texts:List[str]) -> List[str]: 23 | pred_outputs = self._predictor.predict(texts) 24 | pred_texts = [PredictorCtc.output2text(output) for output in pred_outputs] 25 | return pred_texts 26 | 27 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | import json 2 | from src.corrector import Corrector 3 | from src.metric import final_f1_score 4 | 5 | 6 | def evaluate(in_model_dir,json_data_file, log_fp='logs/f1_score.log'): 7 | """输入模型目录,数据, 计算模型在该数据下的指标 8 | 9 | """ 10 | 11 | json_data = json.load(open(json_data_file, 'r', encoding='utf-8')) 12 | src_texts, trg_texts = [], [] 13 | for line in json_data: 14 | src_texts.append(line['source']) 15 | trg_texts.append(line['target']) 16 | 17 | corrector = Corrector(in_model_dir=in_model_dir) 18 | pred_texts = corrector(texts=src_texts) 19 | f1_score = final_f1_score(src_texts=src_texts, 20 | pred_texts=pred_texts, 21 | trg_texts=trg_texts, 22 | log_fp=log_fp) 23 | 24 | return f1_score -------------------------------------------------------------------------------- /src/metric.py: -------------------------------------------------------------------------------- 1 | import json 2 | from difflib import SequenceMatcher 3 | 4 | from src import logger 5 | 6 | 7 | def f1(precision, recall): 8 | if precision + recall == 0: 9 | return 0 10 | return round(2 * precision * recall / (precision + recall), 4) 11 | 12 | 13 | def compute_label_nums(src_text, trg_text, pred_text, log_error_to_fp=None): 14 | assert len(src_text) == len(trg_text) == len( 15 | pred_text), 'src_text:{}, trg_text:{}, pred_text:{}'.format(src_text, trg_text, pred_text) 16 | pred_num, detect_num, correct_num, ref_num = 0, 0, 0, 0 17 | 18 | for j in range(len(trg_text)): 19 | src_char, trg_char, pred_char = src_text[j], trg_text[j], pred_text[j] 20 | if src_char != trg_char: 21 | ref_num += 1 22 | if src_char != pred_char: 23 | detect_num += 1 24 | elif log_error_to_fp is not None and pred_char != trg_char and pred_char == src_char: 25 | log_text = '漏报\t{}\t{}\t{}\t{}\t{}\t{}\n'.format( 26 | src_text, trg_text, src_char, trg_char, pred_char, j) 27 | log_error_to_fp.write(log_text) 28 | 29 | if src_char != pred_char: 30 | pred_num += 1 31 | if pred_char == trg_char: 32 | correct_num += 1 33 | elif log_error_to_fp is not None and pred_char != trg_char and src_char == trg_char: 34 | log_text = '误报\t{}\t{}\t{}\t{}\t{}\t{}\n'.format( 35 | src_text, trg_text, src_char, trg_char, pred_char, j) 36 | log_error_to_fp.write(log_text) 37 | elif log_error_to_fp is not None and pred_char != trg_char and src_char != trg_char: 38 | log_text = '错报(检对报错)\t{}\t{}\t{}\t{}\t{}\t{}\n'.format( 39 | src_text, trg_text, src_char, trg_char, pred_char, j) 40 | log_error_to_fp.write(log_text) 41 | 42 | return (pred_num, detect_num, correct_num, ref_num) 43 | 44 | 45 | def ctc_f1(src_texts, trg_texts, pred_texts, log_error_to_fp=None): 46 | """训练过程中字级别序列标注任务的F1计算 47 | 48 | Args: 49 | src_texts ([type]): [源文本] 50 | trg_texts ([type]): [目标文本] 51 | pred_texts ([type]): [预测文本] 52 | log_error_to_fp : 文本路径 53 | 54 | Returns: 55 | [type]: [description] 56 | """ 57 | if isinstance(src_texts, str): 58 | src_texts = [src_texts] 59 | if isinstance(trg_texts, str): 60 | trg_texts = [trg_texts] 61 | if isinstance(pred_texts, str): 62 | pred_texts = [pred_texts] 63 | lines_length = len(trg_texts) 64 | assert len(src_texts) == lines_length == len( 65 | pred_texts), 'keep equal length' 66 | all_pred_num, all_detect_num, all_correct_num, all_ref_num = 0, 0, 0, 0 67 | if log_error_to_fp is not None: 68 | f = open(log_error_to_fp, 'w', encoding='utf-8') 69 | f.write('type\tsrc_text\ttrg_text\tsrc_char\ttrg_char\tpred_char\tchar_index\n') 70 | else: 71 | f = None 72 | 73 | all_nums = [compute_label_nums(src_texts[i], trg_texts[i], pred_texts[i], f) 74 | for i in range(lines_length)] 75 | if log_error_to_fp is not None: 76 | f.close() 77 | for i in all_nums: 78 | all_pred_num += i[0] 79 | all_detect_num += i[1] 80 | all_correct_num += i[2] 81 | all_ref_num += i[3] 82 | 83 | d_precision = round(all_detect_num/all_pred_num, 84 | 4) if all_pred_num != 0 else 0 85 | d_recall = round(all_detect_num/all_ref_num, 4) if all_ref_num != 0 else 0 86 | c_precision = round(all_correct_num/all_pred_num, 87 | 4) if all_pred_num != 0 else 0 88 | c_recall = round(all_correct_num/all_ref_num, 4) if all_ref_num != 0 else 0 89 | 90 | d_f1, c_f1 = f1(d_precision, d_recall), f1(c_precision, c_recall) 91 | 92 | logger.info('====== [Char Level] ======') 93 | logger.info('d_precsion:{}%, d_recall:{}%, d_f1:{}%'.format( 94 | d_precision*100, d_recall*100, d_f1*100)) 95 | logger.info('c_precsion:{}%, c_recall:{}%, c_f1:{}%'.format( 96 | c_precision*100, c_recall*100, c_f1*100)) 97 | logger.info('error_char_num: {}'.format(all_ref_num)) 98 | return (d_precision, d_recall, d_f1), (c_precision, c_recall, c_f1) 99 | 100 | 101 | def ctc_comp_f1_sentence_level(src_texts, pred_texts, trg_texts): 102 | "计算负样本的 句子级 纠正级别 F1" 103 | correct_ref_num, correct_pred_num, correct_recall_num, correct_f1 = 0, 0, 0, 0 104 | for src_text, pred_text, trg_text in zip(src_texts, pred_texts, trg_texts): 105 | if src_text != pred_text: 106 | correct_pred_num += 1 107 | if src_text != trg_text: 108 | correct_ref_num += 1 109 | if src_text != trg_text and pred_text == trg_text: 110 | correct_recall_num += 1 111 | 112 | assert correct_ref_num > 0, '文本中未发现错误,无法计算指标,该指标只计算含有错误的样本。' 113 | 114 | correct_precision = 0 if correct_recall_num == 0 else correct_recall_num / correct_pred_num 115 | correct_recall = 0 if correct_recall_num == 0 else correct_recall_num / correct_ref_num 116 | correct_f1 = f1(correct_precision, correct_recall) 117 | 118 | return correct_precision, correct_recall, correct_f1 119 | 120 | 121 | def ctc_comp_f1_token_level(src_texts, pred_texts, trg_texts): 122 | "字级别,负样本 检测级别*0.8+纠正级别*0.2 f1" 123 | def compute_detect_correct_label_list(src_text, trg_text): 124 | detect_ref_list, correct_ref_list = [], [] 125 | diffs = SequenceMatcher(None, src_text, trg_text).get_opcodes() 126 | for (tag, src_i1, src_i2, trg_i1, trg_i2) in diffs: 127 | 128 | if tag == 'replace': 129 | assert src_i2 - src_i1 == trg_i2 - trg_i1 130 | for count, src_i in enumerate(range(src_i1, src_i2)): 131 | trg_token = trg_text[trg_i1+count] 132 | detect_ref_list.append(src_i) 133 | correct_ref_list.append((src_i, trg_token)) 134 | 135 | elif tag == 'delete': 136 | trg_token = 'D'*(src_i2-src_i1) 137 | detect_ref_list.append(src_i1) 138 | correct_ref_list.append((src_i1, trg_token)) 139 | 140 | elif tag == 'insert': 141 | trg_token = trg_text[trg_i1:trg_i2] 142 | detect_ref_list.append(src_i1) 143 | correct_ref_list.append((src_i1, trg_token)) 144 | 145 | return detect_ref_list, correct_ref_list 146 | 147 | # 字级别 148 | detect_ref_num, detect_pred_num, detect_recall_num, detect_f1 = 0, 0, 0, 0 149 | correct_ref_num, correct_pred_num, correct_recall_num, correct_f1 = 0, 0, 0, 0 150 | 151 | for src_text, pred_text, trg_text in zip(src_texts, pred_texts, trg_texts): 152 | # 先统计检测和纠正标签 153 | try: 154 | detect_ref_list, correct_ref_list = compute_detect_correct_label_list( 155 | src_text, trg_text) 156 | except Exception as e: 157 | # 可能Eval dataset有个别错误,暂时跳过 158 | continue 159 | try: 160 | # 处理bad case 161 | detect_pred_list, correct_pred_list = compute_detect_correct_label_list( 162 | src_text, pred_text) 163 | except Exception as e: 164 | logger.exception(e) 165 | detect_pred_list, correct_pred_list = [], [] 166 | 167 | 168 | detect_ref_num += len(detect_ref_list) 169 | detect_pred_num += len(detect_pred_list) 170 | detect_recall_num += len(set(detect_ref_list) 171 | & set(detect_pred_list)) 172 | 173 | correct_ref_num += len(correct_ref_list) 174 | correct_pred_num += len(correct_pred_list) 175 | correct_recall_num += len(set(correct_ref_list) 176 | & set(correct_pred_list)) 177 | 178 | assert correct_ref_num > 0, '文本中未发现错误,无法计算指标,该指标只计算含有错误的样本。' 179 | 180 | detect_precision = 0 if detect_pred_num == 0 else detect_recall_num / detect_pred_num 181 | detect_recall = 0 if detect_ref_num == 0 else detect_recall_num / detect_ref_num 182 | 183 | correct_precision = 0 if detect_pred_num == 0 else correct_recall_num / correct_pred_num 184 | correct_recall = 0 if detect_ref_num == 0 else correct_recall_num / correct_ref_num 185 | 186 | detect_f1 = f1(detect_precision, detect_recall) 187 | correct_f1 = f1(correct_precision, correct_recall) 188 | 189 | final_f1 = detect_f1*0.8+correct_f1*0.2 190 | 191 | return final_f1, [detect_precision, detect_recall, detect_f1], [correct_precision, correct_recall, correct_f1] 192 | 193 | 194 | def final_f1_score(src_texts, 195 | pred_texts, 196 | trg_texts, 197 | log_fp='logs/f1_score.log'): 198 | """"最终输出结果F1计算,综合了句级F1和字级F1" 199 | 200 | Args: 201 | src_texts (_type_): 源文本 202 | pred_texts (_type_): 预测文本 203 | trg_texts (_type_): 目标文本 204 | log_fp (str, optional): _description_. Defaults to 'logs/f1_score.log'. 205 | 206 | Returns: 207 | _type_: _description_ 208 | """ 209 | 210 | 211 | 212 | token_level_f1, detect_metrics, correct_metrcis = ctc_comp_f1_token_level( 213 | src_texts, pred_texts, trg_texts) 214 | sent_level_p, sent_level_r, sent_level_f1 = ctc_comp_f1_sentence_level( 215 | src_texts, pred_texts, trg_texts) 216 | final_f1 = round(0.8*token_level_f1 + sent_level_f1*0.2, 4) 217 | 218 | json_data = { 219 | 220 | 'token_level:[detect_precision, detect_recall, detect_f1]': detect_metrics, 221 | 'token_level:[correct_precision, correct_recall, correct_f1] ': correct_metrcis, 222 | 'token_level:f1': token_level_f1, 223 | 224 | 'sentence_level:[correct_precision, correct_recall]': [sent_level_p, sent_level_r], 225 | 'sentence_level:f1': sent_level_f1, 226 | 227 | 'final_f1': final_f1 228 | } 229 | _log_fp = open(log_fp, 'w', encoding='utf-8') 230 | json.dump(json_data, _log_fp, indent=4) 231 | logger.info('final f1:{}'.format(final_f1)) 232 | logger.info('f1 logfile saved at:{}'.format(log_fp)) 233 | return final_f1 -------------------------------------------------------------------------------- /src/prepare_for_upload.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from src.corrector import Corrector 4 | 5 | 6 | def prepare_for_uploadfile(in_model_dir, 7 | in_json_file, 8 | out_json_file='data/test_output.json'): 9 | 10 | json_data_list = json.load(open(in_json_file, 'r', encoding='utf-8')) 11 | src_texts = [ json_data['source'] for json_data in json_data_list] 12 | corrector = Corrector(in_model_dir=in_model_dir) 13 | pred_texts = corrector(texts=src_texts) 14 | output_json_data = [ {'id':json_data['id'], 'inference': pred_text} for json_data, pred_text in zip(json_data_list, pred_texts)] 15 | 16 | out_json_file = open(out_json_file, 'w', encoding='utf-8') 17 | json.dump(output_json_data, out_json_file, ensure_ascii=False, indent=4) 18 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import inspect 4 | import os 5 | from typing import Optional 6 | 7 | import torch 8 | import torch.distributed as dist 9 | from auto_argparse import parse_args_and_run 10 | from torch.multiprocessing import spawn 11 | 12 | from src.baseline.trainer import TrainerCtc 13 | 14 | 15 | 16 | def ddp_train_wrapper(ddp_local_rank, 17 | train_kwargs 18 | ): 19 | "Distributed Data Parallel Training" 20 | # setup ddp env 21 | os.environ['MASTER_ADDR'] = 'localhost' 22 | os.environ['MASTER_PORT'] = '12355' 23 | dist.init_process_group('nccl', rank=ddp_local_rank, 24 | world_size=train_kwargs['ddp_nodes_num']) 25 | torch.cuda.set_device(ddp_local_rank) 26 | train_kwargs['ddp_local_rank'] = ddp_local_rank 27 | trainer = TrainerCtc(**train_kwargs) 28 | trainer.train() 29 | # clear ddp env 30 | dist.destroy_process_group() 31 | 32 | 33 | def train_entrance(in_model_dir: str = 'pretrained_model/chinese-roberta-wwm-ext', 34 | out_model_dir: str = 'model/ctc', 35 | epochs: int = 10, 36 | batch_size: int = 64, 37 | learning_rate: float = 5e-5, 38 | max_seq_len: int = 128, 39 | train_fp: str = 'data/example.txt', 40 | dev_fp: str = None, 41 | test_fp: str = None, 42 | random_seed_num: int = 42, 43 | check_val_every_n_epoch: Optional[float] = 0.5, 44 | early_stop_times: Optional[int] = 100, 45 | freeze_embedding: bool = False, 46 | warmup_steps: int = -1, 47 | max_grad_norm: Optional[float] = None, 48 | dev_data_ratio: Optional[float] = 0.2, 49 | with_train_epoch_metric: bool = False, 50 | training_mode: str = 'normal', 51 | amp: Optional[bool] = True): 52 | """_summary_ 53 | 54 | Args: 55 | # in_model_dir 预训练模型目录 56 | # out_model_dir 输出模型目录 57 | # epochs 训练轮数 58 | # batch_size batch文本数 59 | # max_seq_len 最大句子长度 60 | # learning_rate 学习率 61 | # train_fp 训练集文件 62 | # test_fp 测试集文件 63 | # dev_data_ratio 没有验证集时,会从训练集按照比例分割出验证集 64 | # random_seed_num 随机种子 65 | # check_val_every_n_epoch 每几轮对验证集进行指标计算 66 | # training_mode 训练模式 包括 ddp,dp, normal,分别代表分布式,并行,普通训练 67 | # amp 是否开启混合精度 68 | # freeze_embedding 是否冻结bert embed层 69 | """ 70 | 71 | signature = inspect.signature(train_entrance) 72 | train_kwargs = {} 73 | for param in signature.parameters.values(): 74 | train_kwargs[param.name] = eval(param.name) 75 | 76 | if training_mode in ('normal', 'dp'): 77 | trainer = TrainerCtc(**train_kwargs) 78 | trainer.train() 79 | 80 | elif training_mode == 'ddp': 81 | ddp_nodes_num = torch.cuda.device_count() 82 | train_kwargs['ddp_nodes_num'] = ddp_nodes_num 83 | spawn(ddp_train_wrapper, 84 | args=(train_kwargs,), 85 | nprocs=ddp_nodes_num, 86 | join=True) 87 | 88 | 89 | if __name__ == '__main__': 90 | parse_args_and_run(train_entrance) 91 | --------------------------------------------------------------------------------