├── README.md ├── demo_chinese_text_binary_classification_with_bert.ipynb └── dianping_train_test.pickle /README.md: -------------------------------------------------------------------------------- 1 | # 如何用 Python 和 BERT 做中文文本二元分类? 2 | 3 | ![](https://github.com/wshuyi/github_pub_img/raw/master/assets/2019-04-07-16-38-00-008905.png) 4 | 5 | # 兴奋 6 | 7 | 去年, Google 的 BERT 模型一发布出来,我就很兴奋。 8 | 9 | 因为我当时正在用 fast.ai 的 ULMfit 做自然语言分类任务(还专门写了《[如何用 Python 和深度迁移学习做文本分类?](https://zhuanlan.zhihu.com/p/48182945)》一文分享给你)。ULMfit 和 BERT 都属于预训练语言模型(Pre-trained Language Modeling),具有很多的相似性。 10 | 11 | 所谓语言模型,就是利用深度神经网络结构,在海量语言文本上训练,以抓住一种语言的**通用特征**。 12 | 13 | 上述工作,往往只有大机构才能完成。因为**花费**实在太大了。 14 | 15 | 这花费包括但不限于: 16 | 17 | - 存数据 18 | - 买(甚至开发)运算设备 19 | - 训练模型(以天甚至月计) 20 | - 聘用专业人员 21 | - …… 22 | 23 | **预训练**就是指他们训练好之后,把这种结果开放出来。我们普通人或者小型机构,也可以**借用**其结果,在自己的专门领域文本数据上进行**微调**,以便让模型对于这个专门领域的文本有非常清晰的**认识**。 24 | 25 | 所谓认识,主要是指你遮挡上某些词汇,模型可以较准确地猜出来你藏住了什么。 26 | 27 | 甚至,你把两句话放在一起,模型可以判断它俩是不是紧密相连的上下文关系。 28 | 29 | ![](https://github.com/wshuyi/github_pub_img/raw/master/assets/2019-04-07-16-38-34-983938.png) 30 | 31 | 这种“认识”有用吗? 32 | 33 | 当然有。 34 | 35 | BERT 在多项自然语言任务上测试,不少结果已经超越了人类选手。 36 | 37 | ![](https://github.com/wshuyi/github_pub_img/raw/master/assets/2019-04-07-16-38-58-744810.png) 38 | 39 | BERT 可以辅助解决的任务,当然也包括文本分类(classification),例如情感分类等。这也是我目前研究的问题。 40 | 41 | # 痛点 42 | 43 | 然而,为了能用上 BERT ,我等了很久。 44 | 45 | Google 官方代码早已开放。就连 Pytorch 上的实现,也已经迭代了多少个轮次了。 46 | 47 | 但是我只要一打开他们提供的样例,就头晕。 48 | 49 | ![](https://github.com/wshuyi/github_pub_img/raw/master/assets/2019-04-05-16-49-01-928502.jpeg) 50 | 51 | 单单是那代码的行数,就非常吓人。 52 | 53 | 而且,一堆的数据处理流程(Data Processor) ,都用数据集名称命名。我的数据不属于上述任何一个,那么我该用哪个? 54 | 55 | ![](https://github.com/wshuyi/github_pub_img/raw/master/assets/2019-04-07-16-37-22-345044.jpeg) 56 | 57 | 还有莫名其妙的无数旗标(flags) ,看了也让人头疼不已。 58 | 59 | ![](https://github.com/wshuyi/github_pub_img/raw/master/assets/2019-04-07-16-37-22-345033.jpeg) 60 | 61 | 让我们来对比一下,同样是做分类任务,Scikit-learn 里面的语法结构是什么样的。 62 | 63 | ```python 64 | from sklearn.datasets import load_iris 65 | from sklearn import tree 66 | iris = load_iris() 67 | clf = tree.DecisionTreeClassifier() 68 | clf = clf.fit(iris.data, iris.target) 69 | ``` 70 | 71 | 即便是图像分类这种数据吞吐量大,需要许多步骤的任务,你用 fast.ai ,也能几行代码,就轻轻松松搞定。 72 | 73 | ```python 74 | !git clone https://github.com/wshuyi/demo-image-classification-fastai.git 75 | from fastai.vision import * 76 | path = Path("demo-image-classification-fastai/imgs/") 77 | data = ImageDataBunch.from_folder(path, test='test', size=224) 78 | learn = cnn_learner(data, models.resnet18, metrics=accuracy) 79 | learn.fit_one_cycle(1) 80 | interp = ClassificationInterpretation.from_learner(learn) 81 | interp.plot_top_losses(9, figsize=(8, 8)) 82 | ``` 83 | 84 | 别小瞧这几行代码,不仅帮你训练好一个图像分类器,还能告诉你,那些分类误差最高的图像中,模型到底在关注哪里。 85 | 86 | ![](https://github.com/wshuyi/github_pub_img/raw/master/assets/2019-04-05-16-49-01-928501.png) 87 | 88 | 对比一下,你觉得 BERT 样例和 fast.ai 的样例区别在哪儿? 89 | 90 | 我觉得,后者是**给人用的**。 91 | 92 | # 教程 93 | 94 | 我总以为,会有人把代码重构一下,写一个简明的教程。 95 | 96 | 毕竟,文本分类任务是个常见的机器学习应用。应用场景多,也适合新手学习。 97 | 98 | 但是,这样的教程,我就是没等来。 99 | 100 | 当然,这期间,我也看过很多人写的应用和教程。 101 | 102 | 有的就做到把一段自然语言文本,转换到 BERT 编码。戛然而止。 103 | 104 | 有的倒是认真介绍怎么在官方提供的数据集上,对 BERT 进行“稍微修改”使用。所有的修改,都在原始的 Python 脚本上完成。那些根本没用到的函数和参数,全部被保留。至于别人如何复用到自己的数据集上?人家根本没提这事儿。 105 | 106 | 我不是没想过从头啃一遍代码。想当年读研的时候,我也通读过仿真平台上 TCP 和 IP 层的全部 C 代码。我确定眼前的任务,难度更低一些。 107 | 108 | 但是我真的懒得做。我觉得自己被 Python 机器学习框架,特别是 fast.ai 和 Scikit-learn 宠坏了。 109 | 110 | 后来, Google 的开发人员把 BERT 弄到了 Tensorflow Hub 上。还专门写了个 Google Colab Notebook 样例。 111 | 112 | 看到这个消息,我高兴坏了。 113 | 114 | 我尝试过 Tensorflow Hub 上的不少其他模型。使用起来很方便。而 Google Colab 我已在《[如何用 Google Colab 练 Python?](https://zhuanlan.zhihu.com/p/57100935)》一文中介绍给你,是非常好的 Python 深度学习练习和演示环境。满以为双剑合璧,这次可以几行代码搞定自己的任务了。 115 | 116 | 且慢。 117 | 118 | ![](https://github.com/wshuyi/github_pub_img/raw/master/assets/2019-04-07-16-37-22-345041.png) 119 | 120 | 真正打开一看,还是以样例数据为中心。 121 | 122 | 普通用户需要什么?需要一个接口。 123 | 124 | 你告诉我输入的标准规范,然后告诉我结果都能有什么。即插即用,完事儿走人。 125 | 126 | 一个文本分类任务,原本不就是给你个训练集和测试集,告诉你训练几轮练多快,然后你告诉我准确率等结果吗? 127 | 128 | 你至于让我为了这么简单的一个任务,去读几百行代码,自己找该在哪里改吗? 129 | 130 | 好在,有了这个样例做基础,总比没有好。 131 | 132 | 我耐下心来,把它整理了一番。 133 | 134 | 声明一下,我并没有对原始代码进行大幅修改。 135 | 136 | 所以不讲清楚的话,就有剽窃嫌疑,也会被鄙视的。 137 | 138 | 这种整理,对于会 Python 的人来说,没有任何技术难度。 139 | 140 | 可正因为如此,我才生气。这事儿难做吗?Google 的 BERT 样例编写者怎么就不肯做? 141 | 142 | 从 Tensorflow 1.0 到 2.0,为什么变动会这么大?不就是因为 2.0 才是给人用的吗? 143 | 144 | 你不肯把界面做得清爽简单,你的竞争者(TuriCreate 和 fast.ai)会做,而且做得非常好。实在坐不住了,才肯降尊纡贵,给普通人开发一个好用的界面。 145 | 146 | 教训啊!为什么就不肯吸取呢? 147 | 148 | 我给你提供一个 Google Colab 笔记本样例,你可以轻易地替换上自己的数据集来运行。你需要去理解(包括修改)的代码,**不超过10行**。 149 | 150 | 我先是测试了一个英文文本分类任务,效果很好。于是写了[一篇 Medium 博客](https://towardsdatascience.com/how-to-do-text-binary-classification-with-bert-f1348a25d905),旋即被 Towards Data Science 专栏收录了。 151 | 152 | ![](https://github.com/wshuyi/github_pub_img/raw/master/assets/2019-04-07-16-37-22-345050.jpeg) 153 | 154 | Towards Data Science 专栏编辑给我私信,说: 155 | 156 | > Very interesting, I like this considering the default implementation is not very developer friendly for sure. 157 | 158 | 有一个读者,居然连续给这篇文章点了50个赞(Claps),我都看呆了。 159 | 160 | ![](https://github.com/wshuyi/github_pub_img/raw/master/assets/2019-04-07-16-37-22-345040.png) 161 | 162 | 看来,这种忍受已久的痛点,不止属于我一个人。 163 | 164 | 估计你的研究中,中文分类任务可能遇到得更多。所以我干脆又做了一个中文文本分类样例,并且写下这篇教程,一并分享给你。 165 | 166 | 咱们开始吧。 167 | 168 | 169 | # 代码 170 | 171 | 请点击[这个链接](https://github.com/wshuyi/demo-chinese-text-binary-classification-with-bert/blob/master/demo_chinese_text_binary_classification_with_bert.ipynb),查看我在 Github 上为你做好的 IPython Notebook 文件。 172 | 173 | ![](https://github.com/wshuyi/github_pub_img/raw/master/assets/2019-04-07-16-37-22-345039.jpeg) 174 | 175 | 176 | Notebook 顶端,有个非常明显的 "Open in Colab" 按钮。点击它,Google Colab 就会自动开启,并且载入这个 Notebook 。 177 | 178 | ![](https://github.com/wshuyi/github_pub_img/raw/master/assets/2019-04-07-16-37-22-345030.png) 179 | 180 | 我建议你点一下上图中红色圈出的 “COPY TO DRIVE” 按钮。这样就可以先把它在你自己的 Google Drive 中存好,以便使用和回顾。 181 | 182 | 这件事做好以后,你实际上只需要执行下面三个步骤: 183 | 184 | 1. 你的数据,应该以 Pandas 数据框形式组织。如果你对 Pandas 不熟悉,可以参考我的[这篇文章](https://www.jianshu.com/p/a7a7db17e26d)。 185 | 2. 如有必要,可以调整训练参数。其实主要是训练速率(Learning Rate)和训练轮数(Epochs)。 186 | 3. 执行 Notebook 的代码,获取结果。 187 | 188 | 当你把 Notebook 存好之后。定睛一看,或许会觉得上当了。 189 | 190 | ![](https://github.com/wshuyi/github_pub_img/raw/master/assets/2019-04-07-16-37-22-345053.png) 191 | 192 | > 老师你骗人!说好了不超过10行代码的! 193 | 194 | **别急**。 195 | 196 | 在下面这张图红色圈出的这句话之前,你**不用修改任何内容**。 197 | 198 | ![](https://github.com/wshuyi/github_pub_img/raw/master/assets/2019-04-05-16-49-01-928507.png) 199 | 200 | 请你点击这句话所在位置,然后从菜单中如下图选择 `Run before` 。 201 | 202 | ![](https://github.com/wshuyi/github_pub_img/raw/master/assets/2019-04-05-16-49-01-928498.png) 203 | 204 | 下面才都是紧要的环节,集中注意力。 205 | 206 | 第一步,就是把数据准备好。 207 | 208 | ```python 209 | !wget https://github.com/wshuyi/demo-chinese-text-binary-classification-with-bert/raw/master/dianping_train_test.pickle 210 | 211 | with open("dianping_train_test.pickle", 'rb') as f: 212 | train, test = pickle.load(f) 213 | ``` 214 | 215 | 这里使用的数据,你应该并不陌生。它是餐饮点评情感标注数据,我在《[如何用Python和机器学习训练中文文本情感分类模型?](https://zhuanlan.zhihu.com/p/34482959)》和《[如何用 Python 和循环神经网络做中文文本分类?](https://zhuanlan.zhihu.com/p/50488163)》中使用过它。只不过,为了演示的方便,这次我把它输出为 pickle 格式,一起放在了演示 Github repo 里,便于你下载和使用。 216 | 217 | 其中的训练集,包含1600条数据;测试集包含400条数据。标注里面1代表正向情感,0代表负向情感。 218 | 219 | 利用下面这条语句,我们把训练集重新洗牌(shuffling),打乱顺序。以避免过拟合(overfitting)。 220 | 221 | ```python 222 | train = train.sample(len(train)) 223 | ``` 224 | 225 | 这时再来看看我们训练集的头部内容。 226 | 227 | ```python 228 | train.head() 229 | ``` 230 | 231 | ![](https://github.com/wshuyi/github_pub_img/raw/master/assets/2019-04-07-16-37-22-345042.png) 232 | 233 | 如果你后面要替换上自己的数据集,请注意格式。训练集和测试集的列名称应该保持一致。 234 | 235 | 第二步,我们来设置参数。 236 | 237 | ```python 238 | myparam = { 239 | "DATA_COLUMN": "comment", 240 | "LABEL_COLUMN": "sentiment", 241 | "LEARNING_RATE": 2e-5, 242 | "NUM_TRAIN_EPOCHS":3, 243 | "bert_model_hub":"https://tfhub.dev/google/bert_chinese_L-12_H-768_A-12/1" 244 | } 245 | ``` 246 | 247 | 前两行,是把文本、标记对应的列名,指示清楚。 248 | 249 | 第三行,指定训练速率。你可以阅读原始论文,来进行超参数调整尝试。或者,你干脆保持默认值不变就可以。 250 | 251 | 第四行,指定训练轮数。把所有数据跑完,算作一轮。这里使用3轮。 252 | 253 | 最后一行,是说明你要用的 BERT 预训练模型。咱们要做中文文本分类,所以使用的是这个中文预训练模型地址。如果你希望用英文的,可以参考[我的 Medium 博客文章](https://towardsdatascience.com/how-to-do-text-binary-classification-with-bert-f1348a25d905)以及对应的[英文样例代码](https://github.com/wshuyi/demo_text_binary_classification_bert/blob/master/demo_text_binary_classification_with_bert.ipynb)。 254 | 255 | 最后一步,我们依次执行代码就好了。 256 | 257 | ```python 258 | result, estimator = run_on_dfs(train, test, **myparam) 259 | ``` 260 | 261 | 注意,执行这一句,可能需要**花费一段时间**。做好心理准备。这跟你的数据量和训练轮数设置有关。 262 | 263 | 在这个过程中,你可以看到,程序首先帮助你把原先的中文文本,变成了 BERT 可以理解的输入数据格式。 264 | 265 | ![](https://github.com/wshuyi/github_pub_img/raw/master/assets/2019-04-07-16-37-22-345037.png) 266 | 267 | 当你看到下图中红色圈出文字时,就意味着训练过程终于结束了。 268 | 269 | ![](https://github.com/wshuyi/github_pub_img/raw/master/assets/2019-04-07-16-37-22-345057.png) 270 | 271 | 然后你就可以把测试的结果打印出来了。 272 | 273 | ```python 274 | pretty_print(result) 275 | ``` 276 | 277 | ![](https://github.com/wshuyi/github_pub_img/raw/master/assets/2019-04-07-16-37-22-345046.png) 278 | 279 | 跟咱们之前的[教程](https://zhuanlan.zhihu.com/p/50488163)(使用同一数据集)对比一下。 280 | 281 | ![](https://github.com/wshuyi/github_pub_img/raw/master/assets/2018-11-21-08-47-47-758482.png) 282 | 283 | 当时自己得写那么多行代码,而且需要跑10个轮次,可结果依然没有超过 80% 。这次,虽然只训练了3个轮次,但准确率已经超过了 88% 。 284 | 285 | 在这样小规模数据集上,达到这样的准确度,不容易。 286 | 287 | BERT **性能之强悍**,可见一斑。 288 | 289 | # 小结 290 | 291 | 讲到这里,你已经学会了如何用 BERT 来做中文文本二元分类任务了。希望你会跟我一样开心。 292 | 293 | 如果你是个资深 Python 爱好者,请帮我个忙。 294 | 295 | 还记得这条线之前的代码吗? 296 | 297 | ![](https://github.com/wshuyi/github_pub_img/raw/master/assets/2019-04-05-16-49-01-928506.png) 298 | 299 | 能否帮我把它们打个包?这样咱们的演示代码就可以更加短小精悍和清晰易用了。 300 | 301 | 欢迎在[咱们的 Github 项目](https://github.com/wshuyi/demo-chinese-text-binary-classification-with-bert)上提交你的代码。如果你觉得这篇教程对你有帮助,欢迎给[这个 Github 项目](https://github.com/wshuyi/demo-chinese-text-binary-classification-with-bert)加颗星。谢谢! 302 | 303 | 祝深度学习愉快! 304 | 305 | # 延伸阅读 306 | 307 | 你可能也会对以下话题感兴趣。点击链接就可以查看。 308 | 309 | - [如何高效学 Python ?](https://zhuanlan.zhihu.com/p/29631043) 310 | - [学 Python ,能提升你的竞争力吗?](https://zhuanlan.zhihu.com/p/53011746) 311 | - [文科生如何理解卷积神经网络?](https://zhuanlan.zhihu.com/p/36416075) 312 | - [文科生如何理解循环神经网络(RNN)?](https://zhuanlan.zhihu.com/p/49988171) 313 | - [《文科生数据科学上手指南》分享](https://zhuanlan.zhihu.com/p/44653452) 314 | 315 | 喜欢请点赞和打赏。还可以微信关注和置顶我的公众号[“玉树芝兰”(nkwangshuyi)](https://i.loli.net/2019/03/05/5c7dd41f11372.png)。 316 | 317 | 如果你对 Python 与数据科学感兴趣,不妨阅读我的系列教程索引贴《[如何高效入门数据科学?](https://zhuanlan.zhihu.com/p/35563090)》,里面还有更多的有趣问题及解法。 318 | 319 | 知识星球入口在这里: 320 | 321 | ![](https://i.loli.net/2019/03/05/5c7dd41f11372.png) 322 | -------------------------------------------------------------------------------- /demo_chinese_text_binary_classification_with_bert.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "demo-chinese-text-binary-classification-with-bert.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "accelerator": "GPU" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "metadata": { 31 | "id": "GBbAMqZ6IZuw", 32 | "colab_type": "text" 33 | }, 34 | "cell_type": "markdown", 35 | "source": [ 36 | "base code borrowed from [this Google Colab Notebook](https://colab.research.google.com/github/google-research/bert/blob/master/predicting_movie_reviews_with_bert_on_tf_hub.ipynb).\n", 37 | "\n", 38 | "Refactored by [Shuyi Wang](https://www.linkedin.com/in/shuyi-wang-b3955026/)\n", 39 | "\n", 40 | "Please refer to [this Medium Article](https://medium.com/@wshuyi/how-to-do-text-binary-classification-with-bert-f1348a25d905) for the tutorial on how to classify English text data.\n", 41 | "\n" 42 | ] 43 | }, 44 | { 45 | "metadata": { 46 | "id": "jviywGyWyKsA", 47 | "colab_type": "code", 48 | "trusted": true, 49 | "colab": {} 50 | }, 51 | "cell_type": "code", 52 | "source": [ 53 | "!pip install bert-tensorflow" 54 | ], 55 | "execution_count": 0, 56 | "outputs": [] 57 | }, 58 | { 59 | "metadata": { 60 | "id": "hsZvic2YxnTz", 61 | "colab_type": "code", 62 | "trusted": true, 63 | "colab": {} 64 | }, 65 | "cell_type": "code", 66 | "source": [ 67 | "import pandas as pd\n", 68 | "import tensorflow as tf\n", 69 | "import tensorflow_hub as hub\n", 70 | "import pickle\n", 71 | "import bert\n", 72 | "from bert import run_classifier\n", 73 | "from bert import optimization\n", 74 | "from bert import tokenization" 75 | ], 76 | "execution_count": 0, 77 | "outputs": [] 78 | }, 79 | { 80 | "metadata": { 81 | "colab_type": "code", 82 | "id": "NZM71PjIOF_I", 83 | "colab": {} 84 | }, 85 | "cell_type": "code", 86 | "source": [ 87 | "def pretty_print(result):\n", 88 | " df = pd.DataFrame([result]).T\n", 89 | " df.columns = [\"values\"]\n", 90 | " return df" 91 | ], 92 | "execution_count": 0, 93 | "outputs": [] 94 | }, 95 | { 96 | "metadata": { 97 | "trusted": true, 98 | "id": "ZCYOL8HbIZu2", 99 | "colab_type": "code", 100 | "colab": {} 101 | }, 102 | "cell_type": "code", 103 | "source": [ 104 | "def create_tokenizer_from_hub_module(bert_model_hub):\n", 105 | " \"\"\"Get the vocab file and casing info from the Hub module.\"\"\"\n", 106 | " with tf.Graph().as_default():\n", 107 | " bert_module = hub.Module(bert_model_hub)\n", 108 | " tokenization_info = bert_module(signature=\"tokenization_info\", as_dict=True)\n", 109 | " with tf.Session() as sess:\n", 110 | " vocab_file, do_lower_case = sess.run([tokenization_info[\"vocab_file\"],\n", 111 | " tokenization_info[\"do_lower_case\"]])\n", 112 | " \n", 113 | " return bert.tokenization.FullTokenizer(\n", 114 | " vocab_file=vocab_file, do_lower_case=do_lower_case)\n", 115 | "\n", 116 | "def make_features(dataset, label_list, MAX_SEQ_LENGTH, tokenizer, DATA_COLUMN, LABEL_COLUMN):\n", 117 | " input_example = dataset.apply(lambda x: bert.run_classifier.InputExample(guid=None, \n", 118 | " text_a = x[DATA_COLUMN], \n", 119 | " text_b = None, \n", 120 | " label = x[LABEL_COLUMN]), axis = 1)\n", 121 | " features = bert.run_classifier.convert_examples_to_features(input_example, label_list, MAX_SEQ_LENGTH, tokenizer)\n", 122 | " return features\n", 123 | "\n", 124 | "def create_model(bert_model_hub, is_predicting, input_ids, input_mask, segment_ids, labels,\n", 125 | " num_labels):\n", 126 | " \"\"\"Creates a classification model.\"\"\"\n", 127 | "\n", 128 | " bert_module = hub.Module(\n", 129 | " bert_model_hub,\n", 130 | " trainable=True)\n", 131 | " bert_inputs = dict(\n", 132 | " input_ids=input_ids,\n", 133 | " input_mask=input_mask,\n", 134 | " segment_ids=segment_ids)\n", 135 | " bert_outputs = bert_module(\n", 136 | " inputs=bert_inputs,\n", 137 | " signature=\"tokens\",\n", 138 | " as_dict=True)\n", 139 | "\n", 140 | " # Use \"pooled_output\" for classification tasks on an entire sentence.\n", 141 | " # Use \"sequence_outputs\" for token-level output.\n", 142 | " output_layer = bert_outputs[\"pooled_output\"]\n", 143 | "\n", 144 | " hidden_size = output_layer.shape[-1].value\n", 145 | "\n", 146 | " # Create our own layer to tune for politeness data.\n", 147 | " output_weights = tf.get_variable(\n", 148 | " \"output_weights\", [num_labels, hidden_size],\n", 149 | " initializer=tf.truncated_normal_initializer(stddev=0.02))\n", 150 | "\n", 151 | " output_bias = tf.get_variable(\n", 152 | " \"output_bias\", [num_labels], initializer=tf.zeros_initializer())\n", 153 | "\n", 154 | " with tf.variable_scope(\"loss\"):\n", 155 | "\n", 156 | " # Dropout helps prevent overfitting\n", 157 | " output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)\n", 158 | "\n", 159 | " logits = tf.matmul(output_layer, output_weights, transpose_b=True)\n", 160 | " logits = tf.nn.bias_add(logits, output_bias)\n", 161 | " log_probs = tf.nn.log_softmax(logits, axis=-1)\n", 162 | "\n", 163 | " # Convert labels into one-hot encoding\n", 164 | " one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)\n", 165 | "\n", 166 | " predicted_labels = tf.squeeze(tf.argmax(log_probs, axis=-1, output_type=tf.int32))\n", 167 | " # If we're predicting, we want predicted labels and the probabiltiies.\n", 168 | " if is_predicting:\n", 169 | " return (predicted_labels, log_probs)\n", 170 | "\n", 171 | " # If we're train/eval, compute loss between predicted and actual label\n", 172 | " per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)\n", 173 | " loss = tf.reduce_mean(per_example_loss)\n", 174 | " return (loss, predicted_labels, log_probs)\n", 175 | "\n", 176 | "# model_fn_builder actually creates our model function\n", 177 | "# using the passed parameters for num_labels, learning_rate, etc.\n", 178 | "def model_fn_builder(bert_model_hub, num_labels, learning_rate, num_train_steps,\n", 179 | " num_warmup_steps):\n", 180 | " \"\"\"Returns `model_fn` closure for TPUEstimator.\"\"\"\n", 181 | " def model_fn(features, labels, mode, params): # pylint: disable=unused-argument\n", 182 | " \"\"\"The `model_fn` for TPUEstimator.\"\"\"\n", 183 | "\n", 184 | " input_ids = features[\"input_ids\"]\n", 185 | " input_mask = features[\"input_mask\"]\n", 186 | " segment_ids = features[\"segment_ids\"]\n", 187 | " label_ids = features[\"label_ids\"]\n", 188 | "\n", 189 | " is_predicting = (mode == tf.estimator.ModeKeys.PREDICT)\n", 190 | " \n", 191 | " # TRAIN and EVAL\n", 192 | " if not is_predicting:\n", 193 | "\n", 194 | " (loss, predicted_labels, log_probs) = create_model(\n", 195 | " bert_model_hub, is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)\n", 196 | "\n", 197 | " train_op = bert.optimization.create_optimizer(\n", 198 | " loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu=False)\n", 199 | "\n", 200 | " # Calculate evaluation metrics. \n", 201 | " def metric_fn(label_ids, predicted_labels):\n", 202 | " accuracy = tf.metrics.accuracy(label_ids, predicted_labels)\n", 203 | " f1_score = tf.contrib.metrics.f1_score(\n", 204 | " label_ids,\n", 205 | " predicted_labels)\n", 206 | " auc = tf.metrics.auc(\n", 207 | " label_ids,\n", 208 | " predicted_labels)\n", 209 | " recall = tf.metrics.recall(\n", 210 | " label_ids,\n", 211 | " predicted_labels)\n", 212 | " precision = tf.metrics.precision(\n", 213 | " label_ids,\n", 214 | " predicted_labels) \n", 215 | " true_pos = tf.metrics.true_positives(\n", 216 | " label_ids,\n", 217 | " predicted_labels)\n", 218 | " true_neg = tf.metrics.true_negatives(\n", 219 | " label_ids,\n", 220 | " predicted_labels) \n", 221 | " false_pos = tf.metrics.false_positives(\n", 222 | " label_ids,\n", 223 | " predicted_labels) \n", 224 | " false_neg = tf.metrics.false_negatives(\n", 225 | " label_ids,\n", 226 | " predicted_labels)\n", 227 | " return {\n", 228 | " \"eval_accuracy\": accuracy,\n", 229 | " \"f1_score\": f1_score,\n", 230 | " \"auc\": auc,\n", 231 | " \"precision\": precision,\n", 232 | " \"recall\": recall,\n", 233 | " \"true_positives\": true_pos,\n", 234 | " \"true_negatives\": true_neg,\n", 235 | " \"false_positives\": false_pos,\n", 236 | " \"false_negatives\": false_neg\n", 237 | " }\n", 238 | "\n", 239 | " eval_metrics = metric_fn(label_ids, predicted_labels)\n", 240 | "\n", 241 | " if mode == tf.estimator.ModeKeys.TRAIN:\n", 242 | " return tf.estimator.EstimatorSpec(mode=mode,\n", 243 | " loss=loss,\n", 244 | " train_op=train_op)\n", 245 | " else:\n", 246 | " return tf.estimator.EstimatorSpec(mode=mode,\n", 247 | " loss=loss,\n", 248 | " eval_metric_ops=eval_metrics)\n", 249 | " else:\n", 250 | " (predicted_labels, log_probs) = create_model(\n", 251 | " bert_model_hub, is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)\n", 252 | "\n", 253 | " predictions = {\n", 254 | " 'probabilities': log_probs,\n", 255 | " 'labels': predicted_labels\n", 256 | " }\n", 257 | " return tf.estimator.EstimatorSpec(mode, predictions=predictions)\n", 258 | "\n", 259 | " # Return the actual model function in the closure\n", 260 | " return model_fn\n", 261 | "\n", 262 | "def estimator_builder(bert_model_hub, OUTPUT_DIR, SAVE_SUMMARY_STEPS, SAVE_CHECKPOINTS_STEPS, label_list, LEARNING_RATE, num_train_steps, num_warmup_steps, BATCH_SIZE):\n", 263 | "\n", 264 | " # Specify outpit directory and number of checkpoint steps to save\n", 265 | " run_config = tf.estimator.RunConfig(\n", 266 | " model_dir=OUTPUT_DIR,\n", 267 | " save_summary_steps=SAVE_SUMMARY_STEPS,\n", 268 | " save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS)\n", 269 | "\n", 270 | " model_fn = model_fn_builder(\n", 271 | " bert_model_hub = bert_model_hub,\n", 272 | " num_labels=len(label_list),\n", 273 | " learning_rate=LEARNING_RATE,\n", 274 | " num_train_steps=num_train_steps,\n", 275 | " num_warmup_steps=num_warmup_steps)\n", 276 | "\n", 277 | " estimator = tf.estimator.Estimator(\n", 278 | " model_fn=model_fn,\n", 279 | " config=run_config,\n", 280 | " params={\"batch_size\": BATCH_SIZE})\n", 281 | " return estimator, model_fn, run_config\n" 282 | ], 283 | "execution_count": 0, 284 | "outputs": [] 285 | }, 286 | { 287 | "metadata": { 288 | "id": "IuMOGwFui4it", 289 | "colab_type": "code", 290 | "trusted": true, 291 | "colab": {} 292 | }, 293 | "cell_type": "code", 294 | "source": [ 295 | "def run_on_dfs(train, test, DATA_COLUMN, LABEL_COLUMN, \n", 296 | " MAX_SEQ_LENGTH = 128,\n", 297 | " BATCH_SIZE = 32,\n", 298 | " LEARNING_RATE = 2e-5,\n", 299 | " NUM_TRAIN_EPOCHS = 3.0,\n", 300 | " WARMUP_PROPORTION = 0.1,\n", 301 | " SAVE_SUMMARY_STEPS = 100,\n", 302 | " SAVE_CHECKPOINTS_STEPS = 10000,\n", 303 | " bert_model_hub = \"https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1\"):\n", 304 | "\n", 305 | " label_list = train[LABEL_COLUMN].unique().tolist()\n", 306 | " \n", 307 | " tokenizer = create_tokenizer_from_hub_module(bert_model_hub)\n", 308 | "\n", 309 | " train_features = make_features(train, label_list, MAX_SEQ_LENGTH, tokenizer, DATA_COLUMN, LABEL_COLUMN)\n", 310 | " test_features = make_features(test, label_list, MAX_SEQ_LENGTH, tokenizer, DATA_COLUMN, LABEL_COLUMN)\n", 311 | "\n", 312 | " num_train_steps = int(len(train_features) / BATCH_SIZE * NUM_TRAIN_EPOCHS)\n", 313 | " num_warmup_steps = int(num_train_steps * WARMUP_PROPORTION)\n", 314 | "\n", 315 | " estimator, model_fn, run_config = estimator_builder(\n", 316 | " bert_model_hub, \n", 317 | " OUTPUT_DIR, \n", 318 | " SAVE_SUMMARY_STEPS, \n", 319 | " SAVE_CHECKPOINTS_STEPS, \n", 320 | " label_list, \n", 321 | " LEARNING_RATE, \n", 322 | " num_train_steps, \n", 323 | " num_warmup_steps, \n", 324 | " BATCH_SIZE)\n", 325 | "\n", 326 | " train_input_fn = bert.run_classifier.input_fn_builder(\n", 327 | " features=train_features,\n", 328 | " seq_length=MAX_SEQ_LENGTH,\n", 329 | " is_training=True,\n", 330 | " drop_remainder=False)\n", 331 | "\n", 332 | " estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)\n", 333 | "\n", 334 | " test_input_fn = run_classifier.input_fn_builder(\n", 335 | " features=test_features,\n", 336 | " seq_length=MAX_SEQ_LENGTH,\n", 337 | " is_training=False,\n", 338 | " drop_remainder=False)\n", 339 | "\n", 340 | " result_dict = estimator.evaluate(input_fn=test_input_fn, steps=None)\n", 341 | " return result_dict, estimator\n", 342 | " " 343 | ], 344 | "execution_count": 0, 345 | "outputs": [] 346 | }, 347 | { 348 | "metadata": { 349 | "trusted": true, 350 | "id": "wpUHTA8LIZu_", 351 | "colab_type": "code", 352 | "colab": {} 353 | }, 354 | "cell_type": "code", 355 | "source": [ 356 | "import random\n", 357 | "random.seed(10)" 358 | ], 359 | "execution_count": 0, 360 | "outputs": [] 361 | }, 362 | { 363 | "metadata": { 364 | "trusted": true, 365 | "id": "qmP8MBvjIZvB", 366 | "colab_type": "code", 367 | "colab": {} 368 | }, 369 | "cell_type": "code", 370 | "source": [ 371 | "OUTPUT_DIR = 'output'" 372 | ], 373 | "execution_count": 0, 374 | "outputs": [] 375 | }, 376 | { 377 | "metadata": { 378 | "id": "UV3ExI2JJMhd", 379 | "colab_type": "text" 380 | }, 381 | "cell_type": "markdown", 382 | "source": [ 383 | "----- you just need to focus from here ------" 384 | ] 385 | }, 386 | { 387 | "metadata": { 388 | "id": "XSUjVoFtJVEO", 389 | "colab_type": "text" 390 | }, 391 | "cell_type": "markdown", 392 | "source": [ 393 | "## Get your data" 394 | ] 395 | }, 396 | { 397 | "metadata": { 398 | "id": "pe6HHONmOwEs", 399 | "colab_type": "code", 400 | "colab": {} 401 | }, 402 | "cell_type": "code", 403 | "source": [ 404 | "!wget https://github.com/wshuyi/demo-chinese-text-binary-classification-with-bert/raw/master/dianping_train_test.pickle" 405 | ], 406 | "execution_count": 0, 407 | "outputs": [] 408 | }, 409 | { 410 | "metadata": { 411 | "trusted": true, 412 | "id": "XFBwDlmnIZvD", 413 | "colab_type": "code", 414 | "colab": {} 415 | }, 416 | "cell_type": "code", 417 | "source": [ 418 | "with open(\"dianping_train_test.pickle\", 'rb') as f:\n", 419 | " train, test = pickle.load(f)" 420 | ], 421 | "execution_count": 0, 422 | "outputs": [] 423 | }, 424 | { 425 | "metadata": { 426 | "trusted": true, 427 | "id": "Tju0c4dqIZvK", 428 | "colab_type": "code", 429 | "colab": {} 430 | }, 431 | "cell_type": "code", 432 | "source": [ 433 | "train = train.sample(len(train))" 434 | ], 435 | "execution_count": 0, 436 | "outputs": [] 437 | }, 438 | { 439 | "metadata": { 440 | "id": "p9BAxyKhKirc", 441 | "colab_type": "code", 442 | "colab": {} 443 | }, 444 | "cell_type": "code", 445 | "source": [ 446 | "train.head()" 447 | ], 448 | "execution_count": 0, 449 | "outputs": [] 450 | }, 451 | { 452 | "metadata": { 453 | "id": "hM8M7k0gKk5H", 454 | "colab_type": "code", 455 | "colab": {} 456 | }, 457 | "cell_type": "code", 458 | "source": [ 459 | "myparam = {\n", 460 | " \"DATA_COLUMN\": \"comment\",\n", 461 | " \"LABEL_COLUMN\": \"sentiment\",\n", 462 | " \"LEARNING_RATE\": 2e-5,\n", 463 | " \"NUM_TRAIN_EPOCHS\":3,\n", 464 | " \"bert_model_hub\":\"https://tfhub.dev/google/bert_chinese_L-12_H-768_A-12/1\"\n", 465 | " }" 466 | ], 467 | "execution_count": 0, 468 | "outputs": [] 469 | }, 470 | { 471 | "metadata": { 472 | "colab_type": "code", 473 | "id": "Dg2apeXpMqG-", 474 | "colab": {} 475 | }, 476 | "cell_type": "code", 477 | "source": [ 478 | "result, estimator = run_on_dfs(train, test, **myparam)" 479 | ], 480 | "execution_count": 0, 481 | "outputs": [] 482 | }, 483 | { 484 | "metadata": { 485 | "id": "YLOvtveTMqwp", 486 | "colab_type": "code", 487 | "colab": {} 488 | }, 489 | "cell_type": "code", 490 | "source": [ 491 | "pretty_print(result)" 492 | ], 493 | "execution_count": 0, 494 | "outputs": [] 495 | }, 496 | { 497 | "metadata": { 498 | "id": "D0OPTDcGMsJw", 499 | "colab_type": "code", 500 | "colab": {} 501 | }, 502 | "cell_type": "code", 503 | "source": [ 504 | "" 505 | ], 506 | "execution_count": 0, 507 | "outputs": [] 508 | } 509 | ] 510 | } -------------------------------------------------------------------------------- /dianping_train_test.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wshuyi/demo-chinese-text-binary-classification-with-bert/62d4a74ed4c91ea6196d55d0633e41bc7785fd84/dianping_train_test.pickle --------------------------------------------------------------------------------