├── asserts ├── images │ └── model.png └── MDCSpell:A Multi-task Detector-Corrector Framework for Chinese.pdf ├── README.md └── MDCSpell.ipynb /asserts/images/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iioSnail/MDCSpell_pytorch/HEAD/asserts/images/model.png -------------------------------------------------------------------------------- /asserts/MDCSpell:A Multi-task Detector-Corrector Framework for Chinese.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iioSnail/MDCSpell_pytorch/HEAD/asserts/MDCSpell:A Multi-task Detector-Corrector Framework for Chinese.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MDCSpell_pytorch 2 | 3 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/iioSnail/MDCSpell_pytorch/blob/main/MDCSpell.ipynb) 4 | 5 | [论文地址](https://aclanthology.org/2022.findings-acl.98/) :https://aclanthology.org/2022.findings-acl.98/ 6 | 7 | [MDCSpell: A Multi-task Detector-Corrector Framework for Chinese Spelling Correction](https://aclanthology.org/2022.findings-acl.98/) 论文的**非官方**Pytorch实现。 由于作者并没有公开代码,所以我就尝试自己实现一个,最终我的实验结果如下表: 8 | 9 | | Dataset | Model | D_Precision | D_Recall | D_F1 | C_Prec | C_Rec | C_F1 | 10 | |--|--|--|--|--|--|--|--| 11 | | SIGHAN 13 | MDCSpell | 89.1 | 78.3| 83.4 | 87.5| 76.8 | 81.8 | 12 | | SIGHAN 13 | MDCSpell(复现) | 80.2 | 79.9| 80.0 | 77.2| 76.9 | 77.1 | 13 | | SIGHAN 14 | MDCSpell | 70.2 | 68.8| 69.5 | 69.0| 67.7 | 68.3 | 14 | | SIGHAN 14 | MDCSpell(复现) | 82.8 | 66.6| 73.8 | 79.9| 64.3 | 71.2 | 15 | | SIGHAN 15 | MDCSpell | 80.8 | 80.6| 80.7 | 78.4| 78.2 | 78.3 | 16 | | SIGHAN 15 | MDCSpell(复现) | 86.7 | 76.1| 81.1 | 72.5| 82.7 | 77.3 | 17 | 18 | 这里是我训练了2个epoch的结果,与作者的结论相差不大。如果我增加训练次数的话,也许可以和作者的结果达到一致。 19 | 20 | 21 | 数据集地址:[百度网盘](https://pan.baidu.com/s/1x67LPiYAjLKhO1_2CI6aOA?pwd=skda) ; [Google Drive](https://drive.google.com/drive/folders/1dC09i57lobL91lEbpebDuUBS0fGz-LAk) 22 | 23 | 我的模型:[百度网盘](https://pan.baidu.com/s/1yxuQY8V3ZrmmYcJvTRkDmQ?pwd=pffg) 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /MDCSpell.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "pycharm": { 7 | "name": "#%% md\n" 8 | } 9 | }, 10 | "source": [ 11 | "# 本文内容" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": { 17 | "pycharm": { 18 | "name": "#%% md\n" 19 | } 20 | }, 21 | "source": [ 22 | "本文为MDCSpell: A Multi-task Detector-Corrector Framework for Chinese Spelling Correction论文的Pytorch实现。\n", 23 | "\n", 24 | "[论文地址](https://aclanthology.org/2022.findings-acl.98/): https://aclanthology.org/2022.findings-acl.98/\n", 25 | "\n", 26 | "论文年份:2022\n", 27 | "\n", 28 | "[论文笔记](https://blog.csdn.net/zhaohongfei_358/article/details/126973451):https://blog.csdn.net/zhaohongfei_358/article/details/126973451\n", 29 | "\n", 30 | "论文大致内容:作者基于Transformer和BERT设计了一个多任务的网络来进行CSC(Chinese Spell Checking)任务(中文拼写纠错)。多任务分别是找出哪个字是错的和对错字进行纠正。\n", 31 | "\n", 32 | "由于作者并没有公开代码,所以我就尝试自己实现一个,最终我的实验结果如下表:\n", 33 | "\n", 34 | "| Dataset | Model | D_Precision | D_Recall | D_F1 | C_Prec | C_Rec | C_F1 |\n", 35 | "|--|--|--|--|--|--|--|--|\n", 36 | "| SIGHAN 13 | MDCSpell | 89.1 | 78.3| 83.4 | 87.5| 76.8 | 81.8 |\n", 37 | "| SIGHAN 13 | MDCSpell(复现) | 80.2 | 79.9| 80.0 | 77.2| 76.9 | 77.1 |\n", 38 | "| SIGHAN 14 | MDCSpell | 70.2 | 68.8| 69.5 | 69.0| 67.7 | 68.3 |\n", 39 | "| SIGHAN 14 | MDCSpell(复现) | 82.8 | 66.6| 73.8 | 79.9| 64.3 | 71.2 |\n", 40 | "| SIGHAN 15 | MDCSpell | 80.8 | 80.6| 80.7 | 78.4| 78.2 | 78.3 |\n", 41 | "| SIGHAN 15 | MDCSpell(复现) | 86.7 | 76.1| 81.1 | 72.5| 82.7 | 77.3 |\n", 42 | "\n", 43 | "这里是我训练了2个epoch的结果,与作者的结论相差不大。如果我增加训练次数的话,也许可以和作者的结果达到一致。" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": { 49 | "pycharm": { 50 | "name": "#%% md\n" 51 | } 52 | }, 53 | "source": [ 54 | "# 环境配置" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 1, 60 | "metadata": { 61 | "pycharm": { 62 | "name": "#%%\n" 63 | } 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "try:\n", 68 | " import transformers\n", 69 | "except:\n", 70 | " !pip install transformers" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 2, 76 | "metadata": { 77 | "pycharm": { 78 | "name": "#%%\n" 79 | } 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "import os\n", 84 | "import copy\n", 85 | "import pickle\n", 86 | "\n", 87 | "import torch\n", 88 | "import transformers\n", 89 | "\n", 90 | "from torch import nn\n", 91 | "from torch.utils.data import Dataset, DataLoader\n", 92 | "from transformers import AutoModel, AutoTokenizer\n", 93 | "from tqdm import tqdm" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 3, 99 | "metadata": { 100 | "pycharm": { 101 | "name": "#%%\n" 102 | } 103 | }, 104 | "outputs": [ 105 | { 106 | "data": { 107 | "text/plain": [ 108 | "'1.12.1+cu113'" 109 | ] 110 | }, 111 | "execution_count": 3, 112 | "metadata": {}, 113 | "output_type": "execute_result" 114 | } 115 | ], 116 | "source": [ 117 | "torch.__version__" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 4, 123 | "metadata": { 124 | "pycharm": { 125 | "name": "#%%\n" 126 | } 127 | }, 128 | "outputs": [ 129 | { 130 | "data": { 131 | "text/plain": [ 132 | "'4.21.3'" 133 | ] 134 | }, 135 | "execution_count": 4, 136 | "metadata": {}, 137 | "output_type": "execute_result" 138 | } 139 | ], 140 | "source": [ 141 | "transformers.__version__" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": { 147 | "pycharm": { 148 | "name": "#%% md\n" 149 | } 150 | }, 151 | "source": [ 152 | "# 全局变量" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 5, 158 | "metadata": { 159 | "pycharm": { 160 | "name": "#%%\n" 161 | } 162 | }, 163 | "outputs": [ 164 | { 165 | "name": "stdout", 166 | "output_type": "stream", 167 | "text": [ 168 | "Device: cuda\n" 169 | ] 170 | } 171 | ], 172 | "source": [ 173 | "# 句子的长度,作者并没有说明。我这里就按经验取一个\n", 174 | "max_length = 128\n", 175 | "# 作者使用的batch_size\n", 176 | "batch_size = 32\n", 177 | "# epoch数,作者并没有具体说明,按经验取一个\n", 178 | "epochs = 10\n", 179 | "\n", 180 | "# 每${log_after_step}步,打印一次日志\n", 181 | "log_after_step = 20\n", 182 | "\n", 183 | "# 模型存放的位置。\n", 184 | "model_path = './drive/MyDrive/models/MDCSpell/'\n", 185 | "os.makedirs(model_path, exist_ok=True)\n", 186 | "model_path = model_path + 'MDCSpell-model.pt'\n", 187 | "\n", 188 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 189 | "print(\"Device:\", device)" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "metadata": {}, 195 | "source": [ 196 | "# 模型构建" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "metadata": { 202 | "pycharm": { 203 | "name": "#%% md\n" 204 | } 205 | }, 206 | "source": [ 207 | "" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "metadata": { 213 | "pycharm": { 214 | "name": "#%% md\n" 215 | } 216 | }, 217 | "source": [ 218 | "---\n", 219 | "\n", 220 | "**Correction Network** 的数据流向如下:\n", 221 | "\n", 222 | "1.将token序列 `[CLS] 遇 到 逆 竟 [SEP]` 送给Word Embedding模块进行embeddings,得到向量$\\{e_{CLS}^w, e_1^w, e_2^w, e_3^w, e_4^w,e_{SEP}^w\\}$。\n", 223 | "\n", 224 | "> 个人认为此时的embedding仅仅是Word Embeding,并不包含Position Embedding和Segment Embedding。\n", 225 | "\n", 226 | "2.之后将$\\{e_{CLS}^w, e_1^w, e_2^w, e_3^w, e_4^w,e_{SEP}^w\\}$向量送入BERT,增加Position Embedding和Segment Embedding,得到 $\\{e_C, e_1, e_2, e_3, e_4, e_S\\}$。\n", 227 | "\n", 228 | "3.在BERT内部,会经历多层的TransformerEncoder,最终的得到输出向量 $H^c=\\{h_C^c, h_1^c, h_2^c, h_3^c, h_4^c, h_S^c\\}$.\n", 229 | "\n", 230 | "4.将BERT的输出 $H^c$ 和 隔壁Detection Network输出的 $H^d$ 进行融合,得到 $H = H^d+H^c$\n", 231 | "\n", 232 | "> 融合时并不对`[CLS]`和`[SEP]`进行融合\n", 233 | "\n", 234 | "5.将$H$送给全连接层(Dense Layer)做最后的预测。\n", 235 | "\n", 236 | "
\n", 237 | "\n", 238 | "**Correction Network模型细节**:\n", 239 | "\n", 240 | "1. **BERT**:作者使用的是具有12层Transformer Block的BERT-base版。\n", 241 | "2. **Dense Layer**:Dense Layer的输入通道为词向量维度,输出通道为词典大小。例如:词向量维度为768,词典大小为20000,则Dense Layer则为`nn.Linear(768, 20000)`\n", 242 | "3. **Dense Layer的初始化**:Dense Layer的权重使用的是Word Embedding的参数。因为word Embedding是将词index转成词向量,所以其参数刚好是Dense Layer的转置,即Word Embedding是`nn.Linear(20000, 768)`,所以作者就是用Word Embedding的转置来初始化Dense Layer的参数。因为这样可以加速训练,且使模型变的稳定。\n", 243 | "\n", 244 | "---\n", 245 | "\n", 246 | "**Detection Network**的数据流向如下:\n", 247 | "\n", 248 | "1.输入为使用BERT得到的word Embedding $\\{e_1^w, e_2^w, e_3^w, e_4^w\\}$。虽然图里并不包含`[CLS]`和`[SEP]`的词向量,但个人认为不需要对其特殊处理,因为最后的预测也用不到这两个token.\n", 249 | "\n", 250 | "2.将$\\{e_1^w, e_2^w, e_3^w, e_4^w\\}$增加Position Embedding信息,得到$\\{e_1', e_2', e_3', e_4'\\}$\n", 251 | "\n", 252 | "> 在论文中说Detection Network使用的是向量$\\{e_1, e_2, e_3, e_4\\}$,其是word embedding+position embedding+segment embedding。这与图上是矛盾的,这里以图为准了。\n", 253 | "\n", 254 | "3.将$\\{e_1', e_2', e_3', e_4'\\}$向量送入Transformer Block,得到输出向量$H^d=\\{h_1^d, h_2^d, h_3^d, h_4^d\\}$\n", 255 | "\n", 256 | "4.一方面,将输出向量$H^d$送给隔壁的Correction Network进行融合;另一方面,将$H^d$送给后续的全连接层(Dense Layer)来判断哪个token是错误的.\n", 257 | "\n", 258 | "**Detection Network**的细节:\n", 259 | "\n", 260 | "1. **Transformer Block**:Transformer Block是2层的TransformerEncoder。\n", 261 | "2. **Transformer Block参数初始化**:Transformer Block参数初始化使用的是BERT的权重。\n", 262 | "3. **Dense Layer**:Dense Layer的输入通道为词向量大小,输出通道为1。使用Sigmoid来判别该token为错字的概率。\n" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": 6, 268 | "metadata": { 269 | "pycharm": { 270 | "name": "#%%\n" 271 | } 272 | }, 273 | "outputs": [], 274 | "source": [ 275 | "class CorrectionNetwork(nn.Module):\n", 276 | "\n", 277 | " def __init__(self):\n", 278 | " super(CorrectionNetwork, self).__init__()\n", 279 | " # BERT分词器,作者并没提到自己使用的是哪个中文版的bert,我这里就使用一个比较常用的\n", 280 | " self.tokenizer = AutoTokenizer.from_pretrained(\"hfl/chinese-roberta-wwm-ext\")\n", 281 | " # BERT\n", 282 | " self.bert = AutoModel.from_pretrained(\"hfl/chinese-roberta-wwm-ext\")\n", 283 | " # BERT的word embedding,本质就是个nn.Embedding\n", 284 | " self.word_embedding_table = self.bert.get_input_embeddings()\n", 285 | " # 预测层。hidden_size是词向量的大小,len(self.tokenizer)是词典大小\n", 286 | " self.dense_layer = nn.Linear(self.bert.config.hidden_size, len(self.tokenizer))\n", 287 | "\n", 288 | " def forward(self, inputs, word_embeddings, detect_hidden_states):\n", 289 | " \"\"\"\n", 290 | " Correction Network的前向传递\n", 291 | " :param inputs: inputs为tokenizer对中文文本的分词结果,\n", 292 | " 里面包含了token对一个的index,attention_mask等\n", 293 | " :param word_embeddings: 使用BERT的word_embedding对token进行embedding后的结果\n", 294 | " :param detect_hidden_states: Detection Network输出hidden state\n", 295 | " :return: Correction Network对个token的预测结果。\n", 296 | " \"\"\"\n", 297 | " # 1. 使用bert进行前向传递\n", 298 | " bert_outputs = self.bert(token_type_ids=inputs['token_type_ids'],\n", 299 | " attention_mask=inputs['attention_mask'],\n", 300 | " inputs_embeds=word_embeddings)\n", 301 | " # 2. 将bert的hidden_state和Detection Network的hidden state进行融合。\n", 302 | " hidden_states = bert_outputs['last_hidden_state'] + detect_hidden_states\n", 303 | " # 3. 最终使用全连接层进行token预测\n", 304 | " return self.dense_layer(hidden_states)\n", 305 | "\n", 306 | " def get_inputs_and_word_embeddings(self, sequences, max_length=128):\n", 307 | " \"\"\"\n", 308 | " 对中文序列进行分词和word embeddings处理\n", 309 | " :param sequences: 中文文本序列。例如: [\"鸡你太美\", \"哎呦,你干嘛!\"]\n", 310 | " :param max_length: 文本的最大长度,不足则进行填充,超出进行裁剪。\n", 311 | " :return: tokenizer的输出和word embeddings.\n", 312 | " \"\"\"\n", 313 | " inputs = self.tokenizer(sequences, padding='max_length', max_length=max_length, return_tensors='pt',\n", 314 | " truncation=True).to(device)\n", 315 | " # 使用BERT的work embeddings对token进行embedding,这里得到的embedding并不包含position embedding和segment embedding\n", 316 | " word_embeddings = self.word_embedding_table(inputs['input_ids'])\n", 317 | " return inputs, word_embeddings" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 7, 323 | "metadata": { 324 | "pycharm": { 325 | "name": "#%%\n" 326 | } 327 | }, 328 | "outputs": [], 329 | "source": [ 330 | "class DetectionNetwork(nn.Module):\n", 331 | "\n", 332 | " def __init__(self, position_embeddings, transformer_blocks, hidden_size):\n", 333 | " \"\"\"\n", 334 | " :param position_embeddings: bert的position_embeddings,本质是一个nn.Embedding\n", 335 | " :param transformer: BERT的前两层transformer_block,其是一个ModuleList对象\n", 336 | " \"\"\"\n", 337 | " super(DetectionNetwork, self).__init__()\n", 338 | " self.position_embeddings = position_embeddings\n", 339 | " self.transformer_blocks = transformer_blocks\n", 340 | "\n", 341 | " # 定义最后的预测层,预测哪个token是错误的\n", 342 | " self.dense_layer = nn.Sequential(\n", 343 | " nn.Linear(hidden_size, 1),\n", 344 | " nn.Sigmoid()\n", 345 | " )\n", 346 | "\n", 347 | " def forward(self, word_embeddings):\n", 348 | " # 获取token序列的长度,这里为128\n", 349 | " sequence_length = word_embeddings.size(1)\n", 350 | " # 生成position embedding\n", 351 | " position_embeddings = self.position_embeddings(torch.LongTensor(range(sequence_length)).to(device))\n", 352 | " # 融合work_embedding和position_embedding\n", 353 | " x = word_embeddings + position_embeddings\n", 354 | " # 将x一层一层的使用transformer encoder进行向后传递\n", 355 | " for transformer_layer in self.transformer_blocks:\n", 356 | " x = transformer_layer(x)[0]\n", 357 | "\n", 358 | " # 最终返回Detection Network输出的hidden states和预测结果\n", 359 | " hidden_states = x\n", 360 | " return hidden_states, self.dense_layer(hidden_states)" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 8, 366 | "metadata": { 367 | "pycharm": { 368 | "name": "#%%\n" 369 | } 370 | }, 371 | "outputs": [], 372 | "source": [ 373 | "class MDCSpellModel(nn.Module):\n", 374 | "\n", 375 | " def __init__(self):\n", 376 | " super(MDCSpellModel, self).__init__()\n", 377 | " # 构造Correction Network\n", 378 | " self.correction_network = CorrectionNetwork()\n", 379 | " self._init_correction_dense_layer()\n", 380 | "\n", 381 | " # 构造Detection Network\n", 382 | " # position embedding使用BERT的\n", 383 | " position_embeddings = self.correction_network.bert.embeddings.position_embeddings\n", 384 | " # 作者在论文中提到的,Detection Network的Transformer使用BERT的权重\n", 385 | " # 所以我这里直接克隆BERT的前两层Transformer来完成这个动作\n", 386 | " transformer = copy.deepcopy(self.correction_network.bert.encoder.layer[:2])\n", 387 | " # 提取BERT的词向量大小\n", 388 | " hidden_size = self.correction_network.bert.config.hidden_size\n", 389 | "\n", 390 | " # 构造Detection Network\n", 391 | " self.detection_network = DetectionNetwork(position_embeddings, transformer, hidden_size)\n", 392 | "\n", 393 | " def forward(self, sequences, max_length=128):\n", 394 | " # 先获取word embedding,Correction Network和Detection Network都要用\n", 395 | " inputs, word_embeddings = self.correction_network.get_inputs_and_word_embeddings(sequences, max_length)\n", 396 | " # Detection Network进行前向传递,获取输出的Hidden State和预测结果\n", 397 | " hidden_states, detection_outputs = self.detection_network(word_embeddings)\n", 398 | " # Correction Network进行前向传递,获取其预测结果\n", 399 | " correction_outputs = self.correction_network(inputs, word_embeddings, hidden_states)\n", 400 | " # 返回Correction Network 和 Detection Network 的预测结果。\n", 401 | " # 在计算损失时`[PAD]`token不需要参与计算,所以这里将`[PAD]`部分全都变为0\n", 402 | " return correction_outputs, detection_outputs.squeeze(2) * inputs['attention_mask']\n", 403 | "\n", 404 | " def _init_correction_dense_layer(self):\n", 405 | " \"\"\"\n", 406 | " 原论文中提到,使用Word Embedding的weight来对Correction Network进行初始化\n", 407 | " \"\"\"\n", 408 | " self.correction_network.dense_layer.weight.data = self.correction_network.word_embedding_table.weight.data" 409 | ] 410 | }, 411 | { 412 | "cell_type": "markdown", 413 | "metadata": { 414 | "pycharm": { 415 | "name": "#%% md\n" 416 | } 417 | }, 418 | "source": [ 419 | "定义好模型后,我们来简单的尝试一下:" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": 9, 425 | "metadata": { 426 | "pycharm": { 427 | "name": "#%%\n" 428 | } 429 | }, 430 | "outputs": [ 431 | { 432 | "name": "stderr", 433 | "output_type": "stream", 434 | "text": [ 435 | "Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']\n", 436 | "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 437 | "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" 438 | ] 439 | }, 440 | { 441 | "name": "stdout", 442 | "output_type": "stream", 443 | "text": [ 444 | "correction_outputs shape: torch.Size([2, 128, 21128])\n", 445 | "detection_outputs shape: torch.Size([2, 128])\n" 446 | ] 447 | } 448 | ], 449 | "source": [ 450 | "model = MDCSpellModel().to(device)\n", 451 | "correction_outputs, detection_outputs = model([\"鸡你太美\", \"哎呦,你干嘛!\"])\n", 452 | "print(\"correction_outputs shape:\", correction_outputs.size())\n", 453 | "print(\"detection_outputs shape:\", detection_outputs.size())" 454 | ] 455 | }, 456 | { 457 | "cell_type": "markdown", 458 | "metadata": { 459 | "pycharm": { 460 | "name": "#%% md\n" 461 | } 462 | }, 463 | "source": [ 464 | "# 损失函数" 465 | ] 466 | }, 467 | { 468 | "cell_type": "markdown", 469 | "metadata": { 470 | "pycharm": { 471 | "name": "#%% md\n" 472 | } 473 | }, 474 | "source": [ 475 | "Correction Network和Detection Network使用的都是Cross Entropy。之后进行相加即可:\n", 476 | "\n", 477 | "$$\n", 478 | "L = \\lambda L^c + (1-\\lambda) L^d\n", 479 | "$$\n", 480 | "\n", 481 | "其中 $\\lambda \\in [0,1]$ 。作者通过实验得出 $\\lambda=0.85$ 时效果最好。" 482 | ] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "execution_count": 10, 487 | "metadata": { 488 | "pycharm": { 489 | "name": "#%%\n" 490 | } 491 | }, 492 | "outputs": [], 493 | "source": [ 494 | "class MDCSpellLoss(nn.Module):\n", 495 | "\n", 496 | " def __init__(self, coefficient=0.85):\n", 497 | " super(MDCSpellLoss, self).__init__()\n", 498 | " # 定义Correction Network的Loss函数\n", 499 | " self.correction_criterion = nn.CrossEntropyLoss(ignore_index=0)\n", 500 | " # 定义Detection Network的Loss函数,因为是二分类,所以用Binary Cross Entropy\n", 501 | " self.detection_criterion = nn.BCELoss()\n", 502 | " # 权重系数\n", 503 | " self.coefficient = coefficient\n", 504 | "\n", 505 | " def forward(self, correction_outputs, correction_targets, detection_outputs, detection_targets):\n", 506 | " \"\"\"\n", 507 | " :param correction_outputs: Correction Network的输出,Shape为(batch_size, sequence_length, hidden_size)\n", 508 | " :param correction_targets: Correction Network的标签,Shape为(batch_size, sequence_length)\n", 509 | " :param detection_outputs: Detection Network的输出,Shape为(batch_size, sequence_length)\n", 510 | " :param detection_targets: Detection Network的标签,Shape为(batch_size, sequence_length)\n", 511 | " :return:\n", 512 | " \"\"\"\n", 513 | " # 计算Correction Network的loss,因为Shape维度为3,所以要把batch_size和sequence_length进行合并才能计算\n", 514 | " correction_loss = self.correction_criterion(correction_outputs.view(-1, correction_outputs.size(2)),\n", 515 | " correction_targets.view(-1))\n", 516 | " # 计算Detection Network的loss\n", 517 | " detection_loss = self.detection_criterion(detection_outputs, detection_targets)\n", 518 | " # 对两个loss进行加权平均\n", 519 | " return self.coefficient * correction_loss + (1 - self.coefficient) * detection_loss" 520 | ] 521 | }, 522 | { 523 | "cell_type": "markdown", 524 | "metadata": { 525 | "pycharm": { 526 | "name": "#%% md\n" 527 | } 528 | }, 529 | "source": [ 530 | "# 模型训练" 531 | ] 532 | }, 533 | { 534 | "cell_type": "markdown", 535 | "metadata": { 536 | "pycharm": { 537 | "name": "#%% md\n" 538 | } 539 | }, 540 | "source": [ 541 | "作者的训练方式:\n", 542 | "\n", 543 | "1. 第一步,首先使用 [Wang271K(自己造的假数据)](https://github.com/wdimmy/Automatic-Corpus-Generation) 数据集进行训练。batch size为32, learning rate为2e-5\n", 544 | "\n", 545 | "2. 第二步,使用SIGHAN训练集进行fine-tune。 batch size为32,learning rate为1e-5\n", 546 | "\n", 547 | "> 作者并没有提到使用的是什么Optimizer,但看这个学习率,应该是Adam。\n", 548 | "\n", 549 | "> 在第一步,作者说的是使用了几乎3M个,但作者只提到过Wang271K这个数据集,我猜可能作者看错了,这个是0.3M条数据,而不是3M。" 550 | ] 551 | }, 552 | { 553 | "cell_type": "markdown", 554 | "metadata": { 555 | "pycharm": { 556 | "name": "#%% md\n" 557 | } 558 | }, 559 | "source": [ 560 | "作者首先使用了Wang271K数据集进行对模型进行训练,然后又使用SIGHAN训练集对模型进行fine-tune。这里我就不进行fine-tune了,直接进行训练。我这里使用的是 [ReaLiSe](https://github.com/DaDaMrX/ReaLiSe)论文 处理好的数据集,其就是Wang271K和SIGHAN。\n", 561 | "\n", 562 | "[百度网盘链接](https://pan.baidu.com/s/1x67LPiYAjLKhO1_2CI6aOA?pwd=skda) :https://pan.baidu.com/s/1x67LPiYAjLKhO1_2CI6aOA?pwd=skda\n", 563 | "\n", 564 | "下载好直接解压即可。" 565 | ] 566 | }, 567 | { 568 | "cell_type": "code", 569 | "execution_count": 11, 570 | "metadata": { 571 | "pycharm": { 572 | "name": "#%%\n" 573 | } 574 | }, 575 | "outputs": [ 576 | { 577 | "name": "stderr", 578 | "output_type": "stream", 579 | "text": [ 580 | "'gdown' 不是内部或外部命令,也不是可运行的程序\n", 581 | "或批处理文件。\n" 582 | ] 583 | } 584 | ], 585 | "source": [ 586 | "!gdown '1dC09i57lobL91lEbpebDuUBS0fGz-LAk' --folder --output data" 587 | ] 588 | }, 589 | { 590 | "cell_type": "markdown", 591 | "metadata": { 592 | "pycharm": { 593 | "name": "#%% md\n" 594 | } 595 | }, 596 | "source": [ 597 | "## 构造Dataset" 598 | ] 599 | }, 600 | { 601 | "cell_type": "code", 602 | "execution_count": 12, 603 | "metadata": { 604 | "pycharm": { 605 | "name": "#%%\n" 606 | } 607 | }, 608 | "outputs": [], 609 | "source": [ 610 | "class CSCDataset(Dataset):\n", 611 | "\n", 612 | " def __init__(self):\n", 613 | " super(CSCDataset, self).__init__()\n", 614 | " with open(\"data/trainall.times2.pkl\", mode='br') as f:\n", 615 | " train_data = pickle.load(f)\n", 616 | "\n", 617 | " self.train_data = train_data\n", 618 | "\n", 619 | " def __getitem__(self, index):\n", 620 | " src = self.train_data[index]['src']\n", 621 | " tgt = self.train_data[index]['tgt']\n", 622 | " return src, tgt\n", 623 | "\n", 624 | " def __len__(self):\n", 625 | " return len(self.train_data)" 626 | ] 627 | }, 628 | { 629 | "cell_type": "code", 630 | "execution_count": 13, 631 | "metadata": { 632 | "pycharm": { 633 | "name": "#%%\n" 634 | } 635 | }, 636 | "outputs": [], 637 | "source": [ 638 | "train_data = CSCDataset()" 639 | ] 640 | }, 641 | { 642 | "cell_type": "code", 643 | "execution_count": 14, 644 | "metadata": { 645 | "pycharm": { 646 | "name": "#%%\n" 647 | } 648 | }, 649 | "outputs": [ 650 | { 651 | "data": { 652 | "text/plain": [ 653 | "('纽约早盘作为基准的低硫轻油,五月份交割价攀升一点三四美元,来到每桶二十八点二五美元,而上周五曾下挫一豪元以上。',\n", 654 | " '纽约早盘作为基准的低硫轻油,五月份交割价攀升一点三四美元,来到每桶二十八点二五美元,而上周五曾下挫一美元以上。')" 655 | ] 656 | }, 657 | "execution_count": 14, 658 | "metadata": {}, 659 | "output_type": "execute_result" 660 | } 661 | ], 662 | "source": [ 663 | "train_data.__getitem__(0)" 664 | ] 665 | }, 666 | { 667 | "cell_type": "markdown", 668 | "metadata": { 669 | "pycharm": { 670 | "name": "#%% md\n" 671 | } 672 | }, 673 | "source": [ 674 | "## 构造Dataloader" 675 | ] 676 | }, 677 | { 678 | "cell_type": "code", 679 | "execution_count": 15, 680 | "metadata": { 681 | "pycharm": { 682 | "name": "#%%\n" 683 | } 684 | }, 685 | "outputs": [], 686 | "source": [ 687 | "tokenizer = AutoTokenizer.from_pretrained(\"hfl/chinese-roberta-wwm-ext\")" 688 | ] 689 | }, 690 | { 691 | "cell_type": "code", 692 | "execution_count": 16, 693 | "metadata": { 694 | "pycharm": { 695 | "name": "#%%\n" 696 | } 697 | }, 698 | "outputs": [], 699 | "source": [ 700 | "def collate_fn(batch):\n", 701 | " src, tgt = zip(*batch)\n", 702 | " src, tgt = list(src), list(tgt)\n", 703 | "\n", 704 | " src_tokens = tokenizer(src, padding='max_length', max_length=128, return_tensors='pt', truncation=True)['input_ids']\n", 705 | " tgt_tokens = tokenizer(tgt, padding='max_length', max_length=128, return_tensors='pt', truncation=True)['input_ids']\n", 706 | "\n", 707 | " correction_targets = tgt_tokens\n", 708 | " detection_targets = (src_tokens != tgt_tokens).float()\n", 709 | " return src, correction_targets, detection_targets, src_tokens # src_tokens在计算Correction的精准率时要用到" 710 | ] 711 | }, 712 | { 713 | "cell_type": "code", 714 | "execution_count": 17, 715 | "metadata": { 716 | "pycharm": { 717 | "name": "#%%\n" 718 | } 719 | }, 720 | "outputs": [], 721 | "source": [ 722 | "train_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)" 723 | ] 724 | }, 725 | { 726 | "cell_type": "markdown", 727 | "metadata": { 728 | "pycharm": { 729 | "name": "#%% md\n" 730 | } 731 | }, 732 | "source": [ 733 | "## 训练" 734 | ] 735 | }, 736 | { 737 | "cell_type": "code", 738 | "execution_count": 18, 739 | "metadata": { 740 | "pycharm": { 741 | "name": "#%%\n" 742 | } 743 | }, 744 | "outputs": [], 745 | "source": [ 746 | "criterion = MDCSpellLoss()\n", 747 | "optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)\n", 748 | "start_epoch = 0 # 从哪个epoch开始\n", 749 | "total_step = 0 # 一共更新了多少次参数" 750 | ] 751 | }, 752 | { 753 | "cell_type": "code", 754 | "execution_count": 19, 755 | "metadata": { 756 | "pycharm": { 757 | "name": "#%%\n" 758 | } 759 | }, 760 | "outputs": [ 761 | { 762 | "name": "stdout", 763 | "output_type": "stream", 764 | "text": [ 765 | "恢复训练,epoch: 2\n" 766 | ] 767 | } 768 | ], 769 | "source": [ 770 | "# 恢复之前的训练\n", 771 | "if os.path.exists(model_path):\n", 772 | " if not torch.cuda.is_available():\n", 773 | " checkpoint = torch.load(model_path, map_location='cpu')\n", 774 | " else:\n", 775 | " checkpoint = torch.load(model_path)\n", 776 | " model.load_state_dict(checkpoint['model'])\n", 777 | " optimizer.load_state_dict(checkpoint['optimizer'])\n", 778 | " start_epoch = checkpoint['epoch']\n", 779 | " total_step = checkpoint['total_step']\n", 780 | " print(\"恢复训练,epoch:\", start_epoch)" 781 | ] 782 | }, 783 | { 784 | "cell_type": "code", 785 | "execution_count": 20, 786 | "metadata": { 787 | "pycharm": { 788 | "name": "#%%\n" 789 | } 790 | }, 791 | "outputs": [], 792 | "source": [ 793 | "model = model.to(device)\n", 794 | "model = model.train()" 795 | ] 796 | }, 797 | { 798 | "cell_type": "markdown", 799 | "metadata": { 800 | "pycharm": { 801 | "name": "#%% md\n" 802 | } 803 | }, 804 | "source": [ 805 | "训练这里代码量看起来很大,但实际大多都是计算recall和precision的代码。这里对于Detection的recall和precision的计算使用的是Detection Network的预测结果。" 806 | ] 807 | }, 808 | { 809 | "cell_type": "code", 810 | "execution_count": 21, 811 | "metadata": { 812 | "pycharm": { 813 | "name": "#%%\n" 814 | } 815 | }, 816 | "outputs": [ 817 | { 818 | "name": "stdout", 819 | "output_type": "stream", 820 | "text": [ 821 | "Epoch 2, Step 15/8882, Total Step 8900, loss 0.02403, detection recall 0.4118, detection precision 0.8247, correction recall 0.8192, correction precision 0.9485\n", 822 | "Epoch 2, Step 35/8882, Total Step 8920, loss 0.03479, detection recall 0.3658, detection precision 0.8055, correction recall 0.8029, correction precision 0.9125\n" 823 | ] 824 | }, 825 | { 826 | "ename": "KeyboardInterrupt", 827 | "evalue": "", 828 | "output_type": "error", 829 | "traceback": [ 830 | "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", 831 | "\u001B[1;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)", 832 | "\u001B[1;32m\u001B[0m in \u001B[0;36m\u001B[1;34m\u001B[0m\n\u001B[0;32m 19\u001B[0m \u001B[0mcorrection_outputs\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mdetection_outputs\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0msequences\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 20\u001B[0m \u001B[0mloss\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mcriterion\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcorrection_outputs\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mcorrection_targets\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mdetection_outputs\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mdetection_targets\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 21\u001B[1;33m \u001B[0mloss\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mbackward\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 22\u001B[0m \u001B[0moptimizer\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mstep\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 23\u001B[0m \u001B[0moptimizer\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mzero_grad\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n", 833 | "\u001B[1;32mD:\\Anaconda3\\lib\\site-packages\\torch\\_tensor.py\u001B[0m in \u001B[0;36mbackward\u001B[1;34m(self, gradient, retain_graph, create_graph, inputs)\u001B[0m\n\u001B[0;32m 394\u001B[0m \u001B[0mcreate_graph\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mcreate_graph\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 395\u001B[0m inputs=inputs)\n\u001B[1;32m--> 396\u001B[1;33m \u001B[0mtorch\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mautograd\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mbackward\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mgradient\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mretain_graph\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mcreate_graph\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0minputs\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0minputs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 397\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 398\u001B[0m \u001B[1;32mdef\u001B[0m \u001B[0mregister_hook\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mhook\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n", 834 | "\u001B[1;32mD:\\Anaconda3\\lib\\site-packages\\torch\\autograd\\__init__.py\u001B[0m in \u001B[0;36mbackward\u001B[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001B[0m\n\u001B[0;32m 171\u001B[0m \u001B[1;31m# some Python versions print out the first line of a multi-line function\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 172\u001B[0m \u001B[1;31m# calls in the traceback and some print out the last line\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 173\u001B[1;33m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001B[0m\u001B[0;32m 174\u001B[0m \u001B[0mtensors\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mgrad_tensors_\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mretain_graph\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mcreate_graph\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0minputs\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 175\u001B[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass\n", 835 | "\u001B[1;31mKeyboardInterrupt\u001B[0m: " 836 | ] 837 | } 838 | ], 839 | "source": [ 840 | "total_loss = 0. # 记录loss\n", 841 | "\n", 842 | "d_recall_numerator = 0 # Detection的Recall的分子\n", 843 | "d_recall_denominator = 0 # Detection的Recall的分母\n", 844 | "d_precision_numerator = 0 # Detection的precision的分子\n", 845 | "d_precision_denominator = 0 # Detection的precision的分母\n", 846 | "c_recall_numerator = 0 # Correction的Recall的分子\n", 847 | "c_recall_denominator = 0 # Correction的Recall的分母\n", 848 | "c_precision_numerator = 0 # Correction的precision的分子\n", 849 | "c_precision_denominator = 0 # Correction的precision的分母\n", 850 | "\n", 851 | "for epoch in range(start_epoch, epochs):\n", 852 | "\n", 853 | " step = 0\n", 854 | "\n", 855 | " for sequences, correction_targets, detection_targets, correction_inputs in train_loader:\n", 856 | " correction_targets, detection_targets = correction_targets.to(device), detection_targets.to(device)\n", 857 | " correction_inputs = correction_inputs.to(device)\n", 858 | " correction_outputs, detection_outputs = model(sequences)\n", 859 | " loss = criterion(correction_outputs, correction_targets, detection_outputs, detection_targets)\n", 860 | " loss.backward()\n", 861 | " optimizer.step()\n", 862 | " optimizer.zero_grad()\n", 863 | "\n", 864 | " step += 1\n", 865 | " total_step += 1\n", 866 | "\n", 867 | " total_loss += loss.detach().item()\n", 868 | "\n", 869 | " # 计算Detection的recall和precision指标\n", 870 | " # 大于0.5,认为是错误token,反之为正确token\n", 871 | " d_predicts = detection_outputs >= 0.5\n", 872 | " # 计算错误token中被网络正确预测到的数量\n", 873 | " d_recall_numerator += d_predicts[detection_targets == 1].sum().item()\n", 874 | " # 计算错误token的数量\n", 875 | " d_recall_denominator += (detection_targets == 1).sum().item()\n", 876 | " # 计算网络预测的错误token的数量\n", 877 | " d_precision_denominator += d_predicts.sum().item()\n", 878 | " # 计算网络预测的错误token中,有多少是真错误的token\n", 879 | " d_precision_numerator += (detection_targets[d_predicts == 1]).sum().item()\n", 880 | "\n", 881 | " # 计算Correction的recall和precision\n", 882 | " # 将输出映射成index,即将correction_outputs的Shape由(32, 128, 21128)变为(32,128)\n", 883 | " correction_outputs = correction_outputs.argmax(2)\n", 884 | " # 对于填充、[CLS]和[SEP]这三个token不校验\n", 885 | " correction_outputs[(correction_targets == 0) | (correction_targets == 101) | (correction_targets == 102)] = 0\n", 886 | " # correction_targets的[CLS]和[SEP]也要变为0\n", 887 | " correction_targets[(correction_targets == 101) | (correction_targets == 102)] = 0\n", 888 | " # Correction的预测结果,其中True表示预测正确,False表示预测错误或无需预测\n", 889 | " c_predicts = correction_outputs == correction_targets\n", 890 | " # 计算错误token中被网络正确纠正的token数量\n", 891 | " c_recall_numerator += c_predicts[detection_targets == 1].sum().item()\n", 892 | " # 计算错误token的数量\n", 893 | " c_recall_denominator += (detection_targets == 1).sum().item()\n", 894 | " # 计算网络纠正token的数量\n", 895 | " correction_inputs[(correction_inputs == 101) | (correction_inputs == 102)] = 0\n", 896 | " c_precision_denominator += (correction_outputs != correction_inputs).sum().item()\n", 897 | " # 计算在网络纠正的这些token中,有多少是真正被纠正对的\n", 898 | " c_precision_numerator += c_predicts[correction_outputs != correction_inputs].sum().item()\n", 899 | "\n", 900 | " if total_step % log_after_step == 0:\n", 901 | " loss = total_loss / log_after_step\n", 902 | " d_recall = d_recall_numerator / (d_recall_denominator + 1e-9)\n", 903 | " d_precision = d_precision_numerator / (d_precision_denominator + 1e-9)\n", 904 | " c_recall = c_recall_numerator / (c_recall_denominator + 1e-9)\n", 905 | " c_precision = c_precision_numerator / (c_precision_denominator + 1e-9)\n", 906 | "\n", 907 | " print(\"Epoch {}, \"\n", 908 | " \"Step {}/{}, \"\n", 909 | " \"Total Step {}, \"\n", 910 | " \"loss {:.5f}, \"\n", 911 | " \"detection recall {:.4f}, \"\n", 912 | " \"detection precision {:.4f}, \"\n", 913 | " \"correction recall {:.4f}, \"\n", 914 | " \"correction precision {:.4f}\".format(epoch, step, len(train_loader), total_step,\n", 915 | " loss,\n", 916 | " d_recall,\n", 917 | " d_precision,\n", 918 | " c_recall,\n", 919 | " c_precision))\n", 920 | "\n", 921 | " total_loss = 0.\n", 922 | " total_correct = 0\n", 923 | " total_num = 0\n", 924 | " d_recall_numerator = 0\n", 925 | " d_recall_denominator = 0\n", 926 | " d_precision_numerator = 0\n", 927 | " d_precision_denominator = 0\n", 928 | " c_recall_numerator = 0\n", 929 | " c_recall_denominator = 0\n", 930 | " c_precision_numerator = 0\n", 931 | " c_precision_denominator = 0\n", 932 | "\n", 933 | " torch.save({\n", 934 | " 'model': model.state_dict(),\n", 935 | " 'optimizer': optimizer.state_dict(),\n", 936 | " 'epoch': epoch + 1,\n", 937 | " 'total_step': total_step,\n", 938 | " }, model_path)" 939 | ] 940 | }, 941 | { 942 | "cell_type": "markdown", 943 | "metadata": { 944 | "pycharm": { 945 | "name": "#%% md\n" 946 | } 947 | }, 948 | "source": [ 949 | "# 模型评估" 950 | ] 951 | }, 952 | { 953 | "cell_type": "markdown", 954 | "metadata": { 955 | "pycharm": { 956 | "name": "#%% md\n" 957 | } 958 | }, 959 | "source": [ 960 | "模型评估使用了SIGHAN 2013,2014,2015三个数据集对模型进行评估。对于Detection的Precision和Recall的评估,使用的是Correction Network的结果,这和训练阶段有所不同,这是因为Detection Network只是帮助Correction Network训练的,其结果在使用时不具备参考价值。" 961 | ] 962 | }, 963 | { 964 | "cell_type": "code", 965 | "execution_count": 22, 966 | "metadata": { 967 | "pycharm": { 968 | "name": "#%%\n" 969 | } 970 | }, 971 | "outputs": [], 972 | "source": [ 973 | "model = model.eval()" 974 | ] 975 | }, 976 | { 977 | "cell_type": "code", 978 | "execution_count": 25, 979 | "metadata": { 980 | "pycharm": { 981 | "name": "#%%\n" 982 | } 983 | }, 984 | "outputs": [], 985 | "source": [ 986 | "def evaluation(test_data):\n", 987 | " d_recall_numerator = 0 # Detection的Recall的分子\n", 988 | " d_recall_denominator = 0 # Detection的Recall的分母\n", 989 | " d_precision_numerator = 0 # Detection的precision的分子\n", 990 | " d_precision_denominator = 0 # Detection的precision的分母\n", 991 | " c_recall_numerator = 0 # Correction的Recall的分子\n", 992 | " c_recall_denominator = 0 # Correction的Recall的分母\n", 993 | " c_precision_numerator = 0 # Correction的precision的分子\n", 994 | " c_precision_denominator = 0 # Correction的precision的分母\n", 995 | "\n", 996 | " prograss = tqdm(range(len(test_data)))\n", 997 | " for i in prograss:\n", 998 | " src, tgt = test_data[i]['src'], test_data[i]['tgt']\n", 999 | "\n", 1000 | " src_tokens = tokenizer(src, return_tensors='pt', max_length=128, truncation=True)['input_ids'][0][1:-1]\n", 1001 | " tgt_tokens = tokenizer(tgt, return_tensors='pt', max_length=128, truncation=True)['input_ids'][0][1:-1]\n", 1002 | "\n", 1003 | " # 正常情况下,src和tgt的长度应该是一致的\n", 1004 | " if len(src_tokens) != len(tgt_tokens):\n", 1005 | " print(\"第%d条数据异常\" % i)\n", 1006 | " continue\n", 1007 | "\n", 1008 | " correction_outputs, _ = model(src)\n", 1009 | " predict_tokens = correction_outputs[0][1:len(src_tokens) + 1].argmax(1).detach().cpu()\n", 1010 | "\n", 1011 | " # 计算错误token的数量\n", 1012 | " d_recall_denominator += (src_tokens != tgt_tokens).sum().item()\n", 1013 | " # 计算在这些错误token,有多少网络也认为它是错误的\n", 1014 | " d_recall_numerator += (predict_tokens != src_tokens)[src_tokens != tgt_tokens].sum().item()\n", 1015 | " # 计算网络找出的错误token的数量\n", 1016 | " d_precision_denominator += (predict_tokens != src_tokens).sum().item()\n", 1017 | " # 计算在网络找出的这些错误token中,有多少是真正错误的\n", 1018 | " d_precision_numerator += (src_tokens != tgt_tokens)[predict_tokens != src_tokens].sum().item()\n", 1019 | " # 计算Detection的recall、precision和f1-score\n", 1020 | " d_recall = d_recall_numerator / (d_recall_denominator + 1e-9)\n", 1021 | " d_precision = d_precision_numerator / (d_precision_denominator + 1e-9)\n", 1022 | " d_f1_score = 2 * (d_recall * d_precision) / (d_recall + d_precision + 1e-9)\n", 1023 | "\n", 1024 | " # 计算错误token的数量\n", 1025 | " c_recall_denominator += (src_tokens != tgt_tokens).sum().item()\n", 1026 | " # 计算在这些错误token中,有多少网络预测对了\n", 1027 | " c_recall_numerator += (predict_tokens == tgt_tokens)[src_tokens != tgt_tokens].sum().item()\n", 1028 | " # 计算网络找出的错误token的数量\n", 1029 | " c_precision_denominator += (predict_tokens != src_tokens).sum().item()\n", 1030 | " # 计算网络找出的错误token中,有多少是正确修正的\n", 1031 | " c_precision_numerator += (predict_tokens == tgt_tokens)[predict_tokens != src_tokens].sum().item()\n", 1032 | "\n", 1033 | " # 计算Correction的recall、precision和f1-score\n", 1034 | " c_recall = c_recall_numerator / (c_recall_denominator + 1e-9)\n", 1035 | " c_precision = c_precision_numerator / (c_precision_denominator + 1e-9)\n", 1036 | " c_f1_score = 2 * (c_recall * c_precision) / (c_recall + c_precision + 1e-9)\n", 1037 | "\n", 1038 | " prograss.set_postfix({\n", 1039 | " 'd_recall': d_recall,\n", 1040 | " 'd_precision': d_precision,\n", 1041 | " 'd_f1_score': d_f1_score,\n", 1042 | " 'c_recall': c_recall,\n", 1043 | " 'c_precision': c_precision,\n", 1044 | " 'c_f1_score': c_f1_score,\n", 1045 | " })" 1046 | ] 1047 | }, 1048 | { 1049 | "cell_type": "code", 1050 | "execution_count": 26, 1051 | "metadata": { 1052 | "pycharm": { 1053 | "name": "#%%\n" 1054 | } 1055 | }, 1056 | "outputs": [ 1057 | { 1058 | "name": "stderr", 1059 | "output_type": "stream", 1060 | "text": [ 1061 | "100%|██████████| 1000/1000 [00:11<00:00, 90.12it/s, d_recall=0.799, d_precision=0.802, d_f1_score=0.8, c_recall=0.769, c_precision=0.772, c_f1_score=0.771] \n" 1062 | ] 1063 | } 1064 | ], 1065 | "source": [ 1066 | "with open(\"data/test.sighan13.pkl\", mode='br') as f:\n", 1067 | " sighan13 = pickle.load(f)\n", 1068 | "evaluation(sighan13)" 1069 | ] 1070 | }, 1071 | { 1072 | "cell_type": "code", 1073 | "execution_count": 27, 1074 | "metadata": { 1075 | "pycharm": { 1076 | "name": "#%%\n" 1077 | } 1078 | }, 1079 | "outputs": [ 1080 | { 1081 | "name": "stderr", 1082 | "output_type": "stream", 1083 | "text": [ 1084 | "100%|██████████| 1062/1062 [00:12<00:00, 85.48it/s, d_recall=0.666, d_precision=0.828, d_f1_score=0.738, c_recall=0.643, c_precision=0.799, c_f1_score=0.712]\n" 1085 | ] 1086 | } 1087 | ], 1088 | "source": [ 1089 | "with open(\"data/test.sighan14.pkl\", mode='br') as f:\n", 1090 | " sighan14 = pickle.load(f)\n", 1091 | "evaluation(sighan14)" 1092 | ] 1093 | }, 1094 | { 1095 | "cell_type": "code", 1096 | "execution_count": 28, 1097 | "metadata": { 1098 | "pycharm": { 1099 | "name": "#%%\n" 1100 | } 1101 | }, 1102 | "outputs": [ 1103 | { 1104 | "name": "stderr", 1105 | "output_type": "stream", 1106 | "text": [ 1107 | "100%|██████████| 1100/1100 [00:11<00:00, 92.04it/s, d_recall=0.761, d_precision=0.867, d_f1_score=0.811, c_recall=0.725, c_precision=0.827, c_f1_score=0.773] \n" 1108 | ] 1109 | } 1110 | ], 1111 | "source": [ 1112 | "with open(\"data/test.sighan15.pkl\", mode='br') as f:\n", 1113 | " sighan15 = pickle.load(f)\n", 1114 | "evaluation(sighan15)" 1115 | ] 1116 | }, 1117 | { 1118 | "cell_type": "markdown", 1119 | "metadata": { 1120 | "pycharm": { 1121 | "name": "#%% md\n" 1122 | } 1123 | }, 1124 | "source": [ 1125 | "# 模型使用" 1126 | ] 1127 | }, 1128 | { 1129 | "cell_type": "markdown", 1130 | "metadata": { 1131 | "pycharm": { 1132 | "name": "#%% md\n" 1133 | } 1134 | }, 1135 | "source": [ 1136 | "最后,我们来真正的使用一下该模型,看下效果:" 1137 | ] 1138 | }, 1139 | { 1140 | "cell_type": "code", 1141 | "execution_count": 29, 1142 | "metadata": { 1143 | "pycharm": { 1144 | "name": "#%%\n" 1145 | } 1146 | }, 1147 | "outputs": [], 1148 | "source": [ 1149 | "def predict(text):\n", 1150 | " sequences = [text]\n", 1151 | " correction_outputs, _ = model(sequences)\n", 1152 | " tokens = correction_outputs[0][1:len(text) + 1].argmax(1)\n", 1153 | " return ''.join(tokenizer.convert_ids_to_tokens(tokens))" 1154 | ] 1155 | }, 1156 | { 1157 | "cell_type": "code", 1158 | "execution_count": 32, 1159 | "metadata": { 1160 | "pycharm": { 1161 | "name": "#%%\n" 1162 | } 1163 | }, 1164 | "outputs": [ 1165 | { 1166 | "data": { 1167 | "text/plain": [ 1168 | "'今天早上我吃了一个火聋果'" 1169 | ] 1170 | }, 1171 | "execution_count": 32, 1172 | "metadata": {}, 1173 | "output_type": "execute_result" 1174 | } 1175 | ], 1176 | "source": [ 1177 | "predict(\"今天早上我吃了以个火聋果\")" 1178 | ] 1179 | }, 1180 | { 1181 | "cell_type": "code", 1182 | "execution_count": 33, 1183 | "metadata": { 1184 | "pycharm": { 1185 | "name": "#%%\n" 1186 | } 1187 | }, 1188 | "outputs": [ 1189 | { 1190 | "data": { 1191 | "text/plain": [ 1192 | "'我是联系时长两年半的个人练习生蔡徐鲲,喜欢唱跳ra##p蓝球[SEP]'" 1193 | ] 1194 | }, 1195 | "execution_count": 33, 1196 | "metadata": {}, 1197 | "output_type": "execute_result" 1198 | } 1199 | ], 1200 | "source": [ 1201 | "predict(\"我是联系时长两年半的个人练习生蔡徐鲲,喜欢唱跳RAP蓝球\")" 1202 | ] 1203 | }, 1204 | { 1205 | "cell_type": "markdown", 1206 | "metadata": { 1207 | "pycharm": { 1208 | "name": "#%% md\n" 1209 | } 1210 | }, 1211 | "source": [ 1212 | "虽然在数据上模型表现还不错,但在真正使用场景上,效果还是不够好。中文文本纠错果然是一个比较难的任务 T_T !" 1213 | ] 1214 | }, 1215 | { 1216 | "cell_type": "markdown", 1217 | "metadata": { 1218 | "pycharm": { 1219 | "name": "#%% md\n" 1220 | } 1221 | }, 1222 | "source": [ 1223 | "# 参考文献\n", 1224 | "\n", 1225 | "[MDCSpell论文](https://aclanthology.org/2022.findings-acl.98/): https://aclanthology.org/2022.findings-acl.98/\n", 1226 | "\n", 1227 | "[MDCSpell论文笔记](https://blog.csdn.net/zhaohongfei_358/article/details/126973451):https://blog.csdn.net/zhaohongfei_358/article/details/126973451" 1228 | ] 1229 | } 1230 | ], 1231 | "metadata": { 1232 | "kernelspec": { 1233 | "display_name": "Python 3", 1234 | "language": "python", 1235 | "name": "python3" 1236 | }, 1237 | "language_info": { 1238 | "codemirror_mode": { 1239 | "name": "ipython", 1240 | "version": 3 1241 | }, 1242 | "file_extension": ".py", 1243 | "mimetype": "text/x-python", 1244 | "name": "python", 1245 | "nbconvert_exporter": "python", 1246 | "pygments_lexer": "ipython3", 1247 | "version": "3.8.5" 1248 | } 1249 | }, 1250 | "nbformat": 4, 1251 | "nbformat_minor": 1 1252 | } --------------------------------------------------------------------------------