├── .gitignore ├── .vscode ├── launch.json └── settings.json ├── README.md ├── data ├── mlb.pkl ├── stopwords │ └── 哈工大停用词表.txt ├── textcnn_results │ └── multi_cls.h5 ├── vocab.pkl ├── w2v │ └── sgns.target.word-character.char1-2.dynwin5.thr10.neg5.dim300.iter5 └── 题库 │ └── baidu_95.csv └── textcnn ├── config.py ├── data_loader.py ├── multi_proc_utils.py ├── service_helper.py ├── service_test.py ├── textcnn_model.py ├── textcnn_predict.py ├── textcnn_service.py ├── textcnn_train.py └── textcnn_train_helper.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Current File", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal" 13 | } 14 | ] 15 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/home/jimx/anaconda3/bin/python" 3 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TextCNN_MultiClasses 2 | 多标签多分类的textCNN 3 | -------------------------------------------------------------------------------- /data/mlb.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JimXiongGM/TextCNN_MultiClasses/85afaa3a9b287892c2fa3847c51c09d93598530c/data/mlb.pkl -------------------------------------------------------------------------------- /data/stopwords/哈工大停用词表.txt: -------------------------------------------------------------------------------- 1 | ——— 2 | 》), 3 | )÷(1- 4 | ”, 5 | )、 6 | =( 7 | : 8 | → 9 | ℃ 10 | & 11 | * 12 | 一一 13 | ~~~~ 14 | ’ 15 | . 16 | 『 17 | .一 18 | ./ 19 | -- 20 | 』 21 | =″ 22 | 【 23 | [*] 24 | }> 25 | [⑤]] 26 | [①D] 27 | c] 28 | ng•P 29 | * 30 | // 31 | [ 32 | ] 33 | [②e] 34 | [②g] 35 | ={ 36 | } 37 | ,也 38 | ‘ 39 | A 40 | [①⑥] 41 | [②B] 42 | [①a] 43 | [④a] 44 | [①③] 45 | [③h] 46 | ③] 47 | 1. 48 | -- 49 | [②b] 50 | ’‘ 51 | ××× 52 | [①⑧] 53 | 0:2 54 | =[ 55 | [⑤b] 56 | [②c] 57 | [④b] 58 | [②③] 59 | [③a] 60 | [④c] 61 | [①⑤] 62 | [①⑦] 63 | [①g] 64 | ∈[ 65 | [①⑨] 66 | [①④] 67 | [①c] 68 | [②f] 69 | [②⑧] 70 | [②①] 71 | [①C] 72 | [③c] 73 | [③g] 74 | [②⑤] 75 | [②②] 76 | 一. 77 | [①h] 78 | .数 79 | [] 80 | [①B] 81 | 数/ 82 | [①i] 83 | [③e] 84 | [①①] 85 | [④d] 86 | [④e] 87 | [③b] 88 | [⑤a] 89 | [①A] 90 | [②⑧] 91 | [②⑦] 92 | [①d] 93 | [②j] 94 | 〕〔 95 | ][ 96 | :// 97 | ′∈ 98 | [②④ 99 | [⑤e] 100 | 12% 101 | b] 102 | ... 103 | ................... 104 | ⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯③ 105 | ZXFITL 106 | [③F] 107 | 」 108 | [①o] 109 | ]∧′=[ 110 | ∪φ∈ 111 | ′| 112 | {- 113 | ②c 114 | } 115 | [③①] 116 | R.L. 117 | [①E] 118 | Ψ 119 | -[*]- 120 | ↑ 121 | .日 122 | [②d] 123 | [② 124 | [②⑦] 125 | [②②] 126 | [③e] 127 | [①i] 128 | [①B] 129 | [①h] 130 | [①d] 131 | [①g] 132 | [①②] 133 | [②a] 134 | f] 135 | [⑩] 136 | a] 137 | [①e] 138 | [②h] 139 | [②⑥] 140 | [③d] 141 | [②⑩] 142 | e] 143 | 〉 144 | 】 145 | 元/吨 146 | [②⑩] 147 | 2.3% 148 | 5:0 149 | [①] 150 | :: 151 | [②] 152 | [③] 153 | [④] 154 | [⑤] 155 | [⑥] 156 | [⑦] 157 | [⑧] 158 | [⑨] 159 | ⋯⋯ 160 | —— 161 | ? 162 | 、 163 | 。 164 | “ 165 | ” 166 | 《 167 | 》 168 | ! 169 | , 170 | : 171 | ; 172 | ? 173 | . 174 | , 175 | . 176 | ' 177 | ? 178 | · 179 | ——— 180 | ── 181 | ? 182 | — 183 | < 184 | > 185 | ( 186 | ) 187 | 〔 188 | 〕 189 | [ 190 | ] 191 | ( 192 | ) 193 | - 194 | + 195 | 〜 196 | × 197 | / 198 | / 199 | ① 200 | ② 201 | ③ 202 | ④ 203 | ⑤ 204 | ⑥ 205 | ⑦ 206 | ⑧ 207 | ⑨ 208 | ⑩ 209 | Ⅲ 210 | В 211 | " 212 | ; 213 | # 214 | @ 215 | γ 216 | μ 217 | φ 218 | φ. 219 | × 220 | Δ 221 | ■ 222 | ▲ 223 | sub 224 | exp 225 | sup 226 | sub 227 | Lex 228 | # 229 | % 230 | & 231 | ' 232 | + 233 | +ξ 234 | ++ 235 | - 236 | -β 237 | < 238 | <± 239 | <Δ 240 | <λ 241 | <φ 242 | << 243 | = 244 | = 245 | =☆ 246 | =- 247 | > 248 | >λ 249 | _ 250 | 〜± 251 | 〜+ 252 | [⑤f] 253 | [⑤d] 254 | [②i] 255 | ≈ 256 | [②G] 257 | [①f] 258 | LI 259 | ㈧ 260 | [- 261 | ...... 262 | 〉 263 | [③⑩] 264 | 第二 265 | 一番 266 | 一直 267 | 一个 268 | 一些 269 | 许多 270 | 种 271 | 有的是 272 | 也就是说 273 | 末##末 274 | 啊 275 | 阿 276 | 哎 277 | 哎呀 278 | 哎哟 279 | 唉 280 | 俺 281 | 俺们 282 | 按 283 | 按照 284 | 吧 285 | 吧哒 286 | 把 287 | 罢了 288 | 被 289 | 本 290 | 本着 291 | 比 292 | 比方 293 | 比如 294 | 鄙人 295 | 彼 296 | 彼此 297 | 边 298 | 别 299 | 别的 300 | 别说 301 | 并 302 | 并且 303 | 不比 304 | 不成 305 | 不单 306 | 不但 307 | 不独 308 | 不管 309 | 不光 310 | 不过 311 | 不仅 312 | 不拘 313 | 不论 314 | 不怕 315 | 不然 316 | 不如 317 | 不特 318 | 不惟 319 | 不问 320 | 不只 321 | 朝 322 | 朝着 323 | 趁 324 | 趁着 325 | 乘 326 | 冲 327 | 除 328 | 除此之外 329 | 除非 330 | 除了 331 | 此 332 | 此间 333 | 此外 334 | 从 335 | 从而 336 | 打 337 | 待 338 | 但 339 | 但是 340 | 当 341 | 当着 342 | 到 343 | 得 344 | 的 345 | 的话 346 | 等 347 | 等等 348 | 地 349 | 第 350 | 叮咚 351 | 对 352 | 对于 353 | 多 354 | 多少 355 | 而 356 | 而况 357 | 而且 358 | 而是 359 | 而外 360 | 而言 361 | 而已 362 | 尔后 363 | 反过来 364 | 反过来说 365 | 反之 366 | 非但 367 | 非徒 368 | 否则 369 | 嘎 370 | 嘎登 371 | 该 372 | 赶 373 | 个 374 | 各 375 | 各个 376 | 各位 377 | 各种 378 | 各自 379 | 给 380 | 根据 381 | 跟 382 | 故 383 | 故此 384 | 固然 385 | 关于 386 | 管 387 | 归 388 | 果然 389 | 果真 390 | 过 391 | 哈 392 | 哈哈 393 | 呵 394 | 和 395 | 何 396 | 何处 397 | 何况 398 | 何时 399 | 嘿 400 | 哼 401 | 哼唷 402 | 呼哧 403 | 乎 404 | 哗 405 | 还是 406 | 还有 407 | 换句话说 408 | 换言之 409 | 或 410 | 或是 411 | 或者 412 | 极了 413 | 及 414 | 及其 415 | 及至 416 | 即 417 | 即便 418 | 即或 419 | 即令 420 | 即若 421 | 即使 422 | 几 423 | 几时 424 | 己 425 | 既 426 | 既然 427 | 既是 428 | 继而 429 | 加之 430 | 假如 431 | 假若 432 | 假使 433 | 鉴于 434 | 将 435 | 较 436 | 较之 437 | 叫 438 | 接着 439 | 结果 440 | 借 441 | 紧接着 442 | 进而 443 | 尽 444 | 尽管 445 | 经 446 | 经过 447 | 就 448 | 就是 449 | 就是说 450 | 据 451 | 具体地说 452 | 具体说来 453 | 开始 454 | 开外 455 | 靠 456 | 咳 457 | 可 458 | 可见 459 | 可是 460 | 可以 461 | 况且 462 | 啦 463 | 来 464 | 来着 465 | 离 466 | 例如 467 | 哩 468 | 连 469 | 连同 470 | 两者 471 | 了 472 | 临 473 | 另 474 | 另外 475 | 另一方面 476 | 论 477 | 嘛 478 | 吗 479 | 慢说 480 | 漫说 481 | 冒 482 | 么 483 | 每 484 | 每当 485 | 们 486 | 莫若 487 | 某 488 | 某个 489 | 某些 490 | 拿 491 | 哪 492 | 哪边 493 | 哪儿 494 | 哪个 495 | 哪里 496 | 哪年 497 | 哪怕 498 | 哪天 499 | 哪些 500 | 哪样 501 | 那 502 | 那边 503 | 那儿 504 | 那个 505 | 那会儿 506 | 那里 507 | 那么 508 | 那么些 509 | 那么样 510 | 那时 511 | 那些 512 | 那样 513 | 乃 514 | 乃至 515 | 呢 516 | 能 517 | 你 518 | 你们 519 | 您 520 | 宁 521 | 宁可 522 | 宁肯 523 | 宁愿 524 | 哦 525 | 呕 526 | 啪达 527 | 旁人 528 | 呸 529 | 凭 530 | 凭借 531 | 其 532 | 其次 533 | 其二 534 | 其他 535 | 其它 536 | 其一 537 | 其余 538 | 其中 539 | 起 540 | 起见 541 | 起见 542 | 岂但 543 | 恰恰相反 544 | 前后 545 | 前者 546 | 且 547 | 然而 548 | 然后 549 | 然则 550 | 让 551 | 人家 552 | 任 553 | 任何 554 | 任凭 555 | 如 556 | 如此 557 | 如果 558 | 如何 559 | 如其 560 | 如若 561 | 如上所述 562 | 若 563 | 若非 564 | 若是 565 | 啥 566 | 上下 567 | 尚且 568 | 设若 569 | 设使 570 | 甚而 571 | 甚么 572 | 甚至 573 | 省得 574 | 时候 575 | 什么 576 | 什么样 577 | 使得 578 | 是 579 | 是的 580 | 首先 581 | 谁 582 | 谁知 583 | 顺 584 | 顺着 585 | 似的 586 | 虽 587 | 虽然 588 | 虽说 589 | 虽则 590 | 随 591 | 随着 592 | 所 593 | 所以 594 | 他 595 | 他们 596 | 他人 597 | 它 598 | 它们 599 | 她 600 | 她们 601 | 倘 602 | 倘或 603 | 倘然 604 | 倘若 605 | 倘使 606 | 腾 607 | 替 608 | 通过 609 | 同 610 | 同时 611 | 哇 612 | 万一 613 | 往 614 | 望 615 | 为 616 | 为何 617 | 为了 618 | 为什么 619 | 为着 620 | 喂 621 | 嗡嗡 622 | 我 623 | 我们 624 | 呜 625 | 呜呼 626 | 乌乎 627 | 无论 628 | 无宁 629 | 毋宁 630 | 嘻 631 | 吓 632 | 相对而言 633 | 像 634 | 向 635 | 向着 636 | 嘘 637 | 呀 638 | 焉 639 | 沿 640 | 沿着 641 | 要 642 | 要不 643 | 要不然 644 | 要不是 645 | 要么 646 | 要是 647 | 也 648 | 也罢 649 | 也好 650 | 一 651 | 一般 652 | 一旦 653 | 一方面 654 | 一来 655 | 一切 656 | 一样 657 | 一则 658 | 依 659 | 依照 660 | 矣 661 | 以 662 | 以便 663 | 以及 664 | 以免 665 | 以至 666 | 以至于 667 | 以致 668 | 抑或 669 | 因 670 | 因此 671 | 因而 672 | 因为 673 | 哟 674 | 用 675 | 由 676 | 由此可见 677 | 由于 678 | 有 679 | 有的 680 | 有关 681 | 有些 682 | 又 683 | 于 684 | 于是 685 | 于是乎 686 | 与 687 | 与此同时 688 | 与否 689 | 与其 690 | 越是 691 | 云云 692 | 哉 693 | 再说 694 | 再者 695 | 在 696 | 在下 697 | 咱 698 | 咱们 699 | 则 700 | 怎 701 | 怎么 702 | 怎么办 703 | 怎么样 704 | 怎样 705 | 咋 706 | 照 707 | 照着 708 | 者 709 | 这 710 | 这边 711 | 这儿 712 | 这个 713 | 这会儿 714 | 这就是说 715 | 这里 716 | 这么 717 | 这么点儿 718 | 这么些 719 | 这么样 720 | 这时 721 | 这些 722 | 这样 723 | 正如 724 | 吱 725 | 之 726 | 之类 727 | 之所以 728 | 之一 729 | 只是 730 | 只限 731 | 只要 732 | 只有 733 | 至 734 | 至于 735 | 诸位 736 | 着 737 | 着呢 738 | 自 739 | 自从 740 | 自个儿 741 | 自各儿 742 | 自己 743 | 自家 744 | 自身 745 | 综上所述 746 | 总的来看 747 | 总的来说 748 | 总的说来 749 | 总而言之 750 | 总之 751 | 纵 752 | 纵令 753 | 纵然 754 | 纵使 755 | 遵照 756 | 作为 757 | 兮 758 | 呃 759 | 呗 760 | 咚 761 | 咦 762 | 喏 763 | 啐 764 | 喔唷 765 | 嗬 766 | 嗯 767 | 嗳 -------------------------------------------------------------------------------- /data/textcnn_results/multi_cls.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JimXiongGM/TextCNN_MultiClasses/85afaa3a9b287892c2fa3847c51c09d93598530c/data/textcnn_results/multi_cls.h5 -------------------------------------------------------------------------------- /data/vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JimXiongGM/TextCNN_MultiClasses/85afaa3a9b287892c2fa3847c51c09d93598530c/data/vocab.pkl -------------------------------------------------------------------------------- /data/w2v/sgns.target.word-character.char1-2.dynwin5.thr10.neg5.dim300.iter5: -------------------------------------------------------------------------------- 1 | /home/g/data/pretrain_embedding/sgns.target.word-character.char1-2.dynwin5.thr10.neg5.dim300.iter5 -------------------------------------------------------------------------------- /textcnn/config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,pathlib 3 | import torch 4 | 5 | """ 项目的根目录 """ 6 | root = pathlib.Path(os.path.abspath(__file__)).parent.parent 7 | 8 | class Config(object): 9 | def __init__(self): 10 | self.point_path = os.path.join(root,"data","题库/baidu_95.csv") 11 | self.stopwords_path = os.path.join(root,"data","stopwords/哈工大停用词表.txt") 12 | self.w2v_path = os.path.join(root,"data","w2v/sgns.target.word-character.char1-2.dynwin5.thr10.neg5.dim300.iter5") 13 | self.save_dir = os.path.join(root,"data","textcnn_results") 14 | self.save_path = os.path.join(self.save_dir,"multi_cls.h5") 15 | self.mlb_path = os.path.join(root,"data","mlb.pkl") 16 | self.vocab_path = os.path.join(root,"data","vocab.pkl") 17 | self.max_len = 30 18 | 19 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | 21 | self.batch_size = 128 22 | self.embed_dim = 300 23 | self.filter_sizes = [2,3,4,5] 24 | self.num_filters = 128 25 | self.dense_units = 100 26 | self.dropout = 0.5 27 | self.learning_rate = 1e-4 28 | self.num_epochs = 10 29 | self.max_grad_norm = 2.0 30 | self.gamma = 0.9 31 | self.require_improve = 500 -------------------------------------------------------------------------------- /textcnn/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | from tensorflow.keras.preprocessing.text import Tokenizer 4 | from tensorflow.keras.preprocessing.sequence import pad_sequences 5 | import os,re,jieba 6 | import pandas as pd 7 | import numpy as np 8 | from sklearn.model_selection import train_test_split 9 | from config import Config 10 | from multi_proc_utils import parallelize 11 | from sklearn.preprocessing import MultiLabelBinarizer as MLB 12 | from tqdm import tqdm 13 | 14 | config = Config() 15 | 16 | """ 加载停用词 """ 17 | def load_stop_words(stop_word_path): 18 | with open(stop_word_path, 'r', encoding='utf-8') as f1: 19 | stop_words = f1.readlines() 20 | stop_words = [stop_word.strip() for stop_word in stop_words] 21 | return stop_words 22 | 23 | stop_words = load_stop_words(config.stopwords_path) 24 | 25 | """ 清洗文本 """ 26 | def clean_sentence(line): 27 | line = re.sub( 28 | "[a-zA-Z0-9]|[\s+\-\|\!\/\[\]\{\}_,.$%^*(+\"\')]+|[::+——()?【】《》“”!,。?、~@#¥%……&*()]+|题目", '',line) 29 | words = jieba.cut(line, cut_all=False) 30 | return words 31 | 32 | """ 进行分词 """ 33 | def sentence_proc(sentence): 34 | words = clean_sentence(sentence) 35 | words = [word for word in words if word not in stop_words] 36 | return ' '.join(words) 37 | 38 | def proc(df): 39 | df["content"] = df["content"].apply(sentence_proc) 40 | return df 41 | 42 | """ 计算输入长度,使其能涵盖95%样本 """ 43 | def calcu_max_len(df): 44 | df["lengths"] = df["content"].apply(lambda x:x.count(' ')+1) 45 | max_lengths = max(df["lengths"]) 46 | for len_ in range(50,max_lengths,50): 47 | bool_ = df["lengths"] < len_ 48 | cover_rate = sum(bool_.apply(int)) / len(bool_) 49 | if cover_rate >= 0.95: 50 | return len_ 51 | 52 | import pickle 53 | 54 | """ 保存为pickle对象 """ 55 | def save_pickle(s,file_path): 56 | with open(file_path,'wb') as f: 57 | pickle.dump(s,f,protocol=2) 58 | 59 | """ 一:对样本和标签进行数值化,建立词表 """ 60 | def build_dataset(config): 61 | 62 | """ 1:加载数据 """ 63 | print("\nLoading the dataset ... \n") 64 | point_df = pd.read_csv(config.point_path,header=None) 65 | point_df.dropna(inplace=True) 66 | point_df = point_df.rename(columns={0:"label",1:"content"}) 67 | print(f"\nThe shape of the dataset : {point_df.shape}\n") 68 | 69 | """ 2:开多进程进行数据清洗和分词 """ 70 | print("\nCleaning text and segmenting ... \n") 71 | point_df = parallelize(point_df,proc) 72 | 73 | """ 3:对样本进行 zero pad,并转化为id """ 74 | print("\nZero padding and transfering id ...\n") 75 | text_tokenizer = Tokenizer(oov_token="") 76 | text_tokenizer.fit_on_texts(point_df["content"]) 77 | corpus_x = text_tokenizer.texts_to_sequences(point_df["content"]) 78 | config.max_len = calcu_max_len(point_df) 79 | corpus_x = pad_sequences(corpus_x,maxlen=config.max_len,padding="post",truncating="post") 80 | 81 | """ 4: 对多标签分类的标签进行数值化 """ 82 | print("\nNumeralizing the multiclass labels ... \n") 83 | point_df["label"] = point_df["label"].apply(lambda x:x.split()) 84 | mlb = MLB() 85 | corpus_y = mlb.fit_transform(point_df["label"]) 86 | config.num_classes = corpus_y.shape[1] 87 | 88 | """ 5: 用样本构建词表,加入 的id """ 89 | print("\nBuilding the vocab ...\n") 90 | word_index = text_tokenizer.word_index 91 | vocab = dict({"":0}, **word_index) 92 | config.vocab_size = len(vocab) 93 | 94 | """ 保存好标签转化器,在模型预测时用 """ 95 | save_pickle(mlb, config.mlb_path) 96 | 97 | """ 保存好词表,模型预测时用 """ 98 | save_pickle(vocab,config.vocab_path) 99 | 100 | return corpus_x,corpus_y,vocab 101 | 102 | """ 加载百度百科词向量 """ 103 | def load_w2v(path): 104 | with open(path, encoding="utf-8") as f2: 105 | embed_index = {} 106 | for i,line in tqdm(enumerate(f2.readlines()[:10000])): 107 | if i == 0: continue 108 | value = line.split() 109 | word = value[0] 110 | emb = np.asarray(value[1:], dtype="float32") 111 | embed_index[word] = emb 112 | return embed_index 113 | 114 | """ 二: 加载预训练词向量,并与词表相对应 """ 115 | def load_embed_matrix(vocab,config): 116 | 117 | """ 1: 加载百度百科词向量 """ 118 | print("\nLoading baidu baike word2vec ...\n") 119 | embed_index = load_w2v(config.w2v_path) 120 | 121 | """ 2: 词向量矩阵与词表相对应 """ 122 | vocab_size = len(vocab) 123 | embed_matrix = np.zeros((vocab_size,config.embed_dim)) 124 | for word,index in vocab.items(): 125 | vector = embed_index.get(word) 126 | if vector is not None: 127 | embed_matrix[index] = vector 128 | 129 | embed_matrix = torch.FloatTensor(embed_matrix) 130 | 131 | return embed_matrix 132 | 133 | """ 划分数据集 """ 134 | def split_dataset(x,y,size,device): 135 | 136 | train_x, valid_x, train_y, valid_y = train_test_split(x,y,test_size=size,random_state=10) 137 | train_x, test_x, train_y, test_y = train_test_split(train_x, train_y, test_size=size,random_state=10) 138 | 139 | train_x = torch.LongTensor(train_x).to(device) 140 | valid_x = torch.LongTensor(valid_x).to(device) 141 | test_x = torch.LongTensor(test_x).to(device) 142 | 143 | train_y = torch.FloatTensor(train_y).to(device) 144 | valid_y = torch.FloatTensor(valid_y).to(device) 145 | test_y = torch.FloatTensor(test_y).to(device) 146 | 147 | return (train_x,train_y) , (valid_x,valid_y), (test_x, test_y) 148 | 149 | """ 重写 dataset 类""" 150 | class MyDataset(data.Dataset): 151 | def __init__(self, dataset): 152 | self.X, self.Y = dataset 153 | 154 | def __getitem__(self, index): 155 | x, y = self.X[index], self.Y[index] 156 | return x, y 157 | 158 | def __len__(self): 159 | return len(self.X) 160 | 161 | """ 生成 迭代器 """ 162 | def batch_iterator(dataset,batch_size): 163 | dataset = MyDataset(dataset) 164 | batcher = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) 165 | return batcher 166 | 167 | """ 三: 划分数据集,并生成 batch 迭代器 """ 168 | def batch_generator(x,y,size,config): 169 | 170 | """ 1: 划分数据集 """ 171 | print("\nSpliting the dataset ... \n") 172 | train_data, valid_data, test_data = split_dataset(x, y, size, config.device) 173 | 174 | """ 2: 生成 batch 迭代器 """ 175 | train_iter = batch_iterator(train_data, config.batch_size) 176 | valid_iter = batch_iterator(valid_data, config.batch_size) 177 | test_iter = batch_iterator(test_data, config.batch_size) 178 | 179 | return train_iter,valid_iter,test_iter 180 | 181 | """ 四: 计算类别权重,缓解类别不平衡问题 """ 182 | def calcu_class_weights(labels,config): 183 | labels = torch.FloatTensor(labels) 184 | 185 | freqs = torch.zeros_like(labels[0]) 186 | for y in labels: 187 | freqs += y 188 | 189 | weights = freqs / len(labels) 190 | weights = 1 / torch.log(1.01 + weights) 191 | 192 | weights = weights.to(config.device) 193 | return weights 194 | 195 | 196 | 197 | if __name__ == "__main__": 198 | 199 | corpus_x, corpus_y, vocab = build_dataset(config) 200 | 201 | class_weights = calcu_class_weights(corpus_y,config) 202 | 203 | embed_matrix = load_embed_matrix(vocab,config) 204 | 205 | train_iter, valid_iter, test_iter = batch_generator(corpus_x, corpus_y, 0.15, config) 206 | for x_batch, y_batch in train_iter: 207 | print(f"The shape of content batch is {x_batch.shape}") 208 | print(f"The shape of label batch is {y_batch.shape}") 209 | break -------------------------------------------------------------------------------- /textcnn/multi_proc_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Created by LuoJie at 11/17/19 3 | 4 | import pandas as pd 5 | import numpy as np 6 | from multiprocessing import cpu_count, Pool 7 | 8 | # cpu 数量 9 | cores = cpu_count() 10 | # 分块个数 11 | partitions = cores 12 | 13 | def parallelize(df, func): 14 | """ 15 | 多核并行处理模块 16 | :param df: DataFrame数据 17 | :param func: 预处理函数 18 | :return: 处理后的数据 19 | """ 20 | # 数据切分 21 | data_split = np.array_split(df, partitions) 22 | # 线程池 23 | pool = Pool(cores) 24 | # 数据分发 合并 25 | data = pd.concat(pool.map(func, data_split)) 26 | # 关闭线程池 27 | pool.close() 28 | # 执行完close后不会有新的进程加入到pool,join函数等待所有子进程结束 29 | pool.join() 30 | return data 31 | -------------------------------------------------------------------------------- /textcnn/service_helper.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("../") 3 | from textcnn_predict import TextcnnPredict 4 | 5 | message_dic={'200':'正常', 6 | '300':'请求格式错误', 7 | '400':'模型预测失败'} 8 | 9 | 10 | class TextcnnServer: 11 | def __init__(self,device="gpu"): 12 | """ 13 | 把模型的预测函数初始化, 14 | 设置使用CPU还是GPU启动服务. 15 | """ 16 | self.predict = TextcnnPredict(device).predict 17 | 18 | """ 把字典格式的请求数据,解析出来 """ 19 | def parse(self, app_data): 20 | request_id = app_data["request_id"] 21 | text = app_data["query"] 22 | return request_id, text 23 | 24 | """ 得到服务的调用结果,包括模型结果和服务的情况 """ 25 | def get_result(self,data): 26 | code = '200' 27 | try: 28 | request_id, text = self.parse(data) 29 | except Exception as e: 30 | print('error info : {}'.format(e)) 31 | code='300' 32 | request_id = "None" 33 | try: 34 | if code == '200': 35 | label = self.predict(text) 36 | elif code == '300': 37 | label = '高中' 38 | except Exception as e: 39 | print('error info : {}'.format(e)) 40 | label = '高中' 41 | code='400' 42 | 43 | result = {'label': label,'code':code,'message':message_dic[code],'request_id':request_id} 44 | return result 45 | 46 | if __name__ == "__main__": 47 | 48 | server = TextcnnServer(device="gpu") 49 | data = {"request_id": "ExamServer", 50 | "query" :"菠菜从土壤中吸收的氮元素可以用来合成()A.淀粉和纤维素B.葡萄糖和DNAC.核酸和蛋白质D.麦芽糖和脂肪酸"} 51 | print("\n The result is {}".format(server.get_result(data))) -------------------------------------------------------------------------------- /textcnn/service_test.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import json 3 | import socket 4 | import time 5 | import urllib.request 6 | from datetime import timedelta 7 | 8 | 9 | """ 记录花费的时间 """ 10 | def get_time_dif(start_time): 11 | end_time = time.time() 12 | time_dif = end_time - start_time 13 | return time_dif 14 | 15 | """ 测试服务的响应时间 """ 16 | def test_service(content, port): 17 | 18 | url = 'http://0.0.0.0:{}/ExamServer'.format(port) 19 | app_data = {"request_id": "ExamServer", "query": content} 20 | 21 | """ 转化为json格式 """ 22 | app_data=json.dumps(app_data).encode("utf-8") 23 | 24 | start_time = time.time() 25 | req = urllib.request.Request(url, app_data) 26 | try: 27 | """ 调用服务,得到结果 """ 28 | response = urllib.request.urlopen(req) 29 | response = response.read().decode("utf-8") 30 | 31 | """ 从json格式中解析出来 """ 32 | response = json.loads(response) 33 | except Exception as e: 34 | print(e) 35 | response = None 36 | 37 | """ 打印耗时 """ 38 | time_usage = get_time_dif(start_time) 39 | print("Time usage: {}".format(time_usage)) 40 | print(response) 41 | return time_usage 42 | 43 | if __name__=='__main__': 44 | 45 | """ 测试1000次,得到平均响应时间 """ 46 | time_usage = 0 47 | for i in range(1000): 48 | content = "菠菜从土壤中吸收的氮元素可以用来合成()A.淀粉和纤维素B.葡萄糖和DNAC.核酸和蛋白质D.麦芽糖和脂肪酸" 49 | time_usage += test_service(content, 6060) 50 | print("Time usage average is {}".format( time_usage / 1000)) -------------------------------------------------------------------------------- /textcnn/textcnn_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class TextCNN(nn.Module): 7 | def __init__(self,config): 8 | super(TextCNN,self).__init__() 9 | self.embedding = nn.Embedding.from_pretrained(config.embed_matrix,freeze=False) 10 | self.convs = nn.ModuleList( 11 | [nn.Conv2d(1, config.num_filters, (k, config.embed_dim)) for k in config.filter_sizes]) 12 | self.dropout = nn.Dropout(config.dropout) 13 | self.fc = nn.Sequential(nn.Linear(config.num_filters * len(config.filter_sizes), config.dense_units),nn.ReLU()) 14 | self.linear = nn.Linear(config.dense_units,config.num_classes) 15 | 16 | def conv_and_pool(self, x, conv): 17 | x = F.relu(conv(x)).squeeze(3) 18 | x = F.max_pool1d(x, x.size(2)).squeeze(2) 19 | return x 20 | 21 | def forward(self,x): 22 | out = self.embedding(x) 23 | out = out.unsqueeze(1) 24 | out = torch.cat([self.conv_and_pool(out,conv) for conv in self.convs], 1) 25 | out = self.dropout(out) 26 | out = self.fc(out) 27 | return self.linear(out) -------------------------------------------------------------------------------- /textcnn/textcnn_predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from config import Config 4 | import pickle, re 5 | import jieba 6 | 7 | """ 加载停用词 """ 8 | def load_stop_words(stop_word_path): 9 | with open(stop_word_path, 'r', encoding='utf-8') as f1: 10 | stop_words = f1.readlines() 11 | stop_words = [stop_word.strip() for stop_word in stop_words if stop_word.strip()] 12 | return stop_words 13 | 14 | config = Config() 15 | stop_words = load_stop_words(config.stopwords_path) 16 | 17 | """ 清洗文本 """ 18 | def clean_sentence(line): 19 | line = re.sub( 20 | "[a-zA-Z0-9]|[\s+\-\|\!\/\[\]\{\}_,.$%^*(+\"\')]+|[::+——()?【】《》“”!,。?、~@#¥%……&*()]+|题目", '',line) 21 | words = jieba.lcut(line, cut_all=False) 22 | return words 23 | 24 | """ 进行分词 """ 25 | def sentence_proc(sentence): 26 | words = clean_sentence(sentence) 27 | words = [word for word in words if word not in stop_words] 28 | return ' '.join(words) 29 | 30 | def load_pickle(file_path): 31 | """ 32 | 用于加载 python的pickle对象 33 | """ 34 | return pickle.load(open(file_path,'rb')) 35 | 36 | class Vocab: 37 | def __init__(self,vocab_path): 38 | """ 39 | 加载词表,用于把词转化为ID 40 | """ 41 | self.word2id = load_pickle(vocab_path) 42 | 43 | """ 如果词不在词表中,就转化为 '' """ 44 | def w2i(self, word): 45 | if word not in self.word2id: 46 | return self.word2id[""] 47 | return self.word2id[word] 48 | 49 | class TextcnnPredict: 50 | def __init__(self, device="gpu"): 51 | """ 52 | 模型预测,可以选择使用gpu还是cpu 53 | : param: self.mlb 用于将预测结果转化为多标签 54 | """ 55 | self.config = Config() 56 | self.device = device 57 | self.vocab = Vocab(self.config.vocab_path) 58 | self.mlb = load_pickle(self.config.mlb_path) 59 | self.model = self.load_model() 60 | self.model.eval() 61 | 62 | """ 加载为gpu还是cpu版的模型 """ 63 | def load_model(self): 64 | if self.device == "cpu": 65 | model = torch.load(self.config.save_path, map_location="cpu") 66 | else: 67 | model = torch.load(self.config.save_path) 68 | return model 69 | 70 | """ 把试题进行分词,并转化为id,进行pad """ 71 | def text_to_ids(self,sentence): 72 | words = sentence_proc(sentence).split() 73 | words = words[: self.config.max_len] 74 | ids = [self.vocab.w2i(w) for w in words] 75 | 76 | """ 按最大长度进行pad """ 77 | ids += [self.vocab.w2i("")] * (self.config.max_len - len(words)) 78 | if self.device == "cpu": 79 | ids = torch.LongTensor([ids]) 80 | else: 81 | ids = torch.LongTensor([ids]).to(self.config.device) 82 | return ids 83 | 84 | def predict(self,text): 85 | 86 | with torch.no_grad(): 87 | 88 | ids = self.text_to_ids(text) 89 | outputs = self.model(ids) 90 | 91 | """ 用sigmoid函数转化为概率分布 """ 92 | outputs = torch.sigmoid(outputs) 93 | outputs = outputs.data.cpu().numpy() 94 | 95 | """ 将概率转化为数值化标签,再转化为多分类的标签 """ 96 | predicts = np.where(outputs > 0.5, 1, 0) 97 | labels = self.mlb.inverse_transform(predicts) 98 | 99 | return labels[0] 100 | 101 | if __name__ == "__main__": 102 | 103 | text = "菠菜从土壤中吸收的氮元素可以用来合成()A.淀粉和纤维素B.葡萄糖和DNAC.核酸和蛋白质D.麦芽糖和脂肪酸" 104 | real_label = "高中 生物 分子与细胞 组成细胞的化学元素 组成细胞的化合物" 105 | 106 | model = TextcnnPredict(device="cpu") 107 | predict_label = model.predict(text) 108 | print("\nPredited label is %s \n" % " ".join(predict_label)) -------------------------------------------------------------------------------- /textcnn/textcnn_service.py: -------------------------------------------------------------------------------- 1 | from sanic import Sanic 2 | from sanic import response 3 | import json 4 | import time 5 | from sanic.exceptions import NotFound 6 | from service_helper import TextcnnServer 7 | 8 | """ 定义 ip和端口号 """ 9 | app = Sanic(__name__) 10 | ip, port = "0.0.0.0", 6060 11 | 12 | """ 路由 (ExmaServer) 错误时,返回错误信息 """ 13 | @app.exception(NotFound) 14 | async def url_404(request, excep): 15 | return response.json({"Error":excep}) 16 | 17 | """ 定义路由(ExamServer)和请求方式(POST) """ 18 | @app.route('/ExamServer',methods=['POST']) 19 | async def model_server(request): 20 | try: 21 | request_json = request.body 22 | input_json = json.loads(request_json.decode('utf8')) 23 | result = TextcnnServer(device="gpu").get_result(input_json) 24 | except Exception as e: 25 | result = {"code": 400, "message": "预测失败", "Error": e} 26 | return response.json(result) 27 | 28 | 29 | if __name__ == '__main__': 30 | app.run(host=ip,port=port) -------------------------------------------------------------------------------- /textcnn/textcnn_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from textcnn_model import TextCNN 4 | from data_loader import build_dataset, calcu_class_weights, load_embed_matrix, batch_generator, config 5 | from textcnn_train_helper import init_network, train_model 6 | import numpy as np 7 | import os 8 | os.environ['CUDA_LAUNCH_BLOCKING'] = "1" 9 | 10 | 11 | """ 统计模型的参数 """ 12 | def count_params(model): 13 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 14 | 15 | """ 设置随机数种子 """ 16 | def set_manual_seed(seed): 17 | np.random.seed(seed) 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed(seed) 20 | torch.backends.cudnn.deterministic = True 21 | 22 | """ 模型训练 """ 23 | def train(config): 24 | 25 | set_manual_seed(10) 26 | 27 | """ 1: 文本清洗和分词,构建词表 """ 28 | print("Preparing the batch data ... \n") 29 | corpus_x, corpus_y, vocab = build_dataset(config) 30 | 31 | """ 2:计算类别权重,缓解类别不平衡问题 """ 32 | class_weights = calcu_class_weights(corpus_y, config) 33 | config.class_weights = class_weights 34 | 35 | """ 3:加载预训练的词向量 """ 36 | embed_matrix = load_embed_matrix(vocab, config) 37 | config.embed_matrix = embed_matrix 38 | 39 | """ 4: 划分数据集和生成batch迭代器 """ 40 | train_iter, valid_iter, test_iter = batch_generator(corpus_x,corpus_y,0.15,config) 41 | 42 | """ 5:模型初始化 """ 43 | print("Building the textcnn model ... \n") 44 | model = TextCNN(config) 45 | print(f'The model has {count_params(model):,} trainable parameters\n') 46 | 47 | model.to(config.device) 48 | 49 | """ 6:开始训练模型 """ 50 | print("Start the training ... \n") 51 | init_network(model) 52 | train_model(config, model, train_iter, valid_iter, test_iter) 53 | 54 | if __name__ == "__main__": 55 | train(config) -------------------------------------------------------------------------------- /textcnn/textcnn_train_helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import time,os 5 | from sklearn.metrics import f1_score,accuracy_score 6 | from datetime import timedelta 7 | 8 | """ 记录训练时间 """ 9 | def get_time_dif(start_time): 10 | end_time = time.time() 11 | time_dif = end_time - start_time 12 | return timedelta(seconds=int(round(time_dif))) 13 | 14 | """ 网络的参数进行xavier初始化,embedding 层不需要 """ 15 | def init_network(model, method="xavier", exclude="embedding"): 16 | for name, w in model.named_parameters(): 17 | if exclude in name: 18 | continue 19 | if "weight" in name: 20 | nn.init.xavier_normal_(w) 21 | elif "bias" in name: 22 | nn.init.constant_(w, 0) 23 | 24 | def train_model(config, model, train_iter, valid_iter, test_iter): 25 | start_time = time.time() 26 | model.train() 27 | 28 | """ 定义损失函数,并传入类别权重 """ 29 | # criterion = nn.BCEWithLogitsLoss(pos_weight=config.class_weights) 30 | criterion = nn.BCEWithLogitsLoss(pos_weight=config.class_weights) 31 | 32 | """ 定义优化器,进行梯度裁剪,和学习率衰减 """ 33 | optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) 34 | # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.max_grad_norm) 35 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=config.gamma) 36 | 37 | """ 用early stop 防止过拟合 """ 38 | total_batch = 0 39 | valid_best_f1 = float('-inf') 40 | last_improve = 0 41 | flag = False 42 | save_path = os.path.join(config.save_dir,"multi_cls.h5") 43 | 44 | for epoch in range(config.num_epochs): 45 | scheduler.step() 46 | print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs)) 47 | 48 | """ 梯度清零,计算loss,反向传播 """ 49 | for i, (trains, labels) in enumerate(train_iter): 50 | outputs = model(trains) 51 | model.zero_grad() 52 | 53 | """ 训练时,计算损失带权重 """ 54 | loss = criterion(outputs, labels) 55 | loss.backward() 56 | 57 | optimizer.step() 58 | 59 | """ 每训练10个batch就验证一次,如果有提升,就保存并测试一次 """ 60 | if total_batch % 10 == 0: 61 | 62 | valid_f1_macro, valid_f1_micro, valid_loss = evaluate(config, model, valid_iter, test=False) 63 | 64 | """ 以f1微平均作为early stop的监控指标 """ 65 | if valid_f1_micro > valid_best_f1: 66 | evaluate(config, model, test_iter, test=True) 67 | valid_best_f1 = valid_f1_micro 68 | torch.save(model, save_path) 69 | improve = '*' 70 | last_improve = total_batch 71 | else: 72 | improve = '' 73 | 74 | time_dif = get_time_dif(start_time) 75 | msg = 'Iter: {} | Train Loss: {:.4f} | Val Loss: {:.4f} | Val F1-macro: {:.4f} | Val F1-micro: {:.4f} | Time: {} | {}' 76 | print(msg.format(total_batch, loss.item(), valid_loss, valid_f1_macro, valid_f1_micro, time_dif, improve)) 77 | 78 | model.train() 79 | 80 | total_batch += 1 81 | if total_batch - last_improve > config.require_improve: 82 | """ 验证集loss超过500batch没下降,结束训练 """ 83 | print("No optimization for a long time, auto-stopping...") 84 | flag = True 85 | break 86 | if flag: 87 | break 88 | 89 | """ 评估函数 """ 90 | def evaluate(config, model, data_iter, test=False): 91 | 92 | """ 验证时切换到 evaluate 模式,验证结束再切换为 train 模式 """ 93 | model.eval() 94 | 95 | criterion = nn.BCEWithLogitsLoss() 96 | 97 | loss_total = 0 98 | labels_all = [] 99 | predicts_all = [] 100 | 101 | """ 验证和预测时都不需要计算梯度 """ 102 | with torch.no_grad(): 103 | 104 | for texts, labels in data_iter: 105 | outputs = model(texts) 106 | 107 | """ 测试和验证时,计算损失不用带权重 """ 108 | loss = criterion(outputs, labels) 109 | loss_total += loss 110 | 111 | """ 把Tensor数格式转化为numpy格式 """ 112 | labels = labels.data.cpu().numpy() 113 | outputs = outputs.data.cpu().numpy() 114 | 115 | """ 转化为多分类的标签 """ 116 | predicts = np.where(outputs > 0.5, 1, 0) 117 | 118 | labels_all += labels.tolist() 119 | predicts_all += predicts.tolist() 120 | 121 | labels_all = np.array(labels_all,dtype=int) 122 | predicts_all = np.array(predicts_all,dtype=int) 123 | 124 | """ 计算f1值(宏平均和微平均) """ 125 | f1_score_macro = f1_score(labels_all, predicts_all,average='macro') 126 | f1_score_micro = f1_score(labels_all, predicts_all,average='micro') 127 | accuracy = accuracy_score(labels_all, predicts_all) 128 | 129 | if test: 130 | print("1: Accuracy of model is {:.4f}\n".format(accuracy)) 131 | print("2: F1-macro of model is {:.4f}\n".format(f1_score_macro)) 132 | print("3: F1-micro of model is {:.4f}\n".format(f1_score_micro)) 133 | 134 | return f1_score_macro, f1_score_micro, loss_total / len(data_iter) --------------------------------------------------------------------------------