├── .gitignore ├── README.md ├── data ├── multinomial_naive_bayes_model └── stopwords.txt ├── library ├── ambiguity.dic └── default.dic ├── pom.xml └── src ├── main ├── java │ └── com │ │ └── fullstackyang │ │ └── nlp │ │ └── classifier │ │ ├── feature │ │ ├── ChiSquaredStrategy.java │ │ ├── Feature.java │ │ ├── FeatureSelection.java │ │ └── IGStrategy.java │ │ ├── model │ │ ├── Category.java │ │ ├── Doc.java │ │ ├── Term.java │ │ └── TrainSet.java │ │ ├── naivebayes │ │ ├── NaiveBayesClassifier.java │ │ ├── NaiveBayesKnowledgeBase.java │ │ ├── NaiveBayesLearner.java │ │ └── NaiveBayesModels.java │ │ └── utils │ │ ├── Calculator.java │ │ ├── FileUtils.java │ │ └── nlp │ │ ├── AnsjSegmentor.java │ │ ├── HanLPSegmentor.java │ │ ├── MyStopWords.java │ │ ├── NLPTools.java │ │ └── TermFilter.java └── resources │ └── logback.xml └── test └── java └── com └── fullstackyang └── nlp └── classifier └── model └── TestClassifier.java /.gitignore: -------------------------------------------------------------------------------- 1 | *.war 2 | *.ear 3 | 4 | .idea 5 | *.iml 6 | out 7 | gen 8 | 9 | target/ 10 | hs_err_pid* 11 | 12 | logs/ 13 | 14 | #测试集 15 | test-trainset/ 16 | #训练集 17 | trainset/ 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # article-classifier 2 | 基于朴素贝叶斯实现的一款微信公众号文章分类器 3 | ## 运行环境 4 | - 本项目需要在Java 8环境下运行 5 | - 项目根目录下data/是用来存放训练好的模型文件,以及停用词表,当然你可以根据需要修改这些路径,模型文件路径在naivebayes.NaiveBayesModels中修改,停用词表路径在utils.MyStopWord中修改 6 | - 关于分词器 7 | - 系统中使用了两个分词器,Ansj(默认)[https://github.com/NLPchina/ansj_seg] 和HanLp [https://github.com/hankcs/HanLP], 这里表示感谢。 8 | - 根目录下的library/是Ansj所需要的文件 9 | - 如有其他分词,则可以实现NLPTools中的接口,并在实现时转化为Term对象,同时确认已经实现了统计词频的功能,另外词性是用来过滤噪声词的,若分词器未提供词性标注功能,默认可以全部标注为n,然后使用其他过滤方法。 10 | 11 | ## 如何训练 12 | 训练集的目录结构如下: 13 | ``` 14 | trainset/ 15 | |--养生/ 16 | |--文档1 17 | |--文档2 18 | |--历史/ 19 | |--文档3 20 | |--文档4 21 | |--文档4 22 | ... 23 | |--游戏 24 | ... 25 | ``` 26 | 每个子目录的名称将被取出,定义为类别的名称 27 | 28 | 训练模型时调用naivebayes.NaiveBayesLearner的主函数即可,传入训练集的路径,同时可以设定特征选择的方法(ChiSquaredStrategy或IGStrategy),以及朴素贝叶斯的模型(Bernoulli和Multinomial) 29 | ``` 30 | public static void main(String[] args) { 31 | TrainSet trainSet = new TrainSet(System.getProperty("user.dir") + "/trainset/"); 32 | 33 | log.info("特征选择开始..."); 34 | FeatureSelection featureSelection = new FeatureSelection(new ChiSquaredStrategy(trainSet.getCategorySet(), trainSet.getTotalDoc())); 35 | List features = featureSelection.select(trainSet.getDocs()); 36 | log.info("特征选择完成,特征数:[" + features.size() + "]"); 37 | 38 | NaiveBayesModels model = NaiveBayesModels.Multinomial; 39 | NaiveBayesLearner learner = new NaiveBayesLearner(model, trainSet, Sets.newHashSet(features)); 40 | learner.statistics().build().write(model.getModelPath()); 41 | log.info("模型文件写入完成,路径:" + model.getModelPath()); 42 | } 43 | ``` 44 | 训练时,根据需要调整JVM参数-Xmx,笔者的训练集大于3万篇文档,设置-Xmx2000m,训练结束时,模型文件生成到data/目录下,目前提交了一个已经训练好的模型文件,可以直接使用。 45 | 46 | ## 如何预测分类 47 | 初始化分类器,调用predict方法 48 | ``` 49 | @Test 50 | public void test() { 51 | NaiveBayesClassifier classifier = new NaiveBayesClassifier(NaiveBayesModels.Multinomial); 52 | String text = "明日赛事推荐:切尔西巴萨冤家路窄,恒大申花再战亚冠"; 53 | String category = classifier.predict(text); 54 | System.out.println(category); 55 | } 56 | ``` -------------------------------------------------------------------------------- /data/stopwords.txt: -------------------------------------------------------------------------------- 1 | $ 2 | 0 3 | 1 4 | 2 5 | 3 6 | 4 7 | 5 8 | 6 9 | 7 10 | 8 11 | 9 12 | ? 13 | _ 14 | “ 15 | ” 16 | 、 17 | 。 18 | 《 19 | 》 20 | 一 21 | 一些 22 | 一何 23 | 一切 24 | 一则 25 | 一方面 26 | 一旦 27 | 一来 28 | 一样 29 | 一般 30 | 一转眼 31 | 万一 32 | 上 33 | 上下 34 | 下 35 | 不 36 | 不仅 37 | 不但 38 | 不光 39 | 不单 40 | 不只 41 | 不外乎 42 | 不如 43 | 不妨 44 | 不尽 45 | 不尽然 46 | 不得 47 | 不怕 48 | 不惟 49 | 不成 50 | 不拘 51 | 不料 52 | 不是 53 | 不比 54 | 不然 55 | 不特 56 | 不独 57 | 不管 58 | 不至于 59 | 不若 60 | 不论 61 | 不过 62 | 不问 63 | 与 64 | 与其 65 | 与其说 66 | 与否 67 | 与此同时 68 | 且 69 | 且不说 70 | 且说 71 | 两者 72 | 个 73 | 个别 74 | 临 75 | 为 76 | 为了 77 | 为什么 78 | 为何 79 | 为止 80 | 为此 81 | 为着 82 | 乃 83 | 乃至 84 | 乃至于 85 | 么 86 | 之 87 | 之一 88 | 之所以 89 | 之类 90 | 乌乎 91 | 乎 92 | 乘 93 | 也 94 | 也好 95 | 也罢 96 | 了 97 | 二来 98 | 于 99 | 于是 100 | 于是乎 101 | 云云 102 | 云尔 103 | 些 104 | 亦 105 | 人 106 | 人们 107 | 人家 108 | 什么 109 | 什么样 110 | 今 111 | 介于 112 | 仍 113 | 仍旧 114 | 从 115 | 从此 116 | 从而 117 | 他 118 | 他人 119 | 他们 120 | 以 121 | 以上 122 | 以为 123 | 以便 124 | 以免 125 | 以及 126 | 以故 127 | 以期 128 | 以来 129 | 以至 130 | 以至于 131 | 以致 132 | 们 133 | 任 134 | 任何 135 | 任凭 136 | 似的 137 | 但 138 | 但凡 139 | 但是 140 | 何 141 | 何以 142 | 何况 143 | 何处 144 | 何时 145 | 余外 146 | 作为 147 | 你 148 | 你们 149 | 使 150 | 使得 151 | 例如 152 | 依 153 | 依据 154 | 依照 155 | 便于 156 | 俺 157 | 俺们 158 | 倘 159 | 倘使 160 | 倘或 161 | 倘然 162 | 倘若 163 | 借 164 | 假使 165 | 假如 166 | 假若 167 | 傥然 168 | 像 169 | 儿 170 | 先不先 171 | 光是 172 | 全体 173 | 全部 174 | 兮 175 | 关于 176 | 其 177 | 其一 178 | 其中 179 | 其二 180 | 其他 181 | 其余 182 | 其它 183 | 其次 184 | 具体地说 185 | 具体说来 186 | 兼之 187 | 内 188 | 再 189 | 再其次 190 | 再则 191 | 再有 192 | 再者 193 | 再者说 194 | 再说 195 | 冒 196 | 冲 197 | 况且 198 | 几 199 | 几时 200 | 凡 201 | 凡是 202 | 凭 203 | 凭借 204 | 出于 205 | 出来 206 | 分别 207 | 则 208 | 则甚 209 | 别 210 | 别人 211 | 别处 212 | 别是 213 | 别的 214 | 别管 215 | 别说 216 | 到 217 | 前后 218 | 前此 219 | 前者 220 | 加之 221 | 加以 222 | 即 223 | 即令 224 | 即使 225 | 即便 226 | 即如 227 | 即或 228 | 即若 229 | 却 230 | 去 231 | 又 232 | 又及 233 | 及 234 | 及其 235 | 及至 236 | 反之 237 | 反而 238 | 反过来 239 | 反过来说 240 | 受到 241 | 另 242 | 另一方面 243 | 另外 244 | 另悉 245 | 只 246 | 只当 247 | 只怕 248 | 只是 249 | 只有 250 | 只消 251 | 只要 252 | 只限 253 | 叫 254 | 叮咚 255 | 可 256 | 可以 257 | 可是 258 | 可见 259 | 各 260 | 各个 261 | 各位 262 | 各种 263 | 各自 264 | 同 265 | 同时 266 | 后 267 | 后者 268 | 向 269 | 向使 270 | 向着 271 | 吓 272 | 吗 273 | 否则 274 | 吧 275 | 吧哒 276 | 吱 277 | 呀 278 | 呃 279 | 呕 280 | 呗 281 | 呜 282 | 呜呼 283 | 呢 284 | 呵 285 | 呵呵 286 | 呸 287 | 呼哧 288 | 咋 289 | 和 290 | 咚 291 | 咦 292 | 咧 293 | 咱 294 | 咱们 295 | 咳 296 | 哇 297 | 哈 298 | 哈哈 299 | 哉 300 | 哎 301 | 哎呀 302 | 哎哟 303 | 哗 304 | 哟 305 | 哦 306 | 哩 307 | 哪 308 | 哪个 309 | 哪些 310 | 哪儿 311 | 哪天 312 | 哪年 313 | 哪怕 314 | 哪样 315 | 哪边 316 | 哪里 317 | 哼 318 | 哼唷 319 | 唉 320 | 唯有 321 | 啊 322 | 啐 323 | 啥 324 | 啦 325 | 啪达 326 | 啷当 327 | 喂 328 | 喏 329 | 喔唷 330 | 喽 331 | 嗡 332 | 嗡嗡 333 | 嗬 334 | 嗯 335 | 嗳 336 | 嘎 337 | 嘎登 338 | 嘘 339 | 嘛 340 | 嘻 341 | 嘿 342 | 嘿嘿 343 | 因 344 | 因为 345 | 因了 346 | 因此 347 | 因着 348 | 因而 349 | 固然 350 | 在 351 | 在下 352 | 在于 353 | 地 354 | 基于 355 | 处在 356 | 多 357 | 多么 358 | 多少 359 | 大 360 | 大家 361 | 她 362 | 她们 363 | 好 364 | 如 365 | 如上 366 | 如上所述 367 | 如下 368 | 如何 369 | 如其 370 | 如同 371 | 如是 372 | 如果 373 | 如此 374 | 如若 375 | 始而 376 | 孰料 377 | 孰知 378 | 宁 379 | 宁可 380 | 宁愿 381 | 宁肯 382 | 它 383 | 它们 384 | 对 385 | 对于 386 | 对待 387 | 对方 388 | 对比 389 | 将 390 | 小 391 | 尔 392 | 尔后 393 | 尔尔 394 | 尚且 395 | 就 396 | 就是 397 | 就是了 398 | 就是说 399 | 就算 400 | 就要 401 | 尽 402 | 尽管 403 | 尽管如此 404 | 岂但 405 | 己 406 | 已 407 | 已矣 408 | 巴 409 | 巴巴 410 | 并 411 | 并且 412 | 并非 413 | 庶乎 414 | 庶几 415 | 开外 416 | 开始 417 | 归 418 | 归齐 419 | 当 420 | 当地 421 | 当然 422 | 当着 423 | 彼 424 | 彼时 425 | 彼此 426 | 往 427 | 待 428 | 很 429 | 得 430 | 得了 431 | 怎 432 | 怎么 433 | 怎么办 434 | 怎么样 435 | 怎奈 436 | 怎样 437 | 总之 438 | 总的来看 439 | 总的来说 440 | 总的说来 441 | 总而言之 442 | 恰恰相反 443 | 您 444 | 惟其 445 | 慢说 446 | 我 447 | 我们 448 | 或 449 | 或则 450 | 或是 451 | 或曰 452 | 或者 453 | 截至 454 | 所 455 | 所以 456 | 所在 457 | 所幸 458 | 所有 459 | 才 460 | 才能 461 | 打 462 | 打从 463 | 把 464 | 抑或 465 | 拿 466 | 按 467 | 按照 468 | 换句话说 469 | 换言之 470 | 据 471 | 据此 472 | 接着 473 | 故 474 | 故此 475 | 故而 476 | 旁人 477 | 无 478 | 无宁 479 | 无论 480 | 既 481 | 既往 482 | 既是 483 | 既然 484 | 时候 485 | 是 486 | 是以 487 | 是的 488 | 曾 489 | 替 490 | 替代 491 | 最 492 | 有 493 | 有些 494 | 有关 495 | 有及 496 | 有时 497 | 有的 498 | 望 499 | 朝 500 | 朝着 501 | 本 502 | 本人 503 | 本地 504 | 本着 505 | 本身 506 | 来 507 | 来着 508 | 来自 509 | 来说 510 | 极了 511 | 果然 512 | 果真 513 | 某 514 | 某个 515 | 某些 516 | 某某 517 | 根据 518 | 欤 519 | 正值 520 | 正如 521 | 正巧 522 | 正是 523 | 此 524 | 此地 525 | 此处 526 | 此外 527 | 此时 528 | 此次 529 | 此间 530 | 毋宁 531 | 每 532 | 每当 533 | 比 534 | 比及 535 | 比如 536 | 比方 537 | 没奈何 538 | 沿 539 | 沿着 540 | 漫说 541 | 焉 542 | 然则 543 | 然后 544 | 然而 545 | 照 546 | 照着 547 | 犹且 548 | 犹自 549 | 甚且 550 | 甚么 551 | 甚或 552 | 甚而 553 | 甚至 554 | 甚至于 555 | 用 556 | 用来 557 | 由 558 | 由于 559 | 由是 560 | 由此 561 | 由此可见 562 | 的 563 | 的确 564 | 的话 565 | 直到 566 | 相对而言 567 | 省得 568 | 看 569 | 眨眼 570 | 着 571 | 着呢 572 | 矣 573 | 矣乎 574 | 矣哉 575 | 离 576 | 竟而 577 | 第 578 | 等 579 | 等到 580 | 等等 581 | 简言之 582 | 管 583 | 类如 584 | 紧接着 585 | 纵 586 | 纵令 587 | 纵使 588 | 纵然 589 | 经 590 | 经过 591 | 结果 592 | 给 593 | 继之 594 | 继后 595 | 继而 596 | 综上所述 597 | 罢了 598 | 者 599 | 而 600 | 而且 601 | 而况 602 | 而后 603 | 而外 604 | 而已 605 | 而是 606 | 而言 607 | 能 608 | 能否 609 | 腾 610 | 自 611 | 自个儿 612 | 自从 613 | 自各儿 614 | 自后 615 | 自家 616 | 自己 617 | 自打 618 | 自身 619 | 至 620 | 至于 621 | 至今 622 | 至若 623 | 致 624 | 般的 625 | 若 626 | 若夫 627 | 若是 628 | 若果 629 | 若非 630 | 莫不然 631 | 莫如 632 | 莫若 633 | 虽 634 | 虽则 635 | 虽然 636 | 虽说 637 | 被 638 | 要 639 | 要不 640 | 要不是 641 | 要不然 642 | 要么 643 | 要是 644 | 譬喻 645 | 譬如 646 | 让 647 | 许多 648 | 论 649 | 设使 650 | 设或 651 | 设若 652 | 诚如 653 | 诚然 654 | 该 655 | 说来 656 | 诸 657 | 诸位 658 | 诸如 659 | 谁 660 | 谁人 661 | 谁料 662 | 谁知 663 | 贼死 664 | 赖以 665 | 赶 666 | 起 667 | 起见 668 | 趁 669 | 趁着 670 | 越是 671 | 距 672 | 跟 673 | 较 674 | 较之 675 | 边 676 | 过 677 | 还 678 | 还是 679 | 还有 680 | 还要 681 | 这 682 | 这一来 683 | 这个 684 | 这么 685 | 这么些 686 | 这么样 687 | 这么点儿 688 | 这些 689 | 这会儿 690 | 这儿 691 | 这就是说 692 | 这时 693 | 这样 694 | 这次 695 | 这般 696 | 这边 697 | 这里 698 | 进而 699 | 连 700 | 连同 701 | 逐步 702 | 通过 703 | 遵循 704 | 遵照 705 | 那 706 | 那个 707 | 那么 708 | 那么些 709 | 那么样 710 | 那些 711 | 那会儿 712 | 那儿 713 | 那时 714 | 那样 715 | 那般 716 | 那边 717 | 那里 718 | 都 719 | 鄙人 720 | 鉴于 721 | 针对 722 | 阿 723 | 除 724 | 除了 725 | 除外 726 | 除开 727 | 除此之外 728 | 除非 729 | 随 730 | 随后 731 | 随时 732 | 随着 733 | 难道说 734 | 非但 735 | 非徒 736 | 非特 737 | 非独 738 | 靠 739 | 顺 740 | 顺着 741 | 首先 742 | 干嘛 743 | ! 744 | , 745 | : 746 | ; 747 | ? -------------------------------------------------------------------------------- /library/ambiguity.dic: -------------------------------------------------------------------------------- 1 | 习近平 nr 2 | 李民 nr 工作 vn 3 | 三个 m 和尚 n 4 | 的确 d 定 v 不 v 5 | 大 a 和尚 n 6 | 张三 nr 和 c 7 | 动漫 n 游戏 n 8 | 邓颖超 nr 生前 t -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | com.fullstackyang.nlp 8 | article-classifier 9 | 1.1-SNAPSHOT 10 | 11 | 12 | 1.8 13 | UTF-8 14 | 21.0 15 | 1.16.18 16 | 3.8.0 17 | 18 | 19 | 20 | 21 | org.slf4j 22 | slf4j-api 23 | 1.7.25 24 | 25 | 26 | ch.qos.logback 27 | logback-core 28 | 1.1.11 29 | 30 | 31 | ch.qos.logback 32 | logback-classic 33 | 1.1.11 34 | 35 | 36 | com.google.guava 37 | guava 38 | ${google.guava.version} 39 | 40 | 41 | org.projectlombok 42 | lombok 43 | ${lombok.version} 44 | 45 | 46 | com.hankcs 47 | hanlp 48 | portable-1.5.3 49 | 50 | 51 | org.ansj 52 | ansj_seg 53 | 5.1.5 54 | 55 | 56 | junit 57 | junit 58 | 4.12 59 | test 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | . 68 | 69 | data/** 70 | library/** 71 | 72 | 73 | 74 | 75 | 76 | org.apache.maven.plugins 77 | maven-compiler-plugin 78 | 3.5.1 79 | 80 | ${maven.compiler.target} 81 | ${maven.compiler.target} 82 | ${maven.compiler.encoding} 83 | 84 | 85 | 86 | org.apache.maven.plugins 87 | maven-surefire-plugin 88 | 2.19 89 | 90 | 91 | org.apache.maven.surefire 92 | surefire-junit47 93 | 2.19 94 | 95 | 96 | 97 | true 98 | 99 | 100 | 101 | org.apache.maven.plugins 102 | maven-jar-plugin 103 | 2.4 104 | 105 | 106 | **/*.properties 107 | 108 | 109 | 110 | 111 | org.apache.maven.plugins 112 | maven-resources-plugin 113 | 2.7 114 | 115 | ${maven.compiler.encoding} 116 | 117 | 118 | 119 | 120 | 121 | 122 | -------------------------------------------------------------------------------- /src/main/java/com/fullstackyang/nlp/classifier/feature/ChiSquaredStrategy.java: -------------------------------------------------------------------------------- 1 | package com.fullstackyang.nlp.classifier.feature; 2 | 3 | import com.fullstackyang.nlp.classifier.feature.FeatureSelection.Strategy; 4 | import com.fullstackyang.nlp.classifier.model.Category; 5 | import com.fullstackyang.nlp.classifier.utils.Calculator; 6 | import lombok.AllArgsConstructor; 7 | 8 | import java.util.Collection; 9 | import java.util.Comparator; 10 | 11 | @AllArgsConstructor 12 | public class ChiSquaredStrategy implements Strategy { 13 | 14 | private final Collection categories; 15 | 16 | private final int total; 17 | 18 | @Override 19 | public Feature estimate(Feature feature) { 20 | 21 | class ContingencyTable { 22 | private final int A, B, C, D; 23 | 24 | private ContingencyTable(Feature feature, Category category) { 25 | A = feature.getDocCountByCategory(category); 26 | B = feature.getFeatureCount() - A; 27 | C = category.getDocCount() - A; 28 | D = total - A - B - C; 29 | } 30 | } 31 | 32 | Double chisquared = categories.stream() 33 | .map(c -> new ContingencyTable(feature, c)) 34 | .map(ct -> Calculator.chisquare(ct.A, ct.B, ct.C, ct.D)) 35 | .max(Comparator.comparingDouble(Double::valueOf)).get(); 36 | feature.setScore(chisquared); 37 | return feature; 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /src/main/java/com/fullstackyang/nlp/classifier/feature/Feature.java: -------------------------------------------------------------------------------- 1 | package com.fullstackyang.nlp.classifier.feature; 2 | 3 | import com.fullstackyang.nlp.classifier.model.Category; 4 | import com.fullstackyang.nlp.classifier.model.Term; 5 | import com.google.common.collect.Maps; 6 | import lombok.*; 7 | import lombok.extern.slf4j.Slf4j; 8 | 9 | import java.util.Map; 10 | import java.util.Objects; 11 | 12 | @Slf4j 13 | @Data 14 | @ToString(of = {"term", "score"}) 15 | @EqualsAndHashCode(of = {"term"}) 16 | public class Feature { 17 | 18 | private final Term term; 19 | 20 | private double score; 21 | 22 | @Setter(AccessLevel.NONE) 23 | private Map categoryDocCounter; 24 | 25 | @Setter(AccessLevel.NONE) 26 | private Map categoryTermCounter; 27 | 28 | public Feature(Term term) { 29 | this.term = term; 30 | } 31 | 32 | public Feature(Term term, Category category) { 33 | this.term = term; 34 | this.categoryDocCounter = Maps.newHashMap(); 35 | this.categoryDocCounter.put(category, 1); 36 | 37 | this.categoryTermCounter = Maps.newHashMap(); 38 | this.categoryTermCounter.put(category, term.getTf()); 39 | 40 | } 41 | 42 | public Feature merge(Feature feature) { 43 | if (this.term.equals(feature.getTerm())) { 44 | this.term.setTf(this.term.getTf() + feature.getTerm().getTf()); 45 | feature.getCategoryDocCounter() 46 | .forEach((k, v) -> categoryDocCounter.merge(k, v, (oldValue, newValue) -> oldValue + newValue)); 47 | feature.getCategoryTermCounter() 48 | .forEach((k, v) -> categoryTermCounter.merge(k, v, (oldValue, newValue) -> oldValue + newValue)); 49 | } 50 | return this; 51 | } 52 | 53 | 54 | /** 55 | * 所有包含Feature的文档的数量 56 | * @return 57 | */ 58 | int getFeatureCount() { 59 | return categoryDocCounter.values().stream().mapToInt(Integer::intValue).sum(); 60 | } 61 | 62 | public int getDocCountByCategory(Category category) { 63 | return categoryDocCounter.getOrDefault(category, 0); 64 | } 65 | 66 | public int getTermCountByCategory(Category category) { 67 | return categoryTermCounter.getOrDefault(category, 0); 68 | } 69 | 70 | 71 | } 72 | -------------------------------------------------------------------------------- /src/main/java/com/fullstackyang/nlp/classifier/feature/FeatureSelection.java: -------------------------------------------------------------------------------- 1 | package com.fullstackyang.nlp.classifier.feature; 2 | 3 | import com.fullstackyang.nlp.classifier.model.*; 4 | import com.google.common.collect.Maps; 5 | import lombok.AllArgsConstructor; 6 | import lombok.extern.slf4j.Slf4j; 7 | 8 | import java.util.*; 9 | import java.util.function.Function; 10 | import java.util.stream.Stream; 11 | 12 | import static java.util.Comparator.comparing; 13 | import static java.util.stream.Collectors.*; 14 | 15 | @Slf4j 16 | @AllArgsConstructor 17 | public class FeatureSelection { 18 | 19 | interface Strategy { 20 | Feature estimate(Feature feature); 21 | } 22 | 23 | private final Strategy strategy; 24 | 25 | private final static int FEATURE_SIZE = 20000; 26 | 27 | public List select(List docs) { 28 | return createFeatureSpace(docs.stream()) 29 | .stream() 30 | .map(strategy::estimate) 31 | .filter(f -> f.getTerm().getWord().length() > 1) 32 | .sorted(comparing(Feature::getScore).reversed()) 33 | .limit(FEATURE_SIZE) 34 | .collect(toList()); 35 | 36 | } 37 | 38 | private Collection createFeatureSpace(Stream docs) { 39 | 40 | @AllArgsConstructor 41 | class FeatureCounter { 42 | 43 | private final Map featureMap; 44 | 45 | private FeatureCounter accumulate(Doc doc) { 46 | Map temp = doc.getTerms().parallelStream() 47 | .map(t -> new Feature(t, doc.getCategory())) 48 | .collect(toMap(Feature::getTerm, Function.identity())); 49 | 50 | if (!featureMap.isEmpty()) 51 | featureMap.values().forEach(f -> temp.merge(f.getTerm(), f, Feature::merge)); 52 | return new FeatureCounter(temp); 53 | } 54 | 55 | private FeatureCounter combine(FeatureCounter featureCounter) { 56 | Map temp = Maps.newHashMap(featureMap); 57 | featureCounter.featureMap.values().forEach(f -> temp.merge(f.getTerm(), f, Feature::merge)); 58 | return new FeatureCounter(temp); 59 | } 60 | } 61 | 62 | FeatureCounter counter = docs.parallel() 63 | .reduce(new FeatureCounter(Maps.newHashMap()), 64 | FeatureCounter::accumulate, 65 | FeatureCounter::combine); 66 | 67 | 68 | return counter.featureMap.values(); 69 | } 70 | 71 | } 72 | -------------------------------------------------------------------------------- /src/main/java/com/fullstackyang/nlp/classifier/feature/IGStrategy.java: -------------------------------------------------------------------------------- 1 | package com.fullstackyang.nlp.classifier.feature; 2 | 3 | import com.fullstackyang.nlp.classifier.model.Category; 4 | import com.fullstackyang.nlp.classifier.utils.Calculator; 5 | import com.google.common.collect.Lists; 6 | import lombok.AllArgsConstructor; 7 | import lombok.extern.slf4j.Slf4j; 8 | 9 | import java.util.*; 10 | 11 | import static java.util.stream.Collectors.toList; 12 | 13 | @Slf4j 14 | @AllArgsConstructor 15 | public class IGStrategy implements FeatureSelection.Strategy { 16 | 17 | // 所有分类 18 | private final Collection categories; 19 | 20 | //总文档数 21 | private final int total; 22 | 23 | 24 | public Feature estimate(Feature feature) { 25 | double totalEntropy = calcTotalEntropy(); 26 | double conditionalEntrogy = calcConditionEntropy(feature); 27 | feature.setScore(totalEntropy - conditionalEntrogy); 28 | return feature; 29 | } 30 | 31 | private double calcTotalEntropy() { 32 | return Calculator.entropy(categories.stream().map(c -> (double) c.getDocCount() / total).collect(toList())); 33 | } 34 | 35 | private double calcConditionEntropy(Feature feature) { 36 | int featureCount = feature.getFeatureCount(); 37 | double Pfeature = (double) featureCount / total; 38 | 39 | Map> Pcondition = categories.parallelStream().collect(() -> new HashMap>() {{ 40 | put(true, Lists.newArrayList()); 41 | put(false, Lists.newArrayList()); 42 | }}, (map, category) -> { 43 | int countDocWithFeature = feature.getDocCountByCategory(category); 44 | //出现该特征词且属于类别key的文档数量/出现该特征词的文档总数量 45 | map.get(true).add((double) countDocWithFeature / featureCount); 46 | //未出现该特征词且属于类别key的文档数量/未出现该特征词的文档总数量 47 | map.get(false).add((double) (category.getDocCount() - countDocWithFeature) / (total - featureCount)); 48 | }, 49 | (map1, map2) -> { 50 | map1.get(true).addAll(map2.get(true)); 51 | map1.get(false).addAll(map2.get(false)); 52 | } 53 | ); 54 | return Calculator.conditionalEntrogy(Pfeature, Pcondition.get(true), Pcondition.get(false)); 55 | 56 | } 57 | } 58 | 59 | 60 | -------------------------------------------------------------------------------- /src/main/java/com/fullstackyang/nlp/classifier/model/Category.java: -------------------------------------------------------------------------------- 1 | package com.fullstackyang.nlp.classifier.model; 2 | 3 | import com.fullstackyang.nlp.classifier.utils.Calculator; 4 | import lombok.*; 5 | 6 | import java.io.Serializable; 7 | import java.util.concurrent.atomic.AtomicInteger; 8 | 9 | @Data 10 | @AllArgsConstructor 11 | @EqualsAndHashCode(of = "name") 12 | @ToString(of = {"name", "docCount","termCount"}) 13 | public class Category { 14 | 15 | private final String name; 16 | 17 | private final String path; 18 | 19 | private final int docCount; 20 | 21 | private int termCount; 22 | 23 | public Category(String name, String path, int docCount) { 24 | this.name = name; 25 | this.path = path; 26 | this.docCount = docCount; 27 | } 28 | 29 | } 30 | -------------------------------------------------------------------------------- /src/main/java/com/fullstackyang/nlp/classifier/model/Doc.java: -------------------------------------------------------------------------------- 1 | package com.fullstackyang.nlp.classifier.model; 2 | 3 | import lombok.*; 4 | 5 | import java.util.List; 6 | 7 | @Data 8 | @AllArgsConstructor 9 | @EqualsAndHashCode(of = "id") 10 | public class Doc { 11 | 12 | private String id; 13 | 14 | private final Category category; 15 | 16 | private final List terms; 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/com/fullstackyang/nlp/classifier/model/Term.java: -------------------------------------------------------------------------------- 1 | package com.fullstackyang.nlp.classifier.model; 2 | 3 | import com.google.common.collect.Maps; 4 | import lombok.*; 5 | 6 | import java.util.Map; 7 | import java.util.concurrent.atomic.AtomicInteger; 8 | 9 | @Data 10 | @EqualsAndHashCode(of = {"word"})//暂不考虑同一个词不同词性的情况 11 | public class Term { 12 | 13 | private final String word; 14 | 15 | private final String POS; 16 | 17 | private int tf; 18 | 19 | public Term(String word, String POS, int tf) { 20 | this.word = word.toLowerCase(); 21 | this.POS = POS; 22 | this.tf = tf; 23 | } 24 | 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/com/fullstackyang/nlp/classifier/model/TrainSet.java: -------------------------------------------------------------------------------- 1 | package com.fullstackyang.nlp.classifier.model; 2 | 3 | import com.fullstackyang.nlp.classifier.utils.FileUtils; 4 | import com.fullstackyang.nlp.classifier.utils.nlp.NLPTools; 5 | import com.fullstackyang.nlp.classifier.utils.nlp.TermFilter; 6 | import com.google.common.io.PatternFilenameFilter; 7 | import lombok.Getter; 8 | import lombok.extern.slf4j.Slf4j; 9 | 10 | import java.io.File; 11 | import java.util.*; 12 | 13 | import static java.util.stream.Collectors.*; 14 | 15 | @Slf4j 16 | public class TrainSet { 17 | 18 | @Getter 19 | private Set categorySet; 20 | 21 | 22 | @Getter 23 | private List docs; 24 | 25 | private static final NLPTools nlpTools = NLPTools.instance(); 26 | 27 | private static final PatternFilenameFilter filenameFilter = new PatternFilenameFilter("(\\w+)\\.txt$"); 28 | 29 | /** 30 | * 总文档数 31 | */ 32 | @Getter 33 | private int totalDoc; 34 | 35 | /** 36 | * 总词数 37 | */ 38 | @Getter 39 | private int totalTerm; 40 | 41 | public TrainSet(String path) { 42 | File root = new File(path); 43 | if (root.listFiles() == null) { 44 | log.error("未发现训练集"); 45 | return; 46 | } 47 | 48 | log.info("开始读取训练集..."); 49 | this.categorySet = createCategorySet(root); 50 | log.info("类别集合创建完成!"); 51 | this.docs = categorySet.parallelStream() 52 | .map(c -> createDocs(c, new File(c.getPath()).listFiles(filenameFilter))) 53 | .flatMap(Collection::stream).collect(toList()); 54 | 55 | log.info("所有训练语料读取完成!开始统计..."); 56 | this.totalDoc = categorySet.stream().mapToInt(Category::getDocCount).sum(); 57 | this.totalTerm = docs.parallelStream().map(Doc::getTerms).flatMap(List::stream).mapToInt(Term::getTf).sum(); 58 | log.info("统计完成, 总文档数:" + totalDoc + ", 总类别数:" + categorySet.size() + ", 总字词数:" + totalTerm); 59 | 60 | log.info("各类别文档数分布:"); 61 | categorySet.stream() 62 | .sorted(Comparator.comparing(Category::getDocCount).reversed()) 63 | .map(c -> c.getName() + "/" + c.getDocCount()) 64 | .forEach(log::info); 65 | } 66 | 67 | 68 | private Set createCategorySet(File root) { 69 | return Arrays.stream(root.listFiles()) 70 | .filter(File::isDirectory) 71 | .map(f -> new Category(f.getName(), f.getAbsolutePath(), f.listFiles(filenameFilter).length)) 72 | .collect(toSet()); 73 | } 74 | 75 | private List createDocs(final Category category, File[] files) { 76 | return Arrays.stream(files).parallel() 77 | .map(f -> new Doc(f.getName(), category,getTerms(f.getAbsolutePath()))) 78 | .collect(toList()); 79 | } 80 | 81 | private List getTerms(String path) { 82 | return nlpTools.segment(FileUtils.readAll(path)).stream().filter(TermFilter::filter).distinct().collect(toList()); 83 | } 84 | 85 | 86 | public Optional getCategory(String name) { 87 | return categorySet.stream().filter(c -> c.getName().equals(name)).findFirst(); 88 | } 89 | 90 | 91 | } 92 | -------------------------------------------------------------------------------- /src/main/java/com/fullstackyang/nlp/classifier/naivebayes/NaiveBayesClassifier.java: -------------------------------------------------------------------------------- 1 | package com.fullstackyang.nlp.classifier.naivebayes; 2 | 3 | import com.fullstackyang.nlp.classifier.model.Term; 4 | import com.fullstackyang.nlp.classifier.utils.Calculator; 5 | import com.fullstackyang.nlp.classifier.utils.nlp.NLPTools; 6 | import lombok.AllArgsConstructor; 7 | 8 | import java.util.Comparator; 9 | import java.util.List; 10 | import java.util.Set; 11 | 12 | import static java.util.stream.Collectors.toList; 13 | 14 | public class NaiveBayesClassifier { 15 | 16 | interface Model { 17 | String getModelPath(); 18 | 19 | List getConditionProbability(String category, List terms, final NaiveBayesKnowledgeBase knowledgeBase); 20 | } 21 | 22 | private final Model model; 23 | 24 | private final NaiveBayesKnowledgeBase knowledgeBase; 25 | 26 | public NaiveBayesClassifier() { 27 | this(NaiveBayesModels.Multinomial); 28 | } 29 | 30 | public NaiveBayesClassifier(Model model) { 31 | this.model = model; 32 | this.knowledgeBase = new NaiveBayesKnowledgeBase(model.getModelPath()); 33 | } 34 | 35 | public String predict(String content) { 36 | Set allFeatures = knowledgeBase.getFeatures().keySet(); 37 | List terms = NLPTools.instance().segment(content).stream() 38 | .filter(t -> allFeatures.contains(t.getWord())) 39 | .distinct() 40 | .collect(toList()); 41 | 42 | @AllArgsConstructor 43 | class Result { 44 | final String category; 45 | final double probability; 46 | } 47 | 48 | Result result = knowledgeBase.getCategories().keySet().stream() 49 | .map(c -> new Result(c, Calculator.Ppost(knowledgeBase.getCategoryProbability(c), 50 | model.getConditionProbability(c, terms, knowledgeBase)))) 51 | .max(Comparator.comparingDouble(r -> r.probability)).orElse(new Result("unkown", 0.0)); 52 | return result.category; 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/main/java/com/fullstackyang/nlp/classifier/naivebayes/NaiveBayesKnowledgeBase.java: -------------------------------------------------------------------------------- 1 | package com.fullstackyang.nlp.classifier.naivebayes; 2 | 3 | import com.fullstackyang.nlp.classifier.feature.Feature; 4 | import com.fullstackyang.nlp.classifier.model.Term; 5 | import com.fullstackyang.nlp.classifier.utils.FileUtils; 6 | import com.google.common.collect.Lists; 7 | import lombok.*; 8 | import lombok.extern.slf4j.Slf4j; 9 | 10 | import java.io.File; 11 | import java.util.*; 12 | import java.util.function.Function; 13 | import java.util.stream.IntStream; 14 | 15 | import static java.util.stream.Collectors.*; 16 | 17 | @Slf4j 18 | @Data 19 | @NoArgsConstructor 20 | public class NaiveBayesKnowledgeBase { 21 | 22 | @Getter(AccessLevel.PACKAGE) 23 | private Map features; 24 | 25 | @Getter(AccessLevel.PACKAGE) 26 | private Map categories; 27 | 28 | public NaiveBayesKnowledgeBase(String modelPath) { 29 | log.info("加载文件,正在初始化..."); 30 | List lines = FileUtils.readLines(modelPath); 31 | this.categories = parseCategorySummary(lines.get(0)); 32 | this.features = lines.stream().skip(1) 33 | .map(this::parseFeatureSummariy) 34 | .filter(Objects::nonNull) 35 | .collect(toMap(FeatureSummary::getWord, Function.identity())); 36 | log.info("初始化完成!"); 37 | } 38 | 39 | private Map parseCategorySummary(String line) { 40 | if (!line.contains(" ") || !line.contains(":")) { 41 | log.error("格式有误"); 42 | return null; 43 | } 44 | 45 | return Arrays.stream(line.split(" ")) 46 | .filter(str -> str.contains(":")) 47 | .map(str -> str.split(":")) 48 | .collect(toMap(arr -> arr[0], 49 | arr -> Double.parseDouble(arr[1]), 50 | (u, v) -> { 51 | throw new IllegalStateException(String.format("Duplicate key %s", u)); 52 | }, 53 | LinkedHashMap::new)); 54 | } 55 | 56 | private FeatureSummary parseFeatureSummariy(String line) { 57 | try { 58 | return new FeatureSummary(line, categories.keySet()); 59 | } catch (Exception e) { 60 | e.printStackTrace(); 61 | } 62 | return null; 63 | } 64 | 65 | 66 | double getCategoryProbability(String category) { 67 | return categories.getOrDefault(category, 0.0); 68 | } 69 | 70 | double getPconditionByWord(String category, String word) { 71 | return features.containsKey(word) ? features.get(word).getPconditionByCategory(category) : 0.0; 72 | } 73 | 74 | public void write(String path) { 75 | FileUtils.write(new File(path), this.toString()); 76 | } 77 | 78 | 79 | public String toString() { 80 | StringBuilder builder = new StringBuilder(categories.keySet().stream().map(c -> c + ":" + categories.get(c)).collect(joining(" "))); 81 | builder.append(System.lineSeparator()); 82 | features.values().stream().map(f -> f.getWord() + ":" + categories.keySet().stream().map(c -> "" + f.getPconditionByCategory(c)) 83 | .collect(joining(" ")) + System.lineSeparator()).forEach(builder::append); 84 | return builder.toString(); 85 | } 86 | 87 | FeatureSummary createFeatureSummary(final Feature feature, final Map Pconditions) { 88 | return new FeatureSummary(feature, Pconditions); 89 | } 90 | 91 | class FeatureSummary { 92 | 93 | @Getter(AccessLevel.PACKAGE) 94 | private final String word; 95 | 96 | private final Map Pconditions; 97 | 98 | private FeatureSummary(final Feature feature, final Map Pconditions) { 99 | this.word = feature.getTerm().getWord(); 100 | this.Pconditions = Pconditions; 101 | } 102 | 103 | private FeatureSummary(String str, Set categorySet) throws Exception { 104 | if (!str.contains(":")) 105 | throw new Exception("invalid format"); 106 | 107 | this.word = str.substring(0, str.indexOf(":")); 108 | String substring = str.substring(str.indexOf(":") + 1); 109 | if (!substring.contains(" ")) 110 | throw new Exception("this feature has no Pcondition"); 111 | 112 | String[] Pconditions = substring.split(" "); 113 | if (Pconditions.length != categorySet.size()) 114 | throw new Exception("Pcondition's size doesn't match the category size"); 115 | 116 | List list = Lists.newArrayList(categorySet); 117 | this.Pconditions = IntStream.range(0, list.size()).boxed() 118 | .collect(toMap(list::get, i -> Double.parseDouble(Pconditions[i]))); 119 | } 120 | 121 | 122 | double getPconditionByCategory(String category) { 123 | return Pconditions.getOrDefault(category, 0.0); 124 | } 125 | 126 | public String toString() { 127 | return this.word + ":" + Pconditions.values().stream().map(Object::toString).collect(joining(" ")); 128 | } 129 | 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /src/main/java/com/fullstackyang/nlp/classifier/naivebayes/NaiveBayesLearner.java: -------------------------------------------------------------------------------- 1 | package com.fullstackyang.nlp.classifier.naivebayes; 2 | 3 | import com.fullstackyang.nlp.classifier.feature.ChiSquaredStrategy; 4 | import com.fullstackyang.nlp.classifier.feature.Feature; 5 | import com.fullstackyang.nlp.classifier.feature.FeatureSelection; 6 | import com.fullstackyang.nlp.classifier.feature.IGStrategy; 7 | import com.fullstackyang.nlp.classifier.model.Category; 8 | import com.fullstackyang.nlp.classifier.model.Doc; 9 | import com.fullstackyang.nlp.classifier.model.Term; 10 | import com.fullstackyang.nlp.classifier.model.TrainSet; 11 | import com.google.common.collect.Sets; 12 | import lombok.extern.slf4j.Slf4j; 13 | 14 | import java.util.List; 15 | import java.util.Map; 16 | import java.util.Set; 17 | import java.util.function.Function; 18 | 19 | import static com.fullstackyang.nlp.classifier.naivebayes.NaiveBayesModels.Bernoulli; 20 | import static com.fullstackyang.nlp.classifier.naivebayes.NaiveBayesModels.Multinomial; 21 | import static java.util.stream.Collectors.*; 22 | 23 | @Slf4j 24 | public class NaiveBayesLearner { 25 | 26 | private int total; 27 | 28 | private NaiveBayesKnowledgeBase knowledgeBase; 29 | 30 | interface Model { 31 | /** 32 | * 计算类别C的先验概率 33 | * 34 | * @param total (多项式模型)总特征数/(伯努利模型)总文档数 35 | * @param category (多项式模型)该类别的特征总数/(伯努利模型)该类别的文档总数 36 | * @return 类别C的先验概率 37 | */ 38 | double Pprior(int total, final Category category); 39 | 40 | /** 41 | * 计算类别C的条件概率 42 | * 43 | * @param feature 获取文档中出现了feature且属于类别category的数量 44 | * @param category 获取类别category的文档数量 45 | * @param smoothing 平滑参数 46 | * @return 类别C的条件概率 47 | */ 48 | double Pcondition(final Feature feature, final Category category, double smoothing); 49 | } 50 | 51 | private Model model; 52 | 53 | private Set categorySet; 54 | private Set featureSet; 55 | 56 | private TrainSet trainSet; 57 | 58 | public NaiveBayesLearner(Model model, TrainSet trainSet, Set selectedFeatures) { 59 | this.model = model; 60 | this.trainSet = trainSet; 61 | this.featureSet = selectedFeatures; 62 | this.knowledgeBase = new NaiveBayesKnowledgeBase(); 63 | } 64 | 65 | public NaiveBayesLearner statistics() { 66 | log.info("开始统计..."); 67 | this.total = total(); 68 | log.info("total : " + total); 69 | this.categorySet = trainSet.getCategorySet(); 70 | featureSet.forEach(f -> f.getCategoryTermCounter().forEach((category, count) -> category.setTermCount(category.getTermCount() + count))); 71 | categorySet.stream().map(Category::toString).forEach(log::info); 72 | return this; 73 | } 74 | 75 | public NaiveBayesKnowledgeBase build() { 76 | this.knowledgeBase.setCategories(createCategorySummaries(categorySet)); 77 | this.knowledgeBase.setFeatures(createFeatureSummaries(featureSet, categorySet)); 78 | return knowledgeBase; 79 | } 80 | 81 | private Map createFeatureSummaries(final Set featureSet, final Set categorySet) { 82 | return featureSet.parallelStream() 83 | .map(f -> knowledgeBase.createFeatureSummary(f, getPconditions(f, categorySet))) 84 | .collect(toMap(NaiveBayesKnowledgeBase.FeatureSummary::getWord, Function.identity())); 85 | } 86 | 87 | private Map createCategorySummaries(final Set categorySet) { 88 | return categorySet.stream().collect(toMap(Category::getName, c -> model.Pprior(total, c))); 89 | } 90 | 91 | private Map getPconditions(final Feature feature, final Set categorySet) { 92 | final double smoothing = smoothing(); 93 | return categorySet.stream() 94 | .collect(toMap(Category::getName, c -> model.Pcondition(feature, c, smoothing))); 95 | } 96 | 97 | private int total() { 98 | if (model == Multinomial) 99 | return featureSet.parallelStream() 100 | .map(Feature::getTerm) 101 | .mapToInt(Term::getTf) 102 | .sum(); 103 | else if (model == Bernoulli) 104 | return trainSet.getTotalDoc(); 105 | return 0; 106 | } 107 | 108 | private double smoothing() { 109 | if (model == Multinomial) 110 | return this.featureSet.size(); 111 | else if (model == Bernoulli) 112 | return 2.0; 113 | return 0.0; 114 | } 115 | 116 | 117 | public static void main(String[] args) { 118 | TrainSet trainSet = new TrainSet(System.getProperty("user.dir") + "/trainset/"); 119 | 120 | log.info("特征选择开始..."); 121 | FeatureSelection featureSelection = new FeatureSelection(new ChiSquaredStrategy(trainSet.getCategorySet(), trainSet.getTotalDoc())); 122 | List features = featureSelection.select(trainSet.getDocs()); 123 | log.info("特征选择完成,特征数:[" + features.size() + "]"); 124 | features.forEach(System.out::println); 125 | 126 | NaiveBayesModels model = NaiveBayesModels.Multinomial; 127 | NaiveBayesLearner learner = new NaiveBayesLearner(model, trainSet, Sets.newHashSet(features)); 128 | learner.statistics().build().write(model.getModelPath()); 129 | log.info("模型文件写入完成,路径:" + model.getModelPath()); 130 | } 131 | 132 | } 133 | -------------------------------------------------------------------------------- /src/main/java/com/fullstackyang/nlp/classifier/naivebayes/NaiveBayesModels.java: -------------------------------------------------------------------------------- 1 | package com.fullstackyang.nlp.classifier.naivebayes; 2 | 3 | import com.fullstackyang.nlp.classifier.feature.Feature; 4 | import com.fullstackyang.nlp.classifier.model.Category; 5 | import com.fullstackyang.nlp.classifier.model.Term; 6 | 7 | import java.util.List; 8 | 9 | import static java.util.stream.Collectors.toList; 10 | 11 | public enum NaiveBayesModels implements NaiveBayesClassifier.Model, NaiveBayesLearner.Model { 12 | 13 | Bernoulli { 14 | @Override 15 | public String getModelPath() { 16 | return "data/bernoulli_naive_bayes_model"; 17 | } 18 | 19 | @Override 20 | public double Pprior(int total, Category category) { 21 | int Nc = category.getDocCount(); 22 | return Math.log((double) Nc / total); 23 | } 24 | 25 | @Override 26 | public double Pcondition(Feature feature, Category category, double smoothing) { 27 | int Ncf = feature.getDocCountByCategory(category); 28 | int Nc = category.getDocCount(); 29 | return Math.log((double) (1 + Ncf) / (Nc + smoothing)); 30 | } 31 | 32 | @Override 33 | public List getConditionProbability(String category, List terms, final NaiveBayesKnowledgeBase knowledgeBase) { 34 | return terms.stream().map(term -> knowledgeBase.getPconditionByWord(category, term.getWord())).collect(toList()); 35 | } 36 | 37 | 38 | }, 39 | Multinomial { 40 | @Override 41 | public String getModelPath() { 42 | return "data/multinomial_naive_bayes_model"; 43 | } 44 | 45 | @Override 46 | public double Pprior(int total, Category category) { 47 | int Nt = category.getTermCount(); 48 | return Math.log((double) Nt / total); 49 | } 50 | 51 | @Override 52 | public double Pcondition(Feature feature, Category category, double smoothing) { 53 | int Ntf = feature.getTermCountByCategory(category); 54 | int Nt = category.getTermCount(); 55 | return Math.log((double) (1 + Ntf) / (Nt + smoothing)); 56 | } 57 | 58 | @Override 59 | public List getConditionProbability(String category, List terms, final NaiveBayesKnowledgeBase knowledgeBase) { 60 | return terms.stream().map(term -> term.getTf() * knowledgeBase.getPconditionByWord(category, term.getWord())).collect(toList()); 61 | } 62 | }; 63 | 64 | } 65 | -------------------------------------------------------------------------------- /src/main/java/com/fullstackyang/nlp/classifier/utils/Calculator.java: -------------------------------------------------------------------------------- 1 | package com.fullstackyang.nlp.classifier.utils; 2 | 3 | import java.util.List; 4 | 5 | public class Calculator { 6 | 7 | /** 8 | * 信息熵 9 | * 10 | * @param probabilities 11 | * @return 12 | */ 13 | public static double entropy(List probabilities) { 14 | // H(X) = -∑P(x∈X)logP(x∈X), H(X|y) = -∑P(x∈X|y)logP(x∈X|y) 15 | return probabilities.stream().filter(p -> p > 0.0).mapToDouble(p -> -p * Math.log(p)).sum(); 16 | } 17 | 18 | /** 19 | * 条件信息熵 20 | * 21 | * @param probability 22 | * @param PconditionWithFeature 23 | * @param PconditionWithoutFeature 24 | * @return 25 | */ 26 | public static double conditionalEntrogy(double probability, List PconditionWithFeature, 27 | List PconditionWithoutFeature) { 28 | // H(X|Y) = P(y=1.txt)H(X|y) + P(y=0)H(X|y) 即该特征词出现和不出现两种情况 29 | return probability * entropy(PconditionWithFeature) + (1 - probability) * entropy(PconditionWithoutFeature); 30 | } 31 | 32 | /** 33 | * 卡方检验计算公式 34 | * @param A 35 | * @param B 36 | * @param C 37 | * @param D 38 | * @return 39 | */ 40 | public static double chisquare(int A, int B, int C, int D) { 41 | // chi = n*(ad-bc)^2/(a+c)*(b+d)*(a+b)*(c+d) 42 | double chi = Math.log(A + B + C + D) + 2 * Math.log(Math.abs(A * D - B * C)) 43 | - (Math.log(A + C) + Math.log(B + D) + Math.log(A + B) + Math.log(C + D)); 44 | return Math.exp(chi); 45 | } 46 | 47 | 48 | /** 49 | * 贝叶斯公式计算后验概率 Pc=Pprior*Pcondition
50 | * 类条件概率连乘之后过小,故在前面的计算中取对数
51 | * 最终结果为log(Pprior)+log(Pcondition) 52 | * 53 | * @param Pprior 54 | * @param Pconditions 55 | * @return 56 | */ 57 | public static double Ppost(double Pprior, final List Pconditions) { 58 | return Pprior + Pconditions.stream().mapToDouble(Double::valueOf).sum(); 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/main/java/com/fullstackyang/nlp/classifier/utils/FileUtils.java: -------------------------------------------------------------------------------- 1 | package com.fullstackyang.nlp.classifier.utils; 2 | 3 | import com.google.common.base.Charsets; 4 | import com.google.common.base.Strings; 5 | import com.google.common.collect.Lists; 6 | import com.google.common.io.Files; 7 | import com.google.common.primitives.Chars; 8 | import com.hankcs.hanlp.corpus.io.IOUtil; 9 | import lombok.extern.slf4j.Slf4j; 10 | 11 | import java.io.*; 12 | import java.nio.charset.Charset; 13 | import java.util.Arrays; 14 | import java.util.List; 15 | import java.util.StringTokenizer; 16 | import java.util.stream.Collectors; 17 | 18 | import static java.util.stream.Collectors.joining; 19 | 20 | @Slf4j 21 | public class FileUtils { 22 | 23 | public static String readAll(String path) { 24 | try { 25 | return IOUtil.readTxt(path, Charsets.UTF_8.displayName()); 26 | } catch (IOException e) { 27 | log.error(e.getMessage()); 28 | } 29 | return null; 30 | } 31 | 32 | public static List readLines(String path) { 33 | if (Strings.isNullOrEmpty(path)) 34 | return Lists.newArrayList(); 35 | 36 | return IOUtil.readLineList(path); 37 | 38 | } 39 | 40 | public static void write(File file, String content) { 41 | try { 42 | Files.write(content, file, Charsets.UTF_8); 43 | } catch (IOException e) { 44 | log.error(e.getMessage()); 45 | } 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/main/java/com/fullstackyang/nlp/classifier/utils/nlp/AnsjSegmentor.java: -------------------------------------------------------------------------------- 1 | package com.fullstackyang.nlp.classifier.utils.nlp; 2 | 3 | import com.fullstackyang.nlp.classifier.model.Term; 4 | import com.fullstackyang.nlp.classifier.utils.nlp.NLPTools.Segmentor; 5 | import com.google.common.base.Strings; 6 | import com.google.common.collect.Lists; 7 | import lombok.NoArgsConstructor; 8 | import org.ansj.splitWord.analysis.ToAnalysis; 9 | 10 | import java.util.ArrayList; 11 | import java.util.LinkedHashMap; 12 | import java.util.List; 13 | import java.util.function.Function; 14 | import java.util.stream.Collectors; 15 | 16 | import static java.util.stream.Collectors.*; 17 | 18 | @NoArgsConstructor 19 | public class AnsjSegmentor implements Segmentor { 20 | 21 | @Override 22 | public List segment(String content) { 23 | return ToAnalysis.parse(content).getTerms().stream() 24 | .filter(t -> !t.getNatureStr().equals("null")) 25 | .map(t -> new Term(t.getName(), t.getNatureStr(), 1)) 26 | .collect(collectingAndThen(toList(), 27 | list -> { 28 | //词频统计 29 | list.parallelStream().collect(groupingBy(Function.identity(), counting())).forEach((term, count) -> { 30 | list.stream().filter(t -> t.getWord().equals(term.getWord())).forEach(t -> t.setTf(count.intValue())); 31 | }); 32 | return list; 33 | })); 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/com/fullstackyang/nlp/classifier/utils/nlp/HanLPSegmentor.java: -------------------------------------------------------------------------------- 1 | package com.fullstackyang.nlp.classifier.utils.nlp; 2 | 3 | import com.fullstackyang.nlp.classifier.model.Term; 4 | import com.fullstackyang.nlp.classifier.utils.nlp.NLPTools.Segmentor; 5 | import com.hankcs.hanlp.HanLP; 6 | import com.hankcs.hanlp.seg.Dijkstra.DijkstraSegment; 7 | import com.hankcs.hanlp.seg.Segment; 8 | import lombok.NoArgsConstructor; 9 | 10 | import java.util.List; 11 | 12 | import static java.util.stream.Collectors.toList; 13 | 14 | @NoArgsConstructor 15 | public class HanLPSegmentor implements Segmentor { 16 | 17 | private Segment segmentor = HanLP.newSegment().enableCustomDictionary(false).enableOrganizationRecognize(true); 18 | 19 | @Override 20 | public List segment(String content) { 21 | return segmentor.seg(content).stream().map(t -> new Term(t.word, t.nature.name(), t.getFrequency())).collect(toList()); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/com/fullstackyang/nlp/classifier/utils/nlp/MyStopWords.java: -------------------------------------------------------------------------------- 1 | package com.fullstackyang.nlp.classifier.utils.nlp; 2 | 3 | import com.fullstackyang.nlp.classifier.utils.FileUtils; 4 | import com.google.common.base.Charsets; 5 | import com.google.common.collect.Sets; 6 | import com.google.common.io.Files; 7 | import com.hankcs.hanlp.corpus.io.IOUtil; 8 | import com.hankcs.hanlp.dictionary.stopword.StopWordDictionary; 9 | 10 | import java.io.File; 11 | import java.io.FileInputStream; 12 | import java.io.IOException; 13 | import java.nio.file.Path; 14 | import java.util.List; 15 | import java.util.Set; 16 | 17 | public class MyStopWords implements NLPTools.StopWords { 18 | 19 | private Set set; 20 | 21 | private final static String PATH = "data/stopwords.txt"; 22 | 23 | MyStopWords() { 24 | set = Sets.newHashSet(FileUtils.readLines(PATH)); 25 | } 26 | 27 | 28 | @Override 29 | public boolean isStopWord(String word) { 30 | return set.contains(word); 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/main/java/com/fullstackyang/nlp/classifier/utils/nlp/NLPTools.java: -------------------------------------------------------------------------------- 1 | package com.fullstackyang.nlp.classifier.utils.nlp; 2 | 3 | import com.fullstackyang.nlp.classifier.model.Term; 4 | 5 | import java.util.List; 6 | 7 | public class NLPTools { 8 | 9 | interface Segmentor { 10 | List segment(String content); 11 | } 12 | 13 | interface StopWords { 14 | boolean isStopWord(String word); 15 | } 16 | 17 | private Segmentor segmentor; 18 | 19 | private StopWords stopWords; 20 | 21 | private NLPTools() { 22 | // this.segmentor = new JiebaSegmentor(); 23 | this.segmentor = new AnsjSegmentor(); 24 | this.stopWords = new MyStopWords(); 25 | } 26 | 27 | private static class Holder { 28 | private static NLPTools instance = new NLPTools(); 29 | } 30 | 31 | public static NLPTools instance() { 32 | return Holder.instance; 33 | } 34 | 35 | 36 | public List segment(String content) { 37 | return segmentor.segment(content); 38 | } 39 | 40 | public boolean isStopWord(String word) { 41 | return stopWords.isStopWord(word); 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/com/fullstackyang/nlp/classifier/utils/nlp/TermFilter.java: -------------------------------------------------------------------------------- 1 | package com.fullstackyang.nlp.classifier.utils.nlp; 2 | 3 | import com.fullstackyang.nlp.classifier.model.Term; 4 | import com.google.common.collect.Sets; 5 | 6 | import java.util.Set; 7 | 8 | public class TermFilter { 9 | private final static Set POSSet = Sets.newHashSet("w", "nx", "m", "t", "nt"); 10 | 11 | public static boolean filter(Term term) { 12 | return POSSet.stream().noneMatch(term.getPOS()::startsWith) 13 | && !NLPTools.instance().isStopWord(term.getWord()) && !term.getWord().matches("^\\d+(.*)"); 14 | } 15 | } -------------------------------------------------------------------------------- /src/main/resources/logback.xml: -------------------------------------------------------------------------------- 1 | 2 | article classifier 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | debug 11 | 12 | 13 | 14 | %d [%thread] %level %c{30}@%M[%L] - %m%n 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | true 24 | 25 | ${LOG_HOME}/%d{yyyy-MM-dd}_%i.log 26 | 30 27 | 28 | 29 | 100KB 30 | 31 | 32 | 33 | 34 | 35 | %-4date [%thread] %-5level %logger{35} - %msg%n%ex{full, DISPLAY_EX_EVAL} 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /src/test/java/com/fullstackyang/nlp/classifier/model/TestClassifier.java: -------------------------------------------------------------------------------- 1 | package com.fullstackyang.nlp.classifier.model; 2 | 3 | import com.fullstackyang.nlp.classifier.naivebayes.NaiveBayesClassifier; 4 | import com.fullstackyang.nlp.classifier.naivebayes.NaiveBayesModels; 5 | import org.junit.Test; 6 | 7 | public class TestClassifier { 8 | 9 | @Test 10 | public void test() { 11 | NaiveBayesClassifier classifier = new NaiveBayesClassifier(NaiveBayesModels.Multinomial); 12 | String text = "明日赛事推荐:切尔西巴萨冤家路窄,恒大申花再战亚冠"; 13 | String category = classifier.predict(text); 14 | System.out.println(category); 15 | } 16 | 17 | 18 | } 19 | --------------------------------------------------------------------------------