├── README.md ├── data ├── stopword.txt └── train_label.rar ├── data_process.py ├── model ├── checkpoint ├── model.ckpt-1090.index ├── model.ckpt-1635.index ├── model.ckpt-3270.index ├── model.ckpt-3815.index └── model.ckpt-8175.index ├── textcnn_model.py ├── train.py └── word2vec ├── readme.txt └── word2vec.PNG /README.md: -------------------------------------------------------------------------------- 1 | # Textcnn_tensorflow 2 | 3 | 基于搜狐2019算法赛数据的细粒度情感分析 4 | 几乎没有对数据做什么处理,只是简单的把有实体的句子拿出来,但是发现这样做破坏了一些有用的句子,这样肯定也是没有办法解决一个句子中有多个实体情感判断的问题,这里只是一个小尝试,重在自己尝试搭模型。当然在采取一些小措施之后这单一模型也是可以到0.6的效果的 5 | 6 | data文件夹存放的是数据以及停用词表 7 | 8 | model文件夹存放训练好的模型文件 9 | 10 | word2vec文件夹下是训练好的词向量模型,用的是这次比赛的训练集和测试集共同训练。 11 | 12 | data_process.py用于数据预处理 13 | 14 | textcnn_model.py是textcnn的模型代码 15 | 16 | train.py是训练代码 17 | 18 | 下载下来直接运行train.py文件,就可以训练模型 19 | 20 | 如果需要改动以适应于自己的数据,主要需要修改train.py里面的参数设置部分 21 | 22 | 此次代码中还有很多需要完善的地方,将在今后的学习中继续完善 23 | 24 | 关于模型更详细的介绍可以查看本人的csdn博客:https://blog.csdn.net/weixin_43256799/article/details/90647107 25 | 26 | 本人初学者一枚,所以可能有写错的地方还望指出 27 | 28 | 模型太大上传不上去,还请见谅。 29 | 30 | word2vec模型放在了我的百度网盘里,拷贝下来放在word2vec文件夹下即可。百度云链接:链接:https://pan.baidu.com/s/1X5FYj7fubF8KZVhLZdKtwg 提取码:q4k9 31 | -------------------------------------------------------------------------------- /data/stopword.txt: -------------------------------------------------------------------------------- 1 | ——— 2 | 》), 3 | )÷(1- 4 | ”, 5 | )、 6 | =( 7 | : 8 | → 9 | ℃ 10 | & 11 | * 12 | 一一 13 | ~~~~ 14 | ’ 15 | . 16 | 『 17 | .一 18 | ./ 19 | -- 20 | 』 21 | =″ 22 | 【 23 | [*] 24 | }> 25 | [⑤]] 26 | [①D] 27 | c] 28 | ng昉 29 | * 30 | // 31 | [ 32 | ] 33 | [②e] 34 | [②g] 35 | ={ 36 | } 37 | ,也 38 | ‘ 39 | A 40 | [①⑥] 41 | [②B] 42 | [①a] 43 | [④a] 44 | [①③] 45 | [③h] 46 | ③] 47 | 1. 48 | -- 49 | [②b] 50 | ’‘ 51 | ××× 52 | [①⑧] 53 | 0:2 54 | =[ 55 | [⑤b] 56 | [②c] 57 | [④b] 58 | [②③] 59 | [③a] 60 | [④c] 61 | [①⑤] 62 | [①⑦] 63 | [①g] 64 | ∈[ 65 | [①⑨] 66 | [①④] 67 | [①c] 68 | [②f] 69 | [②⑧] 70 | [②①] 71 | [①C] 72 | [③c] 73 | [③g] 74 | [②⑤] 75 | [②②] 76 | 一. 77 | [①h] 78 | .数 79 | [] 80 | [①B] 81 | 数/ 82 | [①i] 83 | [③e] 84 | [①①] 85 | [④d] 86 | [④e] 87 | [③b] 88 | [⑤a] 89 | [①A] 90 | [②⑧] 91 | [②⑦] 92 | [①d] 93 | [②j] 94 | 〕〔 95 | ][ 96 | :// 97 | ′∈ 98 | [②④ 99 | [⑤e] 100 | 12% 101 | b] 102 | ... 103 | ................... 104 | …………………………………………………③ 105 | ZXFITL 106 | [③F] 107 | 」 108 | [①o] 109 | ]∧′=[ 110 | ∪φ∈ 111 | ′| 112 | {- 113 | ②c 114 | } 115 | [③①] 116 | R.L. 117 | [①E] 118 | Ψ 119 | -[*]- 120 | ↑ 121 | .日 122 | [②d] 123 | [② 124 | [②⑦] 125 | [②②] 126 | [③e] 127 | [①i] 128 | [①B] 129 | [①h] 130 | [①d] 131 | [①g] 132 | [①②] 133 | [②a] 134 | f] 135 | [⑩] 136 | a] 137 | [①e] 138 | [②h] 139 | [②⑥] 140 | [③d] 141 | [②⑩] 142 | e] 143 | 〉 144 | 】 145 | 元/吨 146 | [②⑩] 147 | 2.3% 148 | 5:0 149 | [①] 150 | :: 151 | [②] 152 | [③] 153 | [④] 154 | [⑤] 155 | [⑥] 156 | [⑦] 157 | [⑧] 158 | [⑨] 159 | …… 160 | —— 161 | ? 162 | 、 163 | 。 164 | “ 165 | ” 166 | 《 167 | 》 168 | ! 169 | , 170 | : 171 | ; 172 | ? 173 | . 174 | , 175 | . 176 | ' 177 | ? 178 | · 179 | ——— 180 | ── 181 | ? 182 | — 183 | < 184 | > 185 | ( 186 | ) 187 | 〔 188 | 〕 189 | [ 190 | ] 191 | ( 192 | ) 193 | - 194 | + 195 | ~ 196 | × 197 | / 198 | / 199 | ① 200 | ② 201 | ③ 202 | ④ 203 | ⑤ 204 | ⑥ 205 | ⑦ 206 | ⑧ 207 | ⑨ 208 | ⑩ 209 | Ⅲ 210 | В 211 | " 212 | ; 213 | # 214 | @ 215 | γ 216 | μ 217 | φ 218 | φ. 219 | × 220 | Δ 221 | ■ 222 | ▲ 223 | sub 224 | exp 225 | sup 226 | sub 227 | Lex 228 | # 229 | % 230 | & 231 | ' 232 | + 233 | +ξ 234 | ++ 235 | - 236 | -β 237 | < 238 | <± 239 | <Δ 240 | <λ 241 | <φ 242 | << 243 | = 244 | = 245 | =☆ 246 | =- 247 | > 248 | >λ 249 | _ 250 | ~± 251 | ~+ 252 | [⑤f] 253 | [⑤d] 254 | [②i] 255 | ≈ 256 | [②G] 257 | [①f] 258 | LI 259 | ㈧ 260 | [- 261 | ...... 262 | 〉 263 | [③⑩] 264 | 第二 265 | 一番 266 | 一直 267 | 一个 268 | 一些 269 | 许多 270 | 种 271 | 有的是 272 | 也就是说 273 | 末##末 274 | 啊 275 | 阿 276 | 哎 277 | 哎呀 278 | 哎哟 279 | 唉 280 | 俺 281 | 俺们 282 | 按 283 | 按照 284 | 吧 285 | 吧哒 286 | 把 287 | 罢了 288 | 被 289 | 本 290 | 本着 291 | 比 292 | 比方 293 | 比如 294 | 鄙人 295 | 彼 296 | 彼此 297 | 边 298 | 别 299 | 别的 300 | 别说 301 | 并 302 | 并且 303 | 不比 304 | 不成 305 | 不单 306 | 不但 307 | 不独 308 | 不管 309 | 不光 310 | 不过 311 | 不仅 312 | 不拘 313 | 不论 314 | 不怕 315 | 不然 316 | 不如 317 | 不特 318 | 不惟 319 | 不问 320 | 不只 321 | 朝 322 | 朝着 323 | 趁 324 | 趁着 325 | 乘 326 | 冲 327 | 除 328 | 除此之外 329 | 除非 330 | 除了 331 | 此 332 | 此间 333 | 此外 334 | 从 335 | 从而 336 | 打 337 | 待 338 | 但 339 | 但是 340 | 当 341 | 当着 342 | 到 343 | 得 344 | 的 345 | 的话 346 | 等 347 | 等等 348 | 地 349 | 第 350 | 叮咚 351 | 对 352 | 对于 353 | 多 354 | 多少 355 | 而 356 | 而况 357 | 而且 358 | 而是 359 | 而外 360 | 而言 361 | 而已 362 | 尔后 363 | 反过来 364 | 反过来说 365 | 反之 366 | 非但 367 | 非徒 368 | 否则 369 | 嘎 370 | 嘎登 371 | 该 372 | 赶 373 | 个 374 | 各 375 | 各个 376 | 各位 377 | 各种 378 | 各自 379 | 给 380 | 根据 381 | 跟 382 | 故 383 | 故此 384 | 固然 385 | 关于 386 | 管 387 | 归 388 | 果然 389 | 果真 390 | 过 391 | 哈 392 | 哈哈 393 | 呵 394 | 和 395 | 何 396 | 何处 397 | 何况 398 | 何时 399 | 嘿 400 | 哼 401 | 哼唷 402 | 呼哧 403 | 乎 404 | 哗 405 | 还是 406 | 还有 407 | 换句话说 408 | 换言之 409 | 或 410 | 或是 411 | 或者 412 | 极了 413 | 及 414 | 及其 415 | 及至 416 | 即 417 | 即便 418 | 即或 419 | 即令 420 | 即若 421 | 即使 422 | 几 423 | 几时 424 | 己 425 | 既 426 | 既然 427 | 既是 428 | 继而 429 | 加之 430 | 假如 431 | 假若 432 | 假使 433 | 鉴于 434 | 将 435 | 较 436 | 较之 437 | 叫 438 | 接着 439 | 结果 440 | 借 441 | 紧接着 442 | 进而 443 | 尽 444 | 尽管 445 | 经 446 | 经过 447 | 就 448 | 就是 449 | 就是说 450 | 据 451 | 具体地说 452 | 具体说来 453 | 开始 454 | 开外 455 | 靠 456 | 咳 457 | 可 458 | 可见 459 | 可是 460 | 可以 461 | 况且 462 | 啦 463 | 来 464 | 来着 465 | 离 466 | 例如 467 | 哩 468 | 连 469 | 连同 470 | 两者 471 | 了 472 | 临 473 | 另 474 | 另外 475 | 另一方面 476 | 论 477 | 嘛 478 | 吗 479 | 慢说 480 | 漫说 481 | 冒 482 | 么 483 | 每 484 | 每当 485 | 们 486 | 莫若 487 | 某 488 | 某个 489 | 某些 490 | 拿 491 | 哪 492 | 哪边 493 | 哪儿 494 | 哪个 495 | 哪里 496 | 哪年 497 | 哪怕 498 | 哪天 499 | 哪些 500 | 哪样 501 | 那 502 | 那边 503 | 那儿 504 | 那个 505 | 那会儿 506 | 那里 507 | 那么 508 | 那么些 509 | 那么样 510 | 那时 511 | 那些 512 | 那样 513 | 乃 514 | 乃至 515 | 呢 516 | 能 517 | 你 518 | 你们 519 | 您 520 | 宁 521 | 宁可 522 | 宁肯 523 | 宁愿 524 | 哦 525 | 呕 526 | 啪达 527 | 旁人 528 | 呸 529 | 凭 530 | 凭借 531 | 其 532 | 其次 533 | 其二 534 | 其他 535 | 其它 536 | 其一 537 | 其余 538 | 其中 539 | 起 540 | 起见 541 | 起见 542 | 岂但 543 | 恰恰相反 544 | 前后 545 | 前者 546 | 且 547 | 然而 548 | 然后 549 | 然则 550 | 让 551 | 人家 552 | 任 553 | 任何 554 | 任凭 555 | 如 556 | 如此 557 | 如果 558 | 如何 559 | 如其 560 | 如若 561 | 如上所述 562 | 若 563 | 若非 564 | 若是 565 | 啥 566 | 上下 567 | 尚且 568 | 设若 569 | 设使 570 | 甚而 571 | 甚么 572 | 甚至 573 | 省得 574 | 时候 575 | 什么 576 | 什么样 577 | 使得 578 | 是 579 | 是的 580 | 首先 581 | 谁 582 | 谁知 583 | 顺 584 | 顺着 585 | 似的 586 | 虽 587 | 虽然 588 | 虽说 589 | 虽则 590 | 随 591 | 随着 592 | 所 593 | 所以 594 | 他 595 | 他们 596 | 他人 597 | 它 598 | 它们 599 | 她 600 | 她们 601 | 倘 602 | 倘或 603 | 倘然 604 | 倘若 605 | 倘使 606 | 腾 607 | 替 608 | 通过 609 | 同 610 | 同时 611 | 哇 612 | 万一 613 | 往 614 | 望 615 | 为 616 | 为何 617 | 为了 618 | 为什么 619 | 为着 620 | 喂 621 | 嗡嗡 622 | 我 623 | 我们 624 | 呜 625 | 呜呼 626 | 乌乎 627 | 无论 628 | 无宁 629 | 毋宁 630 | 嘻 631 | 吓 632 | 相对而言 633 | 像 634 | 向 635 | 向着 636 | 嘘 637 | 呀 638 | 焉 639 | 沿 640 | 沿着 641 | 要 642 | 要不 643 | 要不然 644 | 要不是 645 | 要么 646 | 要是 647 | 也 648 | 也罢 649 | 也好 650 | 一 651 | 一般 652 | 一旦 653 | 一方面 654 | 一来 655 | 一切 656 | 一样 657 | 一则 658 | 依 659 | 依照 660 | 矣 661 | 以 662 | 以便 663 | 以及 664 | 以免 665 | 以至 666 | 以至于 667 | 以致 668 | 抑或 669 | 因 670 | 因此 671 | 因而 672 | 因为 673 | 哟 674 | 用 675 | 由 676 | 由此可见 677 | 由于 678 | 有 679 | 有的 680 | 有关 681 | 有些 682 | 又 683 | 于 684 | 于是 685 | 于是乎 686 | 与 687 | 与此同时 688 | 与否 689 | 与其 690 | 越是 691 | 云云 692 | 哉 693 | 再说 694 | 再者 695 | 在 696 | 在下 697 | 咱 698 | 咱们 699 | 则 700 | 怎 701 | 怎么 702 | 怎么办 703 | 怎么样 704 | 怎样 705 | 咋 706 | 照 707 | 照着 708 | 者 709 | 这 710 | 这边 711 | 这儿 712 | 这个 713 | 这会儿 714 | 这就是说 715 | 这里 716 | 这么 717 | 这么点儿 718 | 这么些 719 | 这么样 720 | 这时 721 | 这些 722 | 这样 723 | 正如 724 | 吱 725 | 之 726 | 之类 727 | 之所以 728 | 之一 729 | 只是 730 | 只限 731 | 只要 732 | 只有 733 | 至 734 | 至于 735 | 诸位 736 | 着 737 | 着呢 738 | 自 739 | 自从 740 | 自个儿 741 | 自各儿 742 | 自己 743 | 自家 744 | 自身 745 | 综上所述 746 | 总的来看 747 | 总的来说 748 | 总的说来 749 | 总而言之 750 | 总之 751 | 纵 752 | 纵令 753 | 纵然 754 | 纵使 755 | 遵照 756 | 作为 757 | 兮 758 | 呃 759 | 呗 760 | 咚 761 | 咦 762 | 喏 763 | 啐 764 | 喔唷 765 | 嗬 766 | 嗯 767 | 嗳 -------------------------------------------------------------------------------- /data/train_label.rar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/422xiaomage/Textcnn_tensorflow/a0dcda3b55261ef7b7f472e8aa76510d19456100/data/train_label.rar -------------------------------------------------------------------------------- /data_process.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import jieba 3 | from gensim import models 4 | import numpy as np 5 | from collections import Counter 6 | 7 | class Data_process(object): 8 | def __init__(self, path_input, path_stopword, path_word2vec_model, embedding_size=300, 9 | max_length=300, min_counter=10, rate=0.8): 10 | self.path_input = path_input 11 | self.path_stopword = path_stopword 12 | self.max_length = max_length 13 | self.path_word2vec_model = path_word2vec_model 14 | self.min_counter = min_counter 15 | self.embedding_size = embedding_size 16 | self.rate = rate 17 | # 读取csv文件中的数据,并做好分词 18 | def read_data(self,path): 19 | fb = pd.read_csv(path) 20 | reviews = fb["Sentence"].tolist() 21 | labels = fb["Emotion"].tolist() 22 | reviews_word = [jieba.lcut(review.strip()) for review in reviews] 23 | return reviews_word, labels 24 | 25 | def read_stopword(self): 26 | with open(self.path_stopword, "r", encoding="utf-8") as f: 27 | lines = f.read() 28 | stop_word = lines.splitlines() 29 | self.stopWordDict = dict(zip(stop_word, list(range(len(stop_word))))) 30 | 31 | def get_vocabulary_embedding(self,reviews_word): 32 | allwords = [word for review in reviews_word for word in review] 33 | subwords = [word.strip() for word in allwords if word.strip() not in self.stopWordDict] 34 | wordcounter = Counter(subwords) 35 | sort_wordcounter = sorted(wordcounter.items(), key=lambda x: x[1], reverse=True) 36 | words = [item[0] for item in sort_wordcounter if item[1] >= self.min_counter] 37 | vocab = [] 38 | wordembedding = [] 39 | vocab.append("PAD") 40 | vocab.append("UNK") 41 | wordembedding.append(np.zeros(self.embedding_size)) 42 | wordembedding.append(np.random.randn(self.embedding_size)) 43 | word2vec = models.Word2Vec.load(self.path_word2vec_model) 44 | for word in words: 45 | try: 46 | vector = word2vec.wv[word] 47 | vocab.append(word) 48 | wordembedding.append(vector) 49 | except: 50 | print(word+"\t"+"不在训练的词向量中") 51 | embedding = wordembedding 52 | self.wordToindex = dict(zip(vocab, list(range(len(vocab))))) 53 | return np.array(embedding) 54 | 55 | def data_process(self, review,wordToindex): 56 | reviewVec = np.zeros((self.max_length)) 57 | sequenceLen = self.max_length 58 | if len(review) < self.max_length: 59 | sequenceLen = len(review) 60 | for i in range(sequenceLen): 61 | if review[i] in wordToindex: 62 | reviewVec[i] = wordToindex[review[i]] 63 | else: 64 | reviewVec[i] = wordToindex["UNK"] 65 | return reviewVec 66 | def get_train_evadata(self,x,y,rate): 67 | reviews_vector = [] 68 | labels = [] 69 | for i in range(len(x)): 70 | reviewvec = self.data_process(x[i],self.wordToindex) 71 | reviews_vector.append(reviewvec) 72 | labels.append([y[i]]) 73 | trainIndex = int(len(x) * rate) 74 | trainReviews = np.asarray(reviews_vector[:trainIndex], dtype="int64") 75 | trainLabels = np.array(labels[:trainIndex], dtype="float32") 76 | 77 | evalReviews = np.asarray(reviews_vector[trainIndex:], dtype="int64") 78 | evalLabels = np.array(labels[trainIndex:], dtype="float32") 79 | 80 | return trainReviews, trainLabels, evalReviews, evalLabels 81 | 82 | def dataGen(self): 83 | """ 84 | 初始化训练集和验证集 85 | """ 86 | 87 | # 初始化停用词 88 | self.read_stopword() 89 | 90 | # 初始化数据集 91 | reviews, labels = self.read_data(self.path_input) 92 | 93 | # 初始化词汇-索引映射表和词向量矩阵 94 | embedding = self.get_vocabulary_embedding(reviews) 95 | 96 | # 初始化训练集和测试集 97 | trainReviews, trainLabels, evalReviews, evalLabels = self.get_train_evadata(reviews, labels, self.rate) 98 | return trainReviews, trainLabels, evalReviews, evalLabels, embedding 99 | 100 | def nextBatch(self, x, y, batchSize): 101 | """ 102 | 生成batch数据集,用生成器的方式输出 103 | """ 104 | perm = np.arange(len(x)) 105 | np.random.shuffle(perm) 106 | x = x[perm] 107 | y = y[perm] 108 | 109 | numBatches = len(x) // batchSize 110 | 111 | for i in range(numBatches): 112 | start = i * batchSize 113 | end = start + batchSize 114 | batchX = np.array(x[start: end], dtype="int64") 115 | batchY = np.array(y[start: end], dtype="int64") 116 | 117 | yield batchX, batchY -------------------------------------------------------------------------------- /model/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt-3815" 2 | all_model_checkpoint_paths: "model.ckpt-545" 3 | all_model_checkpoint_paths: "model.ckpt-1090" 4 | all_model_checkpoint_paths: "model.ckpt-1635" 5 | all_model_checkpoint_paths: "model.ckpt-3815" 6 | -------------------------------------------------------------------------------- /model/model.ckpt-1090.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/422xiaomage/Textcnn_tensorflow/a0dcda3b55261ef7b7f472e8aa76510d19456100/model/model.ckpt-1090.index -------------------------------------------------------------------------------- /model/model.ckpt-1635.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/422xiaomage/Textcnn_tensorflow/a0dcda3b55261ef7b7f472e8aa76510d19456100/model/model.ckpt-1635.index -------------------------------------------------------------------------------- /model/model.ckpt-3270.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/422xiaomage/Textcnn_tensorflow/a0dcda3b55261ef7b7f472e8aa76510d19456100/model/model.ckpt-3270.index -------------------------------------------------------------------------------- /model/model.ckpt-3815.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/422xiaomage/Textcnn_tensorflow/a0dcda3b55261ef7b7f472e8aa76510d19456100/model/model.ckpt-3815.index -------------------------------------------------------------------------------- /model/model.ckpt-8175.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/422xiaomage/Textcnn_tensorflow/a0dcda3b55261ef7b7f472e8aa76510d19456100/model/model.ckpt-8175.index -------------------------------------------------------------------------------- /textcnn_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class TextCNN(object): 4 | def __init__(self, vocab_size,Filter_size,embedding=None, embedding_size=300, numFilters=128, max_length=300, 5 | dropoutKeepProb=0.5,numClass=3): 6 | self.embedding = embedding 7 | self.embedding_size = embedding_size 8 | self.Filter_size = Filter_size 9 | self.numFilters = numFilters 10 | self.max_length = max_length 11 | self.dropoutKeepProb = dropoutKeepProb 12 | self.vocab_size = vocab_size 13 | self.numClass = numClass 14 | 15 | 16 | # 词嵌入层 17 | self.input_X = tf.placeholder(tf.int32, [None, self.max_length], name="input_X") 18 | self.input_Y = tf.placeholder(tf.int32, [None, 1], name="input_Y") 19 | with tf.name_scope("embedding"): 20 | if self.embedding is not None: 21 | # 这里将词向量设置成了变量,所以在训练过程中,也会对原始词向量进行微调,如果不需要, 22 | # 可以将这一句修改为注释掉的代码段 23 | self.W = tf.Variable(tf.cast(self.embedding, dtype=tf.float32,name="word2vec"), name="W") 24 | # self.W = tf.constant(tf.cast(self.embedding, dtype=tf.float32,name="word2vec"), name="W") 25 | else: 26 | self.W = tf.Variable(tf.truncated_normal([self.vocab_size,self.embedding_size],stddev=1, 27 | dtype=tf.float32), name="W") 28 | # self.W = tf.constant(tf.truncated_normal([self.vocab_size,self.embedding_size],stddev=1, 29 | # dtype=tf.float32), name="W") 30 | # 词序号的向量化操作 31 | self.embeddingwords = tf.nn.embedding_lookup(self.W, self.input_X) 32 | # 这里由于卷积层输入的是一个四维的向量,第四维时通道,所以这里要扩展一维, 33 | # 扩展成[batch_size, width, height, channel] 34 | self.embeddingwords_expand = tf.expand_dims(self.embeddingwords, -1) 35 | # 对标签进行onne_hot编码,这里要注意self.input_Y的数据类型必须时int型 36 | self.input_Y_one_hot = tf.cast(tf.one_hot(self.input_Y, self.numClass, name="Y_onehot"), dtype=tf.float32) 37 | # 卷积层和池化层 38 | pooledOutputs = [] 39 | for i, filtersize in enumerate(self.Filter_size): 40 | 41 | with tf.name_scope("conv-maxpool-%s" % filtersize): 42 | # 卷积层,卷积核的尺寸[filtersize,self.embedding_size],卷积核的个数是self.numFilters,这个是超参 43 | # 初始化权重矩阵和偏置 44 | # 第三维1是通道数量,对于文本来说通道一定是1,所以不可以更改 45 | W = tf.Variable(tf.truncated_normal([filtersize, self.embedding_size, 1, self.numFilters], 46 | stddev=0.1),name="W") 47 | b = tf.Variable(tf.constant(0.1,shape=[self.numFilters]), name="b") 48 | conv = tf.nn.conv2d( 49 | self.embeddingwords_expand, 50 | W, 51 | strides=[1, 1, 1, 1], 52 | padding="VALID", 53 | name="conv") 54 | covn_plus_b = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu") 55 | # 池化层,最大池化,池化是对卷积后的序列取一个最大值 56 | pooled = tf.nn.max_pool( 57 | covn_plus_b, 58 | ksize=[1, self.max_length-filtersize+1, 1, 1], 59 | strides=[1, 1, 1, 1], 60 | padding="VALID", 61 | name="pool") 62 | pooledOutputs.append(pooled) 63 | # 池化后的维度不变,按照最后的维度channel来concat 64 | self.hPool = tf.concat(pooledOutputs, 3) 65 | # cnn输出的长度为卷积核的种类数*每种卷积核的个数 66 | flat_length = len(self.Filter_size) * self.numFilters 67 | self.hPoolFlat = tf.reshape(self.hPool, [-1, flat_length]) 68 | 69 | # dropout层 70 | with tf.name_scope("drop_out"): 71 | self.hPoolFlat_dropout = tf.nn.dropout(self.hPoolFlat, self.dropoutKeepProb) 72 | 73 | # 定义全连接层 74 | with tf.name_scope("output"): 75 | output_W = tf.get_variable( 76 | "output_W", 77 | shape=[flat_length, self.numClass], 78 | initializer=tf.contrib.layers.xavier_initializer()) 79 | output_b = tf.Variable(tf.constant(0.1, shape=[self.numClass]), name="output_b") 80 | 81 | self.predictions = tf.nn.xw_plus_b(self.hPoolFlat_dropout, output_W, output_b, name="predictions") 82 | 83 | self.output = tf.cast(tf.arg_max(self.predictions, 1), tf.float32, name="category") 84 | # 计算三元交叉熵损失 85 | with tf.name_scope("loss"): 86 | self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.predictions, 87 | labels=self.input_Y_one_hot)) 88 | # 优化器 89 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from data_process import Data_process 3 | from textcnn_model import TextCNN 4 | import datetime 5 | from sklearn.metrics import accuracy_score, precision_score, recall_score 6 | 7 | flags = tf.app.flags 8 | flags.DEFINE_boolean("is_train", True, "clean train folder") 9 | flags.DEFINE_integer("batch_size", 128, "batch_size") 10 | flags.DEFINE_integer("epoch", 20, "epoch") 11 | flags.DEFINE_integer("word2vev_size", 300, "length of word2vec") 12 | flags.DEFINE_integer("numfilter", 128, "numfilter") 13 | 14 | # 若文本长度大于300则截断 15 | flags.DEFINE_integer("max_length", 300, "text max_length") 16 | # 取词频大于10的 17 | flags.DEFINE_integer("min_counter", 10, "min_counter") 18 | flags.DEFINE_integer("numclass", 3, "numclass") 19 | flags.DEFINE_string("path_input", "./data/train_label.csv", "input file path") 20 | flags.DEFINE_string("path_stopword", "./data/stopword.txt", "stopword file path") 21 | flags.DEFINE_string("path_model", "./model/model.ckpt", "model file path") 22 | flags.DEFINE_string("path_word2vec_model", "./word2vec/word2vec_words_final.model", "word2vec_model file path") 23 | flags.DEFINE_float("learnrate", 0.0001, "rate") 24 | # 将文本分为训练集核验证集的比例 25 | flags.DEFINE_float("rate", 0.8, "rate") 26 | flags.DEFINE_float("dropoutrate", 0.5, "dropout") 27 | FLAGS = tf.app.flags.FLAGS 28 | 29 | def train(): 30 | dataprocess = Data_process(FLAGS.path_input,FLAGS.path_stopword,FLAGS.path_word2vec_model, 31 | FLAGS.word2vev_size,FLAGS.max_length,FLAGS.min_counter,FLAGS.rate) 32 | trainReviews, trainLabels, evalReviews, evalLabels, wordembedding = dataprocess.dataGen() 33 | with tf.Graph().as_default(): 34 | cnn = TextCNN(vocab_size=len(wordembedding),Filter_size=[2,3,4],embedding=wordembedding, 35 | numFilters=FLAGS.numfilter,max_length=FLAGS.max_length,dropoutKeepProb=FLAGS.dropoutrate, 36 | numClass=FLAGS.numclass) 37 | globalStep = tf.Variable(0, name="globalStep", trainable=False) 38 | optimizer = tf.train.AdamOptimizer(FLAGS.learnrate) 39 | # 计算梯度,得到梯度和变量 40 | gradsAndVars = optimizer.compute_gradients(cnn.loss) 41 | # 将梯度应用到变量下,生成训练器,对参数进行更新 42 | saver = tf.train.Saver() 43 | trainOp = optimizer.apply_gradients(gradsAndVars, global_step=globalStep) 44 | tf_config = tf.ConfigProto() 45 | tf_config.gpu_options.allow_growth = True 46 | with tf.Session(config=tf_config) as sess: 47 | sess.run(tf.global_variables_initializer()) 48 | recall_max = 0 49 | for i in range(FLAGS.epoch): 50 | for batch in dataprocess.nextBatch(trainReviews,trainLabels,FLAGS.batch_size): 51 | 52 | feed_dict = { 53 | cnn.input_X: batch[0], 54 | cnn.input_Y: batch[1] 55 | } 56 | predictions,loss,_,ouput,step = sess.run([cnn.predictions,cnn.loss,trainOp,cnn.output,globalStep], 57 | feed_dict) 58 | acc = accuracy_score(batch[1], ouput) 59 | precision = precision_score(batch[1], ouput, average='weighted') 60 | recall = recall_score(batch[1], ouput, average='micro') 61 | timeStr = datetime.datetime.now().isoformat() 62 | print("{}, iter: {}, step: {}, loss: {},acc: {}, precision: {}, recall: {}" 63 | .format(timeStr, i, step, loss, acc, precision, recall)) 64 | acces = [] 65 | precisiones = [] 66 | recalles = [] 67 | for batch_eva in dataprocess.nextBatch(evalReviews, evalLabels, FLAGS.batch_size): 68 | 69 | loss, output = sess.run([cnn.loss, cnn.output], feed_dict={ 70 | cnn.input_X: batch_eva[0], 71 | cnn.input_Y: batch_eva[1] 72 | }) 73 | acc = accuracy_score(batch_eva[1], ouput) 74 | precision = precision_score(batch_eva[1], ouput, average='weighted') 75 | recall = recall_score(batch_eva[1], ouput, average='micro') 76 | acces.append(acc) 77 | precisiones.append(precision) 78 | recalles.append(recall) 79 | acc = sum(acces)/len(acces) 80 | precision = sum(precisiones)/len(precisiones) 81 | recall = sum(recalles)/len(recalles) 82 | print("验证集结果:") 83 | print("{}, iter: {}, loss: {},acc: {}, precision: {}, recall: {}" 84 | .format(timeStr, i, loss, acc, precision, recall)) 85 | if recall > recall_max: 86 | recall_max = recall 87 | print("正在保存模型") 88 | saver.save(sess, FLAGS.path_model, global_step=step) 89 | 90 | if __name__ == "__main__": 91 | train() -------------------------------------------------------------------------------- /word2vec/readme.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/422xiaomage/Textcnn_tensorflow/a0dcda3b55261ef7b7f472e8aa76510d19456100/word2vec/readme.txt -------------------------------------------------------------------------------- /word2vec/word2vec.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/422xiaomage/Textcnn_tensorflow/a0dcda3b55261ef7b7f472e8aa76510d19456100/word2vec/word2vec.PNG --------------------------------------------------------------------------------