├── .gitignore
├── LICENSE
├── README.md
├── data_demo
├── pre_train_data
├── std_data
├── test_data
├── train_data
├── valid_data
└── vocab
├── dec_mining
├── README.md
├── data
│ ├── topk.std.text.avg
│ ├── trainset
│ └── vocab
├── data_utils.py
├── dataset.py
├── dec_model.py
├── images
│ ├── modify_1.gif
│ ├── modify_2.gif
│ ├── 冷启动流程图.png
│ ├── 轮廓系数公式.png
│ └── 迭代挖掘流程图.png
├── inference.py
├── print_sen_embedding.py
└── train.py
├── docs
├── RUNDEMO.md
├── dssm.png
├── kg_demo.png
├── lstm_dssm_bagging.png
├── measurement.png
├── pretrain.png
└── sptm.png
├── dssm_predict.py
├── lstm_predict.py
├── merge_classifier_match_label.py
├── models
├── __init__.py
├── bilstm.py
└── dssm.py
├── run_bi_lstm.py
├── run_dssm.py
├── sptm
├── format_result.py
├── models.py
├── run_classifier.py
├── run_prediction.py
├── run_pretraining.py
└── utils.py
└── utils
├── __init__.py
├── classifier_utils.py
└── match_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/*
2 | dec_mining/.idea/*
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (C) 2005-present, 58.com. All rights reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 项目介绍
2 | qa_match是一款基于深度学习的问答匹配工具,支持一层和两层结构知识库问答。qa_match通过意图匹配模型支持一层结构知识库问答,通过融合领域分类模型和意图匹配模型的结果支持两层结构知识库问答。qa_match同时支持无监督预训练功能,通过轻量级预训练语言模型(SPTM,Simple Pre-trained Model)可以提升基于知识库问答等下游任务的效果。
3 |
4 | ## 知识库问答
5 | 在实际场景中,知识库一般是通过人工总结、标注、机器挖掘等方式进行构建,知识库中包含大量的标准问题,每个标准问题有一个标准答案和一些扩展问法,我们称这些扩展问法为扩展问题。对于一层结构知识库,仅包含标准问题和扩展问题,我们把标准问题称为意图。对于两层结构知识库,每个标准问题及其扩展问题都有一个类别,我们称为领域,一个领域包含多个意图。
6 |
7 | qa_match支持知识库结构如下:
8 |
9 | 
10 |
11 | 对于输入的问题,qa_match能够结合知识库给出三种回答:
12 | 1. 唯一回答(识别为用户具体的意图)
13 | 2. 列表回答(识别为用户可能的多个意图)
14 | 3. 拒识(没有识别出具体的用户意图)
15 |
16 | 在两种知识库结构下,qa_match的使用方式存在差异,以下分别说明:
17 |
18 | ### 基于两层结构知识库的自动问答
19 |
20 | 
21 |
22 | 对于两层结构知识库问答,qa_match会对用户问题先进行领域分类和意图识别,然后对两者的结果进行融合,获取用户的真实意图进行相应回答(唯一回答、列表回答、拒绝回答)。
23 | 举个例子:如上述知识库问答中[知识库结构图](#知识库问答)所示,我们有一个两层结构知识库,它包括”信息“和”账号“两个领域”。其中“信息”领域下包含两个意图:“如何发布信息”、“如何删除信息”,“账号”领域下包含意图:“如何注销账号”。当用户输入问题为:“我怎么发布帖子?”时,qa_match会进行如下问答逻辑:
24 |
25 | 1. 分别用LSTM领域分类模型和DSSM意图匹配模型对输入问题进行打分,如:领域分类模型最高打分为0.99且识别为“信息”领域,意图匹配模型最高打分为0.98且识别为“如何发布信息”意图。由于领域分类模型最高打分对应的label为信息类,所以进入判断为某一类分支。
26 | 2. 进入判断为某一分类分支后,用领域分类模型的最高打分0.99与两层结构知识库问答图中阈值b1(如b1=0.9)进行对比,由于0.99>=b1,判断为走“严格DSSM意图匹配”子分支。
27 | 3. 进入“严格DSSM意图匹配”分支后,用意图匹配模型的最高打分0.98与阈值x1(例如x1=0.8),x2(如x2=0.95)做比较,发现0.98>x2,由此用如何发布信息对应的答案进行唯一回答(其他分支回答类似)。
28 |
29 | ### 基于一层结构知识库的自动问答
30 |
31 | 实际场景中,我们也会遇到一层结构知识库问答问题,用DSSM意图匹配模型与SPTM轻量级预训练语言模型均可以解决此类问题,两者对比:
32 |
33 | | 模型 | 使用方法 | 优点 | 缺点 |
34 | | ------------------------ | ---------------------------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ |
35 | | DSSM意图匹配模型 | DSSM匹配模型直接匹配 | ①使用简便,模型占用空间小
②训练/预测速度快 | 无法利用上下文信息 |
36 | | SPTM轻量级预训练语言模型 | 预训练LSTM/Transformer语言模型
+微调LSTM/Transformer匹配模型 | ①能够充分利用无监督预训练数据提升效果
②语言模型可用于其他下游任务 | ①预训练需要大量无标签数据
②操作较复杂(需两个步骤得到匹配模型) |
37 |
38 | #### 基于DSSM模型的自动问答
39 | 
40 |
41 | 对于一层结构知识库问答,只需用DSSM意图匹配模型对输入问题进行打分,根据意图匹配的最高分值与上图中的x1,x2进行比较决定回答类型(唯一回答、列表回答、拒识)。
42 |
43 | #### 基于SPTM模型的自动问答
44 |
45 | ##### 轻量级预训练语言模型(SPTM,Simple Pre-trained Model)介绍
46 |
47 | 考虑到实际使用中往往存在大量的无标签数据,在知识库数据有限时,可使用无监督预训练语言模型提升匹配模型的效果。参考[BERT](https://github.com/google-research/bert)预训练过程,2019年5月我们开发了SPTM模型,该模型相对于BERT主要改进了三方面:一是去掉了效果不明显的NSP(Next Sentence Prediction),二是为了提高线上推理性能将Transformer替换成了LSTM,三是为了保证模型效果降低参数量也提供了BLOCK间共享参数的Transformer,模型原理如下:
48 |
49 | ###### 数据预处理
50 |
51 | 预训练模型时,生成训练数据需要使用无标签单句作为数据集,并参考了BERT来构建样本:每个单句作为一个样本,句子中15%的字参与预测,参与预测的字中80%进行mask,10%随机替换成词典中一个其他的字,10%不替换。
52 |
53 | ###### 预训练
54 |
55 | 预训练阶段的模型结构如下图所示:
56 |
57 |
58 |
59 | 
60 |
61 | 为提升模型的表达能力,保留更多的浅层信息,引入了残差Bi-LSTM网络(Residual LSTM)作为模型主体。该网络将每一层Bi-LSTM的输入和该层输出求和归一化后,结果作为下一层的输入。此外将最末层Bi-LSTM输出作为一个全连接层的输入,与全连接层输出求和归一化后,结果作为整个网络的输出。
62 |
63 | 预训练任务耗时示例如下表所示:
64 |
65 | | **指标名称** | **指标值** | **指标值** | **指标值** |
66 | | ---------------- | -------------------------------- | ----------------------------------------------- | ----------------------------------------------- |
67 | |模型结构 | LSTM | 共享参数的Transformer | 共享参数的Transformer |
68 | | 预训练数据集大小 | 10Million | 10Million | 10Million |
69 | | 预训练资源 | 10台Nvidia K40 / 12G Memory | 10台Nvidia K40 / 12G Memory | 10台Nvidia K40 / 12G Memory |
70 | | 预训练参数 | step = 100000 / batch size = 128 | step = 100000 / batch size = 128 / 1 layers / 12 heads | step = 100000 / batch size = 128 / 12 layers / 12 heads |
71 | | 预训练耗时 | 8.9 hours | 13.5 hours | 32.9 hours |
72 | | 预训练模型大小 | 81M | 80.6M | 121M |
73 |
74 | ##### SPTM自动问答流程
75 |
76 | 
77 |
78 | 引入SPTM后,对于一层结构知识库问答,先使用基于语言模型微调的意图匹配模型对输入问题进行打分,再根据与DSSM意图匹配模型相同的策略决定回答类型(唯一回答、列表回答、拒识)。
79 |
80 | ## 如何使用
81 | ### 数据介绍
82 | 需要使用到的数据文件(data_demo文件夹下)格式说明如下,这里为了不泄露数据,我们对标准问题和扩展问题原始文本做了编码,在实际应用场景中直接按照以下格式准备数据即可。
83 | - std_data:类别和标准问题对应关系,包含类别ID、标准问题ID、标准问题文本三列
84 | - pre_train_data:无标签的预训练数据集,每行是一段文本
85 | - vocab:预训练数据字典,每行是一个词(字典应包含``、`、``)
86 | - train_data:训练集,包含标准问题ID、扩展问题ID、扩展问题文本三列
87 | - valid_data:验证集,包含标准问题ID、扩展问题ID、扩展问题文本三列
88 | - test_data:测试集,包含标准问题ID、扩展问题ID、扩展问题文本三列
89 |
90 | 数据以\t分隔,问题编码以空格分隔,字之间以空格分隔。注意在本项目的数据示例中,对原始文本做了编码,将每个字替换为一个数字, 例如`205 19 90 417 41 44` 对应的实际文本是`如何删除信息`,**在实际使用时不需要做该编码操作**;若知识库结构为一级,需要把std_data文件中的类别id全部设置为`__label__0`。
91 |
92 | ### 知识库半自动挖掘流程
93 |
94 | 知识库半自动挖掘流程,是在qa match自动问答流程的基础上(参考qa match 基于一层知识库结构的自动问答)构建的一套知识库半自动挖掘方案,帮助提升知识库规模与知识库质量,一方面提高线上匹配的能力;一方面提高离线模型训练数据的质量,进而提高模型性能。知识库半自动挖掘流程可以用于冷启动挖掘和模型上线后迭代挖掘两个场景。详情参见[知识库挖掘说明文档](./dec_mining/README.md)
95 |
96 | ### 怎么运行
97 | 详情见[运行说明](docs/RUNDEMO.md)
98 |
99 | ### tips
100 | 1. 由于DSSM模型训练选取负样本时是将原样本对应标签随机打散,所以模型参数需要满足`batch_size >= negitive_size`,否则模型无法有效训练。
101 | 2. 模型融合参数选取方法:目前参数的选取是基于统计的,首先在测试集上计算同一参数(如两层结构知识库问答图中a1)不同值所对应的模型label(如拒识)的f1值,然后选取较大的f1值对应的数值做为该参数的取值。如:在选取两层结构知识库问答图中参数a1的最终取值时,首先会在测试集上得到不同a1候选值对应的模型label(如拒识,非拒识),然后根据样本的真实label计算f1值,最后选取合适的f1值(根据项目需求可偏重准确率/召回率)对应的候选值作为a1的最终取值。
102 |
103 | ## 运行环境
104 | ```
105 | tensorflow 版本>r1.8 阈值的样本点作为扩展问题,进行人工审核入库。
50 |
51 |
52 |
53 | ## 第二个场景:迭代挖掘流程
54 |
55 |
56 |
57 | 知识库迭代挖掘场景具体指模型已经上线后,知识库中已经有了一定数量的标准问题和扩展问题,但是由于线上数据是动态变化的,所以存在模型覆盖不到的标准问题和扩展问法,迭代挖掘的目的就是及时的将它们挖掘出来,增加线上样本覆盖度,从而提高模型准召。
58 |
59 |
60 |
61 | 迭代挖掘流程图如下:
62 |
63 |
64 |
65 | 
66 |
67 |
68 |
69 | ### 迭代挖掘步骤
70 |
71 |
72 |
73 | 1. 基于目前自动问答流程(参考[qa match 基于一层知识库结构的自动问答](https://github.com/wuba/qa_match/tree/master#%E5%9F%BA%E4%BA%8E%E4%B8%80%E5%B1%82%E7%BB%93%E6%9E%84%E7%9F%A5%E8%AF%86%E5%BA%93%E7%9A%84%E8%87%AA%E5%8A%A8%E9%97%AE%E7%AD%94)),从线上拒识问题以及每周人工抽样标注的新分类问题(目前标准问题没有覆盖到并且非拒识的问题)中提取新知识。
74 |
75 | 2. 粗略筛除几类问题:
76 |
77 | a)超短query(长度小于3的线上问题,这个类别是optional的,根据具体场景实现)。此类query 在问答场景通常会被拒识,若不拒识,大部分通过匹配实现。
78 |
79 | b)高频问题。对于高频问题,一定要覆盖,不需要再经过挖掘,直接筛选出来交给人工审核进行入库,剩余问题送入DEC算法模块进行挖掘。
80 |
81 | 3. 初步筛出比较纯粹的query 之后,使用已有标准问题作为自定义聚类中心,选取聚类结果概率值 > 阈值的样本点作为扩展问题,进行人工审核入库;对于挖掘新类别的标准问题,可以参考冷启动场景的方法进行挖掘。
82 |
83 | # 运行说明
84 |
85 |
86 |
87 | 本示例给出了支持一层结构知识库的基于SPTM表征的DEC挖掘算法运行demo、评测指标及在测试集的效果。
88 |
89 |
90 |
91 | ## 数据介绍
92 |
93 |
94 |
95 | 需要使用到的数据文件(dec_mining/data 文件夹下)格式说明如下,这里为了不泄露数据,我们对标准问题和扩展问题原始文本做了编码,在实际应用场景中直接按照以下格式准备数据即可。所给出聚类数据集,取自58智能问答生产环境下的真实数据,这里仅为了跑通模型,因此只取了少部分数据,其中待聚类数据1w。
96 |
97 |
98 |
99 | * [trainset](./data/trainset):待聚类数据。两列\t分隔,第一列为 ground truth 标签问题ID ,格式 `__label__n`;如果没有ground truth 标签,也需要设置一个n进行占位,查看聚类结果时忽略掉即可;第二列为标准问题文本,分字格式(空格切分)。
100 | * [topk.std.text.avg](./data/topk.std.text.avg) : 自定义聚类中心文件。该文件为从topk问题总结出的标准问,每行为一个聚类中心,支持多个问题的平均作为聚类中心的格式 使用斜杠/分隔,如:“你好这辆车还在么/车还在吗/这车还在吗”
101 |
102 |
103 |
104 | ## 运行示例
105 |
106 |
107 |
108 | 使用DEC算法进行聚类需要两步,先train 也就是先微调表征,然后再做inference得到聚类结果。
109 |
110 |
111 |
112 |
113 | (1) 根据自定义聚类中心文本得到表征,此步骤可选,如果选择使用K-means做初始化,则不需要此步骤,在步骤(2)指定 `n_clusters` 即可
114 |
115 |
116 | ```bash
117 | cd dec_mining && python3 print_sen_embedding.py --input_file=./topk.std.text.avg --vocab_file=./vocab --model_path=./pretrain_model/lm_pretrain.ckpt-1000000 --batch_size=512 --max_seq_len=50 --output_file=./topk.std.embedding.avg --embedding_way=max
118 | ```
119 |
120 | 参数说明:
121 |
122 | input_file:自定义聚类中心文本。
123 |
124 | vocab_file : 词典文件(需要包含 `` )
125 |
126 | model_path: SPTM预训练表征模型,预训练模型的embedding的维度要跟第(2)步的embedding_dim参数保持一致
127 |
128 | max_seq_len: 截断的最大长度
129 |
130 | output_file: 输出的自定义聚类中心表征文件
131 |
132 |
133 |
134 |
135 | (2) 使用待聚类数据进行微调表征
136 |
137 | ```bash
138 | cd dec_mining && python3 ./train.py --init_checkpoint=./pretrain_model/lm_pretrain.ckpt-1000000 --train_file=./data/trainset --epochs=30 --lstm_dim=128 --embedding_dim=256 --vocab_file=./vocab --external_cluster_center=./topk.std.embedding.avg --model_save_dir=./saved_model --learning_rate=0.03 --warmup_steps=5000
139 | ```
140 |
141 |
142 |
143 | 参数说明:
144 |
145 |
146 |
147 | init_checkpoint: SPTM预训练表征模型
148 |
149 | train_file: 待聚类数据
150 |
151 | epochs: epoch 数量
152 |
153 | n_clusters: K-means 方法指定聚类中心数;此参数与external_cluster_center 只传入一个即可,需要与步骤(3)中inference 过程使用的参数一致,指定`n_clusters`表示使用K-means做初始化,指定external_cluster_center表示使用自定义聚类中心做初始化
154 |
155 | lstm_dim: SPTM lstm的门控单元数
156 |
157 | embedding_dim: SPTM 词嵌入维度,需要设置为lstm_dim 参数的2倍
158 |
159 | vocab_file: 词典文件(需要包含 ``)
160 |
161 | external_cluster_center: 自定义聚类中心文件,此参数与n_clusters 只传入一个即可,需要与步骤(3)中inference 过程使用的参数一致,指定`n_clusters`表示使用K-means做初始化,指定external_cluster_center表示使用自定义聚类中心做初始化
162 |
163 | model_save_dir: DEC模型保存路径
164 |
165 | learning_rate: 学习率
166 |
167 |
168 | warmup_steps:学习率 warm up 步数
169 |
170 |
171 | (3) 根据微调好的表征对待聚类数据进行DEC聚类
172 |
173 |
174 |
175 | ```bash
176 | cd dec_mining && python3 inference.py --model_path=./saved_model/finetune.ckpt-0 --train_file=./data/trainset --external_cluster_center=./topk.std.embedding.avg --lstm_dim=128 --embedding_dim=256 --vocab_file=./vocab --pred_score_path=./pred_score
177 | ```
178 |
179 |
180 |
181 | 参数说明:
182 |
183 |
184 |
185 | model_path: 上一步train 得到的DEC模型
186 |
187 | train_file: 待聚类数据
188 |
189 | lstm_dim: SPTM lstm的门控单元数
190 |
191 | embedding_dim: SPTM 词嵌入维度
192 |
193 | vocab_file: 词典文件(需要包含 `` )
194 |
195 | pred_score_path: 聚类结果打分文件,格式:`pred_label + \t + question + \t + groundtruth_label + \t + probability` 例如:`__label__4` 请添加车主阿涛微信详谈 `__label__0 ` 00.9888488
196 |
197 |
198 |
199 | ## 算法评测指标及测试集效果
200 |
201 |
202 |
203 | 聚类算法的评估一般分为外部评估和内部评估,外部评估是指数据集有ground truth label 时通过有监督标签进行评估; 内部评估是不借助外部可信标签,单纯从无监督数据集内部评估,内部评估的原则是类内距小,类间距大,这里我们使用轮廓系数(silhouette)来评估。
204 |
205 |
206 |
207 | **轮廓系数 silhouette coefficient**
208 |
209 |
210 |
211 | 
212 |
213 |
214 |
215 | * a(i) = avg(i 向量到所有它属于的簇中其他点的距离)
216 |
217 | * b(i) = min(i 向量与其他的簇内的所有点的平均)
218 |
219 | * 值介于 [-1,1] ,越趋近于1代表内聚度和分离度都相对较优
220 |
221 | * 将所有点的轮廓系数求平均,就是该聚类结果总的轮廓系数
222 |
223 |
224 |
225 | **准确率accuracy计算**
226 |
227 | [sklearn linear_assignment](https://www.kite.com/python/docs/sklearn.utils.linear_assignment_.linear_assignment)
228 |
229 |
230 |
231 | 归纳top10 问题后,运行后评测效果如下(使用通用深度学习推理服务[dl_inference](https://github.com/wuba/dl_inference)开源项目部署模型来评测推理耗时):
232 |
233 |
234 |
235 | | 数据集 | 模型 | **Silhouette** | **Runtime** | **Inference Time** | **Accuracy** |
236 | | ------ | ---- | -------------- | ----------- |------------ | ------------ |
237 | | 1w | DEC | 0.7962 | 30min | 52s |0.8437 |
238 | | 10W | DEC | 0.9302 | 3h 5min | 5min 55s |-- |
239 | | 100W | DEC | 0.849 | 11h30min |15min 28s | -- |
240 |
241 | **tips:**
242 |
243 | 1. 由于实验场景有标签数据集数量 < 10w,因此10w, 100w数据集上没有accuracy的数值
244 |
245 |
246 |
247 | ## 运行环境
248 |
249 |
250 |
251 | ```
252 | tensorflow 版本>r1.8
1556 |
1557 |
--------------------------------------------------------------------------------
/dec_mining/data_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | data utils
4 | """
5 |
6 | import os
7 | import sys
8 | import codecs
9 | import collections
10 | import tensorflow as tf
11 | import numpy as np
12 | import six
13 |
14 |
15 |
16 | def convert_to_unicode(text):
17 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
18 | if six.PY3:
19 | if isinstance(text, str):
20 | return text
21 | elif isinstance(text, bytes):
22 | return text.decode("utf-8", "ignore")
23 | else:
24 | raise ValueError("Unsupported string type: %s" % (type(text)))
25 | elif six.PY2:
26 | if isinstance(text, str):
27 | return text.decode("utf-8", "ignore")
28 | elif isinstance(text, unicode):
29 | return text
30 | else:
31 | raise ValueError("Unsupported string type: %s" % (type(text)))
32 | else:
33 | raise ValueError("Not running on Python2 or Python 3?")
34 |
35 | class Sentence:
36 | def __init__(self, raw_tokens, raw_label):
37 | self.raw_tokens = raw_tokens
38 | self.raw_label = raw_label
39 | self.label_id = None
40 | self.token_ids = []
41 |
42 | def to_ids(self, word2id, label2id, max_len):
43 | self.label_id = label2id[self.raw_label]
44 | self.raw_tokens = self.raw_tokens[:max_len] # cut off to the max length
45 | all_unk = True
46 | for raw_token in self.raw_tokens:
47 | if raw_token not in ['', '', '', '', '', '', '']:
48 | raw_token = raw_token.lower()
49 | if raw_token in word2id:
50 | self.token_ids.append(word2id[raw_token])
51 | all_unk = False
52 | else:
53 | self.token_ids.append(word2id[""])
54 | if all_unk:
55 | tf.logging.info("all unk" + self.raw_tokens)
56 |
57 | self.token_ids = self.token_ids + [0] * (max_len - len(self.token_ids))
58 |
59 | def gen_ids(sens, word2id, label2id, max_len):
60 | for sen in sens:
61 | sen.to_ids(word2id, label2id, max_len)
62 |
63 | # convert dataset to tensor
64 | def make_full_tensors(sens):
65 | tokens = np.zeros((len(sens), len(sens[0].token_ids)), dtype=np.int32)
66 | labels = np.zeros((len(sens)), dtype=np.int32)
67 | length = np.zeros((len(sens)), dtype=np.int32)
68 | for idx, sen in enumerate(sens):
69 | tokens[idx] = sen.token_ids
70 | labels[idx] = sen.label_id
71 | length[idx] = len(sen.raw_tokens)
72 | return tokens, length, labels
73 |
74 | def gen_batchs(full_tensors, batch_size, is_shuffle):
75 | tokens, labels, length = full_tensors
76 | # per = np.array([i for i in range(len(tokens))])
77 | per = np.array(list(range(len(tokens))))
78 | if is_shuffle:
79 | np.random.shuffle(per)
80 |
81 | cur_idx = 0
82 | token_batch = []
83 | label_batch = []
84 | length_batch = []
85 | while cur_idx < len(tokens):
86 | token_batch.append(tokens[per[cur_idx]])
87 | label_batch.append(labels[per[cur_idx]])
88 | length_batch.append(length[per[cur_idx]])
89 | if len(token_batch) == batch_size or cur_idx == len(tokens) - 1:
90 | yield token_batch, label_batch, length_batch
91 | token_batch = []
92 | label_batch = []
93 | length_batch = []
94 | cur_idx += 1
95 |
96 | def load_sentences(file_path, skip_invalid):
97 | sens = []
98 | invalid_num = 0
99 | max_len = 0
100 | for raw_l in codecs.open(file_path, 'r', 'utf-8'): # load as utf-8 encoding.
101 | if raw_l.strip() == "":
102 | continue
103 | file_s = raw_l.rstrip().split('\t')
104 | assert len(file_s) == 2
105 | tokens = file_s[1].split() # discard empty strings
106 | for token in tokens:
107 | assert token != ""
108 | label = file_s[0]
109 | if skip_invalid:
110 | if label.find(',') >= 0 or label.find('NONE') >= 0:
111 | invalid_num += 1
112 | continue
113 | if len(tokens) > max_len:
114 | max_len = len(tokens)
115 | sens.append(Sentence(tokens, label))
116 | tf.logging.info("invalid sen num : " + str(invalid_num))
117 | tf.logging.info("valid sen num : " + str(len(sens)))
118 | tf.logging.info("max_len : " + str(max_len))
119 | return sens
120 |
121 | def load_vocab(sens, vocab_file):
122 | label2id = {}
123 | id2label = {}
124 | for sen in sens:
125 | if sen.raw_label not in label2id:
126 | label2id[sen.raw_label] = len(label2id)
127 | id2label[len(id2label)] = sen.raw_label
128 |
129 | index = 0
130 | word2id = collections.OrderedDict()
131 | id2word = collections.OrderedDict()
132 | for l_raw in codecs.open(vocab_file, 'r', 'utf-8'):
133 | token = convert_to_unicode(l_raw)
134 | # if not token:
135 | # break
136 | token = token.strip()
137 | word2id[token] = index
138 | # id2word[index] = token
139 | index += 1
140 |
141 | for k, value in word2id.items():
142 | id2word[value] = k
143 |
144 | assert len(word2id) == len(id2word)
145 | tf.logging.info("token num : " + str(len(word2id)))
146 | tf.logging.info("label num : " + str(len(label2id)))
147 | tf.logging.info("labels: " + str(id2label))
148 | return word2id, id2word, label2id, id2label
149 |
150 | def evaluate(sess, full_tensors, args, model):
151 | total_num = 0
152 | right_num = 0
153 | for batch_data in gen_batchs(full_tensors, args.batch_size, is_shuffle=False):
154 | softmax_re = sess.run(model.softmax_op,
155 | feed_dict={model.ph_dropout_rate: 0,
156 | model.ph_tokens: batch_data[0],
157 | model.ph_labels: batch_data[1],
158 | model.ph_length: batch_data[2]})
159 | pred_re = np.argmax(softmax_re, axis=1)
160 | total_num += len(pred_re)
161 | right_num += np.sum(pred_re == batch_data[1])
162 | acc = 1.0 * right_num / (total_num + 1e-5)
163 |
164 | tf.logging.info("dev total num: " + str(total_num) + ", right num: " + str(right_num) + ", acc: " + str(acc))
165 | return acc
166 |
167 | def load_spec_centers(path):
168 | raw_f = open(path, "r", encoding="utf-8")
169 | f_lines = raw_f.readlines()
170 |
171 | res = []
172 | for line in f_lines:
173 | vec = [float(i) for i in line.strip().split(" ")]
174 | res.append(vec)
175 | return tf.convert_to_tensor(res), len(res)
176 |
177 | def write_file(out_path, out_str):
178 | exists = os.path.isfile(out_path)
179 | if exists:
180 | os.remove(out_path)
181 | tf.logging.info("File Removed!")
182 |
183 | raw_f = open(out_path, "w", encoding="utf-8")
184 | raw_f.write(out_str)
185 | raw_f.close()
186 |
187 | def load_vocab_file(vocab_file):
188 | word2id = {}
189 | id2word = {}
190 | for raw_l in codecs.open(vocab_file, 'r', 'utf8'):
191 | raw_l = raw_l.strip()
192 | assert raw_l != ""
193 | assert raw_l not in word2id
194 | word2id[raw_l] = len(word2id)
195 | id2word[len(id2word)] = raw_l
196 | tf.logging.info("uniq token num : " + str(len(word2id)) + "\n")
197 | return word2id, id2word
198 |
--------------------------------------------------------------------------------
/dec_mining/dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | dataset class
4 | """
5 | import random
6 | import math
7 | import numpy as np
8 | import data_utils
9 |
10 |
11 | class Dataset():
12 | def __init__(self, train_x=None, train_y=None, test_x=None,
13 | test_y=None, train_length_x=None, test_length_x=None):
14 | self.train_x = train_x
15 | self.train_length_x = train_length_x
16 | self.test_length_x = test_length_x
17 | self.train_y = train_y
18 | self.test_x = test_x
19 | self.test_y = test_y
20 |
21 | def gen_next_batch(self, batch_size, is_train_set, epoch=None, iteration=None):
22 | if is_train_set is True:
23 | raw_x = self.train_x
24 | x_length = self.train_length_x
25 | raw_y = self.train_y
26 | else:
27 | raw_x = self.test_x
28 | x_length = self.test_length_x
29 | raw_y = self.test_y
30 |
31 | assert len(raw_x) >= batch_size,\
32 | "batch size must be smaller than data size {}.".format(len(raw_x))
33 |
34 | if epoch is not None:
35 | until = math.ceil(float(epoch * len(raw_x)) / float(batch_size))
36 | elif iteration is not None:
37 | until = iteration
38 | else:
39 | assert False, "epoch or iteration must be set."
40 |
41 | iter_ = 0
42 | index_list = list(range(len(raw_x)))
43 | while iter_ <= until:
44 | idxs = random.sample(index_list, batch_size)
45 | iter_ += 1
46 | yield (raw_x[idxs], raw_y[idxs], idxs, x_length[idxs])
47 |
48 |
49 | class ExpDataset(Dataset):
50 | def __init__(self, args):
51 | super().__init__()
52 |
53 | train_file = args.train_file
54 | vocab_file = args.vocab_file
55 |
56 | train_sens = data_utils.load_sentences(train_file, skip_invalid=True)
57 | word2id, id2word, label2id, id2label = data_utils.load_vocab(train_sens, vocab_file)
58 |
59 | data_utils.gen_ids(train_sens, word2id, label2id, 100)
60 | train_full_tensors = data_utils.make_full_tensors(train_sens)
61 |
62 | raw_x = train_full_tensors[0]
63 | x_length = train_full_tensors[1]
64 | x_labels = train_full_tensors[2]
65 |
66 | raw_f = lambda t: id2label[t]
67 | x_labels_true = np.array(list(map(raw_f, x_labels)))
68 |
69 | n_train = int(len(raw_x) * 1)
70 | self.train_x, self.test_x = raw_x[:n_train], raw_x[n_train:]
71 | self.train_length_x, self.test_length_x = x_length[:n_train], x_length[n_train:]
72 | self.train_y, self.test_y = x_labels[:n_train], x_labels[n_train:]
73 | self.gt_label = x_labels_true
74 | self.raw_q = ["".join(i.raw_tokens) for i in train_sens]
75 |
--------------------------------------------------------------------------------
/dec_mining/dec_model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | modified DEC model
4 | """
5 | import collections
6 | import re
7 | import sys
8 | import tensorflow as tf
9 | import numpy as np
10 | from sklearn.cluster import KMeans
11 | from sklearn.utils.linear_assignment_ import linear_assignment
12 | sys.path.append("..")
13 | from sptm import models
14 | import data_utils
15 |
16 |
17 | """ modified DEC model
18 | References
19 | ----------
20 | https://github.com/HaebinShin/dec-tensorflow/blob/master/dec/model.py
21 | """
22 | class DEC:
23 | def __init__(self, params, other_params):
24 | # pass n_clusters or external_cluster_center
25 | if params.external_cluster_center == "":
26 | self.n_cluster = params.n_clusters
27 | self.kmeans = KMeans(n_clusters=params.n_clusters, n_init=10)
28 | else:
29 | # get cluster_center_embedding, n_cluster from external file
30 | self.external_cluster_center_vec, self.n_cluster = \
31 | data_utils.load_spec_centers(params.external_cluster_center)
32 |
33 | self.embedding_dim = params.embedding_dim
34 | # load SPTM pretrained model
35 | model = models.create_finetune_classification_model(params, other_params)
36 | self.pretrained_model = model
37 | self.alpha = params.alpha
38 |
39 | # mu: cluster center
40 | self.mu = tf.Variable(tf.zeros(shape=(self.n_cluster, self.embedding_dim)),
41 | name="mu") # [n_class, emb_dim]
42 |
43 | self.z = model.max_pool_output # [batch, emb_dim]
44 | with tf.name_scope("distribution"):
45 | self.q = self._soft_assignment(self.z, self.mu) # [, n_class]
46 | self.p = tf.placeholder(tf.float32, shape=(None, self.n_cluster)) # [, n_class]
47 | self.pred = tf.argmax(self.q, axis=1)
48 | self.pred_prob = tf.reduce_max(self.q, axis=1)
49 |
50 | with tf.name_scope("dec-train"):
51 | self.loss = self._kl_divergence(self.p, self.q)
52 | self.global_step_op = tf.train.get_or_create_global_step()
53 | self.lr = params.learning_rate
54 | warmup_steps = params.warmup_steps
55 | warmup_lr = (self.lr * tf.cast(self.global_step_op, tf.float32)
56 | / tf.cast(warmup_steps, tf.float32))
57 | self.warmup_learning_rate_op = \
58 | tf.cond(self.global_step_op < warmup_steps, lambda: warmup_lr, lambda: self.lr)
59 | self.optimizer = tf.train.AdamOptimizer(self.warmup_learning_rate_op)
60 | self.trainer = self.optimizer.minimize(self.loss, global_step=self.global_step_op)
61 |
62 | def get_assign_cluster_centers_op(self, features):
63 | # init mu
64 | tf.logging.info("Kmeans train start.")
65 | kmeans = self.kmeans.fit(features)
66 | tf.logging.info("Kmeans train end.")
67 | return tf.assign(self.mu, kmeans.cluster_centers_)
68 |
69 | # emb [batch, emb_dim] centroid [n_class, emb_dim]
70 | def _soft_assignment(self, embeddings, cluster_centers):
71 | """Implemented a soft assignment as the probability of assigning sample i to cluster j.
72 |
73 | Args:
74 | embeddings: (num_points, dim)
75 | cluster_centers: (num_cluster, dim)
76 |
77 | Return:
78 | q_i_j: (num_points, num_cluster)
79 | """
80 | def _pairwise_euclidean_distance(a, b):
81 | # p1 [batch, n_class]
82 | p1 = tf.matmul(
83 | tf.expand_dims(tf.reduce_sum(tf.square(a), 1), 1),
84 | tf.ones(shape=(1, self.n_cluster))
85 | )
86 | # p2 [batch, n_class]
87 | p2 = tf.transpose(tf.matmul(
88 | tf.reshape(tf.reduce_sum(tf.square(b), 1), shape=[-1, 1]),
89 | tf.ones(shape=(tf.shape(a)[0], 1)),
90 | transpose_b=True
91 | ))
92 | # [batch, n_class]
93 | res = tf.sqrt(
94 | tf.abs(tf.add(p1, p2) - 2 * tf.matmul(a, b, transpose_b=True)))
95 |
96 | return res
97 |
98 | dist = _pairwise_euclidean_distance(embeddings, cluster_centers)
99 | q = 1.0 / (1.0 + dist ** 2 / self.alpha) ** ((self.alpha + 1.0) / 2.0)
100 | q = (q / tf.reduce_sum(q, axis=1, keepdims=True))
101 | return q
102 |
103 | def target_distribution(self, q):
104 | p = q ** 2 / q.sum(axis=0)
105 | p = p / p.sum(axis=1, keepdims=True)
106 | return p
107 |
108 | def _kl_divergence(self, target, pred):
109 | return tf.reduce_mean(tf.reduce_sum(target * tf.log(target / (pred)), axis=1))
110 |
111 | def cluster_acc(self, y_true, y_pred):
112 | """
113 | Calculate clustering accuracy. Require scikit-learn installed
114 | # Arguments
115 | y: true labels, numpy.array with shape `(n_samples,)`
116 | y_pred: predicted labels, numpy.array with shape `(n_samples,)`
117 | # Return
118 | accuracy, in [0,1]
119 | """
120 | y_true = y_true.astype(np.int64)
121 | assert y_pred.size == y_true.size
122 | D = max(y_pred.max(), y_true.max()) + 1
123 | w = np.zeros((D, D), dtype=np.int64)
124 | for i in range(y_pred.size):
125 | w[y_pred[i], y_true[i]] += 1
126 | ind = linear_assignment(w.max() - w)
127 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size
128 |
--------------------------------------------------------------------------------
/dec_mining/images/modify_1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/dec_mining/images/modify_1.gif
--------------------------------------------------------------------------------
/dec_mining/images/modify_2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/dec_mining/images/modify_2.gif
--------------------------------------------------------------------------------
/dec_mining/images/冷启动流程图.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/dec_mining/images/冷启动流程图.png
--------------------------------------------------------------------------------
/dec_mining/images/轮廓系数公式.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/dec_mining/images/轮廓系数公式.png
--------------------------------------------------------------------------------
/dec_mining/images/迭代挖掘流程图.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/dec_mining/images/迭代挖掘流程图.png
--------------------------------------------------------------------------------
/dec_mining/inference.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | inference cluster labels
4 | """
5 | import argparse
6 | from datetime import datetime
7 | import numpy as np
8 | import tensorflow as tf
9 | from sklearn.metrics import silhouette_score, silhouette_samples
10 | import dataset
11 | import dec_model
12 | import data_utils
13 |
14 |
15 | def write_results(z, y_true, raw_q, out_name, prob):
16 | assert len(z) == len(raw_q)
17 | out_str = ""
18 | label_map = {} # sort samples order by y_pred
19 | for (y_pred, gt_label, q, pro) in zip(z, y_true, raw_q, prob):
20 | prob = -np.sort(-prob)
21 | if y_pred in label_map:
22 | label_map[y_pred].append("__label__" + str(y_pred) + "\t" + q +
23 | ": ground truth label is " + str(gt_label) + str(pro))
24 | else:
25 | label_map[y_pred] = []
26 | label_map[y_pred].append("__label__" + str(y_pred) + "\t" + q +
27 | ": ground truth label is" + str(gt_label) + str(pro))
28 |
29 | for _, lines in label_map.items():
30 | for line in lines:
31 | out_str += line + "\n"
32 | data_utils.write_file(out_name, out_str)
33 |
34 | def print_metrics(x, labels):
35 | sil_avg = silhouette_score(x, labels) # avg silhouette score
36 | sils = silhouette_samples(x, labels) # silhouette score of each sample
37 | tf.logging.info("avg silhouette:" + str(sil_avg))
38 |
39 | def inference(data, model, params):
40 | config = tf.ConfigProto()
41 | config.gpu_options.allow_growth = True
42 | saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=None)
43 | batch_size = params.batch_size
44 |
45 | with tf.Session(config=config) as sess:
46 | sess.run(tf.global_variables_initializer())
47 | saver.restore(sess, params.model_path)
48 |
49 | train_size = len(data.train_x)
50 | step_by_batch = train_size // batch_size + 1
51 | tf.logging.info("step by batch " + str(step_by_batch))
52 | z_total = [] # z: transformed representation
53 | prob_total = [] # predict cluster probability
54 | pred_total = [] # predict cluster label
55 |
56 | for idx in range(step_by_batch):
57 | if idx == step_by_batch - 1:
58 | tf.logging.info("start/ end idx " + str(idx * batch_size) + " " + str(idx * batch_size + batch_size))
59 | cur_pred, cur_prob, cur_z = sess.run(
60 | [model.pred, model.pred_prob, model.z], feed_dict={
61 | model.pretrained_model.ph_tokens: data.train_x[idx * batch_size:],
62 | model.pretrained_model.ph_length: data.train_length_x[idx * batch_size:],
63 | model.pretrained_model.ph_dropout_rate: 0
64 | })
65 | else:
66 | cur_pred, cur_prob, cur_z = sess.run(
67 | [model.pred, model.pred_prob, model.z], feed_dict={
68 | model.pretrained_model.ph_tokens:
69 | data.train_x[idx * batch_size: idx * batch_size + batch_size],
70 | model.pretrained_model.ph_length:
71 | data.train_length_x[idx * batch_size: idx * batch_size + batch_size],
72 | model.pretrained_model.ph_dropout_rate: 0
73 | })
74 |
75 | now = datetime.now()
76 | # tf.logging.info("sess run index " + str(idx) + " " + str(len(cur_pred)) + now.strftime("%H:%M:%S"))
77 | pred_total.extend(cur_pred)
78 | prob_total.extend(cur_prob)
79 | z_total.extend(cur_z)
80 | tf.logging.info("pred total " + str(len(pred_total)) + " , sample total " + str(len(data.train_x)))
81 | assert len(pred_total) == len(data.train_x)
82 | clust_label = np.array(pred_total)
83 | prob = np.array(prob_total)
84 | print_metrics(z_total, clust_label)
85 | # write inference result file
86 | write_results(clust_label, data.gt_label, data.raw_q, params.pred_score_path, prob)
87 | return clust_label
88 |
89 | if __name__=="__main__":
90 | tf.logging.set_verbosity(tf.logging.INFO)
91 | parser = argparse.ArgumentParser()
92 | parser.add_argument("--model_path", type=str, default="")
93 | parser.add_argument("--init_checkpoint", type=str, default="")
94 | parser.add_argument("--train_file", type=str, default="")
95 | parser.add_argument("--batch_size", type=int, default=32)
96 | parser.add_argument("--lstm_dim", type=int, default=500)
97 | parser.add_argument("--embedding_dim", type=int, default=1000)
98 | parser.add_argument("--vocab_file", type=str, default="./vocab")
99 | parser.add_argument("--external_cluster_center", type=str, default="")
100 | parser.add_argument("--n_clusters", type=int, default=20)
101 | parser.add_argument("--alpha", type=int, default=1)
102 | parser.add_argument("--layer_num", type=int, default=1)
103 | parser.add_argument("--token_num", type=int, default=7820)
104 | parser.add_argument("--learning_rate", type=float, default=0.01)
105 | parser.add_argument("--warmup_steps", type=int, default=1000)
106 | parser.add_argument("--epochs", type=int, default=5)
107 | parser.add_argument("--pred_score_path", type=str, default='')
108 | args = parser.parse_args()
109 |
110 | word2id, id2word = data_utils.load_vocab_file(args.vocab_file)
111 | TRAINSET_SIZE = len(data_utils.load_sentences(args.train_file, skip_invalid=True))
112 | other_arg_dict = {}
113 | other_arg_dict['token_num'] = len(word2id)
114 | other_arg_dict['trainset_size'] = TRAINSET_SIZE
115 |
116 | exp_data = dataset.ExpDataset(args)
117 | dec_model = dec_model.DEC(args, other_arg_dict)
118 | inference(exp_data, dec_model, args)
119 |
--------------------------------------------------------------------------------
/dec_mining/print_sen_embedding.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | print sentence embedding
4 | """
5 | import os
6 | import sys
7 | import argparse
8 | import codecs
9 | import tensorflow as tf
10 | import numpy as np
11 | import data_utils
12 |
13 | # get graph output in different way: max mean concat
14 | def get_output(g, embedding_way):
15 | if embedding_way == "concat": # here may have problem, this is just for 4 layers of biLM !
16 | t = g.get_tensor_by_name("concat_4:0")
17 | elif embedding_way == "max":
18 | t = g.get_tensor_by_name("Max:0")
19 | elif embedding_way == 'mean':
20 | t = g.get_tensor_by_name("Mean:0")
21 | else:
22 | assert False
23 | return {"sen_embedding": t}
24 |
25 | # get graph input
26 | def get_input(g):
27 | return {"tokens": g.get_tensor_by_name("ph_tokens:0"),
28 | "length": g.get_tensor_by_name("ph_length:0"),
29 | "dropout_rate": g.get_tensor_by_name("ph_dropout_rate:0")}
30 |
31 | def gen_test_data(input_file, word2id, max_seq_len):
32 | sens = []
33 | center_size = []
34 | for line in codecs.open(input_file, 'r', 'utf-8'):
35 | # separated by slash
36 | ls = line.strip().split("/")
37 | center_size.append(len(ls))
38 | for l in ls:
39 | l = l.replace(" ", "")
40 | l = l.replace("", " ")
41 | fs = l.rstrip().split()
42 | if len(fs) > max_seq_len:
43 | continue
44 | sen = []
45 | for f in fs:
46 | if f in word2id:
47 | sen.append(word2id[f])
48 | else:
49 | sen.append(word2id[''])
50 | sens.append(sen)
51 | return sens, center_size
52 |
53 | if __name__=="__main__":
54 |
55 | parser = argparse.ArgumentParser()
56 | parser.add_argument("--input_file", type=str, default="")
57 | parser.add_argument("--vocab_file", type=str, default="")
58 | parser.add_argument("--model_path", type=str, default="")
59 | parser.add_argument("--batch_size", type=int, default=256)
60 | parser.add_argument("--max_seq_len", type=int, default=100)
61 | parser.add_argument("--output_file", type=str, default="")
62 | # sentence representation output way : max mean concat
63 | parser.add_argument("--embedding_way", type=str, default="concat")
64 | args = parser.parse_args()
65 |
66 | word2id, id2word = data_utils.load_vocab_file(args.vocab_file)
67 | sys.stderr.write("vocab num : " + str(len(word2id)) + "\n")
68 | sens, center_size = gen_test_data(args.input_file, word2id, args.max_seq_len)
69 | sys.stderr.write("sens num : " + str(len(sens)) + "\n")
70 | tf.logging.info("embedding_way : ", args.embedding_way)
71 |
72 | # limit cpu resource
73 | cpu_num = int(os.environ.get('CPU_NUM', 15))
74 | config = tf.ConfigProto(device_count={"CPU": cpu_num},
75 | inter_op_parallelism_threads = cpu_num,
76 | intra_op_parallelism_threads = cpu_num,
77 | log_device_placement=True)
78 | config.gpu_options.allow_growth = True
79 | with tf.Session(config=config) as sess:
80 | saver = tf.train.import_meta_graph("{}.meta".format(args.model_path))
81 | saver.restore(sess, args.model_path)
82 |
83 | graph = tf.get_default_graph()
84 | input_dict = get_input(graph)
85 | output_dict = get_output(graph, args.embedding_way)
86 |
87 | caches = []
88 | idx = 0
89 | while idx < len(sens):
90 | batch_sens = sens[idx:idx + args.batch_size]
91 | batch_tokens = []
92 | batch_length = []
93 | for sen in batch_sens:
94 | batch_tokens.append(sen)
95 | batch_length.append(len(sen))
96 |
97 | real_max_len = max([len(b) for b in batch_tokens])
98 | for b in batch_tokens:
99 | b.extend([0] * (real_max_len - len(b)))
100 |
101 | re = sess.run(output_dict['sen_embedding'],
102 | feed_dict={input_dict['tokens']: batch_tokens,
103 | input_dict['length']: batch_length,
104 | input_dict["dropout_rate"]: 0.0})
105 | if len(caches) % 200 == 0:
106 | tf.logging.info(len(caches))
107 | caches.append(re)
108 | idx += len(batch_sens)
109 |
110 | sen_embeddings = np.concatenate(caches, 0)
111 | # calculate average embedding
112 | avg_centers = []
113 |
114 | idx = 0
115 | for size in center_size:
116 | avg_center_emb = np.average(sen_embeddings[idx: idx + size], axis=0)
117 | avg_centers.append(avg_center_emb)
118 | idx = idx + size
119 |
120 | np.savetxt(args.output_file, avg_centers, fmt='%.3e')
121 |
--------------------------------------------------------------------------------
/dec_mining/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | finetune on pretrained model with dataset to be clustered
4 | """
5 |
6 | import argparse
7 | import tensorflow as tf
8 | import numpy as np
9 | import dataset
10 | import dec_model
11 | import data_utils
12 |
13 |
14 | def train(data, model, args):
15 | saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=2)
16 | best_acc = 0
17 | config = tf.ConfigProto(allow_soft_placement=True)
18 | config.gpu_options.allow_growth = True
19 | with tf.Session(config=config) as sess:
20 | sess.run(tf.global_variables_initializer())
21 |
22 | train_size = len(data.train_x)
23 | batch_size = args.batch_size
24 |
25 | steps_in_epoch = train_size // args.batch_size + 1
26 | tf.logging.info("step by batch " + str(steps_in_epoch))
27 | z_total = []
28 |
29 | # z: transformed representation of x
30 | for idx in range(steps_in_epoch):
31 | if idx == steps_in_epoch - 1: # last batch
32 | tf.logging.info("start/ end idx " + str(idx * batch_size) + " " + str(idx * batch_size + batch_size))
33 | cur_z = sess.run(model.z, feed_dict={
34 | model.pretrained_model.ph_tokens: data.train_x[idx * batch_size:],
35 | model.pretrained_model.ph_length: data.train_length_x[idx * batch_size:],
36 | model.pretrained_model.ph_dropout_rate: 0
37 | })
38 | else:
39 | cur_z = sess.run(model.z, feed_dict={
40 | model.pretrained_model.ph_tokens: data.train_x[
41 | idx * batch_size: idx * batch_size + batch_size],
42 | model.pretrained_model.ph_length: data.train_length_x[
43 | idx * batch_size: idx * batch_size + batch_size],
44 | model.pretrained_model.ph_dropout_rate: 0
45 | })
46 | z_total.extend(cur_z)
47 |
48 | tf.logging.info("z total size : " + str(len(z_total))) # sample size
49 | assert len(z_total) == len(data.train_x)
50 | z = np.array(z_total)
51 |
52 | # Customize the cluster center
53 | if args.external_cluster_center != "":
54 | # load external centers file
55 | external_center = model.external_cluster_center_vec
56 | assign_mu_op = tf.assign(model.mu, external_center)
57 | else:
58 | # kmeans init centers
59 | assign_mu_op = model.get_assign_cluster_centers_op(z)
60 |
61 | amu = sess.run(assign_mu_op) # get cluster center
62 |
63 | for cur_epoch in range(args.epochs):
64 | q_list = []
65 |
66 | for idx in range(steps_in_epoch):
67 | start_idx = idx * batch_size
68 | end_idx = idx * batch_size + batch_size
69 | if idx == steps_in_epoch - 1:
70 | q_batch = sess.run(
71 | model.q, feed_dict={
72 | model.pretrained_model.ph_tokens: data.train_x[start_idx:],
73 | model.pretrained_model.ph_length: data.train_length_x[start_idx:],
74 | model.pretrained_model.ph_dropout_rate: 0
75 | })
76 | else:
77 | q_batch = sess.run(
78 | model.q, feed_dict={
79 | model.pretrained_model.ph_tokens: data.train_x[start_idx: end_idx],
80 | model.pretrained_model.ph_length: data.train_length_x[start_idx: end_idx],
81 | model.pretrained_model.ph_dropout_rate: 0
82 | })
83 |
84 | q_list.extend(q_batch)
85 |
86 | q = np.array(q_list)
87 | p = model.target_distribution(q)
88 |
89 | for iter_, (batch_x, batch_y, batch_idxs, batch_x_lengths) in enumerate(
90 | data.gen_next_batch(batch_size=batch_size, \
91 | is_train_set=True, epoch=1)):
92 | batch_p = p[batch_idxs]
93 | _, loss, pred, global_step, lr = sess.run([model.trainer, model.loss, model.pred, model.global_step_op, model.optimizer._lr], \
94 | feed_dict={model.pretrained_model.ph_tokens: batch_x, \
95 | model.pretrained_model.ph_length: batch_x_lengths, \
96 | model.p: batch_p, \
97 | model.pretrained_model.ph_dropout_rate: 0
98 | })
99 | # NOTE: acc 只用于有监督数据查看聚类效果,ground truth label不会参与到train,如果是无监督数据,此acc 无用
100 | acc = model.cluster_acc(batch_y, pred)
101 | tf.logging.info("[DEC] epoch: {}\tloss: {}\tacc: {}\t lr {} \t global_step {} ".format(cur_epoch, loss, acc, lr, global_step))
102 | if acc > best_acc:
103 | best_acc = acc
104 | tf.logging.info("!!!!!!!!!!!!! best acc got " + str(best_acc))
105 | # save model each epoch
106 | saver.save(sess, args.model_save_dir + '/finetune.ckpt', global_step=global_step)
107 |
108 | if __name__ == "__main__":
109 | tf.logging.set_verbosity(tf.logging.INFO)
110 | # pass params
111 | parser = argparse.ArgumentParser()
112 | # sptm pretrain model path
113 | parser.add_argument("--init_checkpoint", type=str, default='')
114 | parser.add_argument("--train_file", type=str, default="")
115 | parser.add_argument("--batch_size", type=int, default=32)
116 | # customized cluster centers file path, pass either of params 'external_cluster_center' or 'n_clusters'
117 | parser.add_argument("--external_cluster_center", type=str, default="")
118 | # number of clusters (init with kmeans)
119 | parser.add_argument("--n_clusters", type=int, default=20)
120 | parser.add_argument("--epochs", type=int, default=50)
121 | parser.add_argument("--warmup_steps", type=int, default=1000)
122 | parser.add_argument("--learning_rate", type=float, default=0.01)
123 | # DEC model q distribution param, alpha=1 in paper
124 | parser.add_argument("--alpha", type=int, default=1)
125 | parser.add_argument("--layer_num", type=int, default=1)
126 | parser.add_argument("--token_num", type=int, default=7820)
127 | parser.add_argument("--lstm_dim", type=int, default=500)
128 | parser.add_argument("--embedding_dim", type=int, default=1000)
129 | parser.add_argument("--vocab_file", type=str, default="./vocab")
130 | parser.add_argument("--model_save_dir", type=str, default="./saved_model")
131 | args = parser.parse_args()
132 |
133 | word2id, id2word = data_utils.load_vocab_file(args.vocab_file)
134 | trainset_size = len(data_utils.load_sentences(args.train_file, skip_invalid=True))
135 | other_arg_dict = {}
136 | other_arg_dict['token_num'] = len(word2id)
137 | other_arg_dict['trainset_size'] = trainset_size
138 |
139 | exp_data = dataset.ExpDataset(args)
140 | dec_model = dec_model.DEC(args, other_arg_dict)
141 | train(exp_data, dec_model, args)
142 |
--------------------------------------------------------------------------------
/docs/RUNDEMO.md:
--------------------------------------------------------------------------------
1 | # 运行说明
2 | 本示例给出了支持一层和两层结构知识库问答运行demo、评测指标及在测试集的效果。
3 |
4 | ## 数据介绍
5 | [data_demo](../data_demo)所给出的预训练集(pre_train_data),训练集(train_data),验证集(valid_data),预测集(test_data) 取自58智能问答生产环境下的真实数据,这里仅为了跑通模型,因此只取了少部分数据,其中预训练集9W+(真实场景下数量较大),训练集9W+,验证集和测试集均3000+,具体数据格式可见[README](../README.md)中的数据介绍部分。
6 |
7 | ## 基于一层结构知识库的自动问答运行示例
8 |
9 | ### 基于DSSM模型的自动问答
10 |
11 | 使用DSSM意图匹配模型时,对于一层结构知识库只需要先训练意图匹配模型,然后用训练好的模型对测试集进行预测,最后对意图匹配的结果按照打分阈值高低给出回答类别,当意图匹配打分高于某个阈值时给出唯一回答,当打分低于某个阈值时给出拒识回答,当打分处于这两个阈值之间时给出列表回答。
12 |
13 | #### 1.训练DSSM意图匹配模型
14 |
15 | ```bash
16 | mkdir model && python run_dssm.py --train_path=./data_demo/train_data --valid_path=./data_demo/valid_data --map_file_path=./data_demo/std_data --model_path=./model/model_min/ --softmax_r=45 --embedding_size=256 --learning_rate=0.001 --keep_prob=0.8 --batch_size=250 --num_epoches=30 --negative_size=200 --eval_every=10 --num_units=256 --use_same_cell=True --label2id_path=./model/model_min/min_label2id --vocab2id_path=./model/model_min/min_vocab2id
17 | ```
18 |
19 | 参数说明:
20 |
21 | train_path: 训练集
22 |
23 | valid_path: 验证集
24 |
25 | map_file_path: 领域意图映射文件
26 |
27 | model_path: 模型存储路径
28 |
29 | softmax_r: 余弦相似度滑动参数
30 |
31 | embedding_size: embedding层向量大小
32 |
33 | learning_rate: 学习率
34 |
35 | keep_prob: dropout过程中keep神经元的概率
36 |
37 | batch_size: batch 大小
38 |
39 | num_epoches: epcho个数
40 |
41 | negative_size: 负样本数量
42 |
43 | eval_every: 每隔多少steps在验证集上检验训练过程中的模型效果
44 |
45 | num_units: lstm cell 的单元个数
46 |
47 | use_same_cell: 前向后向lstm是否需要用相同的cell(共享一套参数)
48 |
49 | label2id_path: <意图,id>映射文件
50 |
51 | vocab2id_path: 根据训练数据生成的字典映射文件
52 |
53 | #### 2.用意图匹配模型对测试集进行预测
54 |
55 | ```bash
56 | python dssm_predict.py --map_file_path=./data_demo/std_data --model_path=./model/model_min/ --export_model_dir=./model/model_min/dssm_tf_serving/ --test_data_path=./data_demo/test_data --test_result_path=./model/model_min/result_min_test --softmax_r=45 --batch_size=250 --label2id_file=./model/model_min/min_label2id --vocab2id_file=./model/model_min/min_vocab2id
57 | ```
58 |
59 | #### 3.意图匹配的结果按照打分阈值高低给出回答类别
60 |
61 | ```bash
62 | python merge_classifier_match_label.py none ./model/model_min/result_min_test ./data_demo/merge_result_1_level none
63 | ```
64 |
65 | ### 基于轻量级预训练语言模型(SPTM)的自动问答
66 |
67 | 使用SPTM进行意图匹配时,对于一层结构知识库需要先预训练语言模型,然后基于预训练语言模型与训练集微调意图匹配模型;最终在用训练好的模型对测试集进行预测后,对意图匹配的结果按照打分阈值高低给出回答类别。其阈值判别的方式与无预训练场景相同。
68 |
69 | #### 1.预训练语言模型
70 |
71 | ##### 基于Bi-LSTM block的预训练
72 | ```bash
73 | cd sptm && mkdir -p model/pretrain && python run_pretraining.py --train_file="../data_demo/pre_train_data" --vocab_file="../data_demo/vocab" --model_save_dir="./model/pretrain" --batch_size=256 --print_step=100 --weight_decay=0 --embedding_dim=1000 --lstm_dim=500 --layer_num=1 --train_step=100000 --warmup_step=1000 --learning_rate=5e-5 --dropout_rate=0.1 --max_predictions_per_seq=10 --clip_norm=1.0 --max_seq_len=100 --use_queue=0 --representation_type=lstm
74 | ```
75 | 参数说明:
76 |
77 | vocab:词典文件(需要包含 ``)
78 |
79 | train_file/valid_data:训练/验证集
80 |
81 | lstm_dim:lstm的门控单元数
82 |
83 | embedding_dim:词嵌入维度
84 |
85 | dropout_rate:节点被dropout的比例
86 |
87 | layer_num:LSTM的层数
88 |
89 | weight_decay:adam的衰减系数
90 |
91 | max_predictions_per_seq:每个句子中,最多会mask的词数
92 |
93 | clip_norm:梯度裁剪阈值
94 |
95 | use_queue:是否使用队列生成预训练数据
96 |
97 | representation_type:使用何种结构训练模型,可选择lstm或transformer
98 |
99 | ##### 基于参数共享Transformer block的预训练
100 | ```bash
101 | cd sptm && mkdir -p model/pretrain && python run_pretraining.py --train_file="../data_demo/pre_train_data" --vocab_file="../data_demo/vocab" --model_save_dir="./model/pretrain" --batch_size=64 --print_step=100 --embedding_dim=1000 --train_step=100000 --warmup_step=5000 --learning_rate=1e-5 --max_predictions_per_seq=10 --clip_norm=1.0 --max_len=100 --use_queue=0 --representation_type=transformer --initializer_range=0.02 --max_position_embeddings=140 --hidden_size=768 --num_hidden_layers=12 --num_attention_heads=12 --intermediate_size=1024
102 | ```
103 | 参数说明:
104 |
105 | learning_rate:学习率
106 |
107 | initializer_range:模型参数正态分布初始化的标准差
108 |
109 | max_position_embeddings: position的最大位置
110 |
111 | hidden_size:隐层大小
112 |
113 | num_hidden_layers:隐层数
114 |
115 | num_attention_heads:多头数
116 |
117 | intermediate_size:ffn层大小
118 |
119 | representation_type:使用何种结构训练模型,可选择lstm或transformer
120 |
121 |
122 | #### 2.微调意图匹配模型
123 |
124 | 注意此处的```init_checkpoint```需要根据预训练的结果进行选取,如没有预训练模型,也可以不填写:
125 |
126 | ##### 基于Bi-LSTM block的微调
127 | ```bash
128 | cd sptm && python run_classifier.py --output_id2label_file="model/id2label.has_init" --vocab_file="../data_demo/vocab" --train_file="../data_demo/train_data" --dev_file="../data_demo/valid_data" --model_save_dir="model/finetune" --lstm_dim=500 --embedding_dim=1000 --opt_type=adam --batch_size=256 --epoch=20 --learning_rate=1e-4 --seed=1 --max_len=100 --print_step=10 --dropout_rate=0.1 --layer_num=1 --init_checkpoint="model/pretrain/lm_pretrain.ckpt-1400" --representation_type=lstm
129 | ```
130 | 参数说明:
131 |
132 | output_id2label_file:(id,标签)映射文件,最后预测的时侯使用
133 |
134 | opt_type:优化器类型,有sgd/adagrad/adam几种可选
135 |
136 | seed:随机种子的值,使用相同的随机种子保证微调模型结果一致
137 |
138 | init_checkpoint:预训练模型保存的checkpoint
139 |
140 | ##### 基于参数共享Transformer block的微调
141 | ```bash
142 | cd sptm && python run_classifier.py --output_id2label_file="model/id2label.has_init" --vocab_file="../data_demo/vocab" --train_file="../data_demo/train_data" --dev_file="../data_demo/valid_data" --model_save_dir="model/finetune" --embedding_dim=1000 --opt_type=adam --batch_size=64 --epoch=10 --learning_rate=1e-4 --seed=1 --max_len=100 --print_step=100 --dropout_rate=0.1 --use_queue=0 --representation_type=transformer --initializer_range=0.02 --max_position_embeddings=140 --hidden_size=768 --num_hidden_layers=12 --num_attention_heads=12 --intermediate_size=1024 --init_checkpoint="model/pretrain/lm_pretrain.ckpt-100800" --representation_type=transformer
143 | ```
144 | 参数说明:
145 |
146 | opt_type:优化器类型
147 |
148 | representation_type:使用何种结构训练模型,可选择lstm或transformer
149 |
150 |
151 | #### 3.用意图匹配模型对测试集进行预测
152 |
153 | ```bash
154 | cd sptm && python run_prediction.py --input_file="../data_demo/test_data" --vocab_file="../data_demo/vocab" --id2label_file="model/id2label.has_init" --model_dir="model/finetune" > "../data_demo/result_pretrain_raw"
155 | ```
156 |
157 | #### 4.预测结果格式化,按照打分阈值高低给出回答类别
158 |
159 | ```bash
160 | python sptm/format_result.py ./data_demo/test_data ./data_demo/result_pretrain_raw ./data_demo/result_pretrain_test
161 | ```
162 | 参数说明:
163 |
164 | argv[1]: 测试集
165 |
166 | argv[2]: SPTM打分文件
167 |
168 | argv[3]: 格式化SPTM打分文件
169 |
170 | ```
171 | python merge_classifier_match_label.py none ./data_demo/result_pretrain_test ./data_demo/merge_result_pretrain none
172 | ```
173 |
174 | ## 基于两层结构知识库的自动问答运行示例
175 | 对于两层结构知识库需要先训练领域分类模型和意图匹配模型,然后用训练好的模型对测试集进行预测,最后对领域分类和意图匹配的结果进行融合,给出回答类别,具体融合策略参考README中两层结构知识库问答融合示意图。
176 |
177 | ### 1.训练LSTM领域分类模型
178 |
179 | ```bash
180 | mkdir model && python run_bi_lstm.py --train_path=./data_demo/train_data --valid_path=./data_demo/valid_data --map_file_path=./data_demo/std_data --model_path=./model/model_max --vocab_file=./model/model_max/vocab_max --label_file=./model/model_max/label_max --embedding_size=256 --num_units=256 --batch_size=200 --seq_length=40 --num_epcho=30 --check_every=20 --lstm_layers=2 --lr=0.01 --dropout_keep_prob=0.8
181 | ```
182 | 参数说明:
183 |
184 | train_path: 训练集
185 |
186 | valid_path: 验证集
187 |
188 | map_file_path: 领域意图映射文件
189 |
190 | model_path: 模型存储路径
191 |
192 | vocab_file:根据训练数据生成的字典映射文件
193 |
194 | label_file:根据训练过程生成的<领域,id>映射文件
195 |
196 | embedding_size: embedding层向量大小
197 |
198 | num_units: lstm cell 的单元个数
199 |
200 | batch_size: batch 大小
201 |
202 | seq_length: 参与训练的最大序列长度
203 |
204 | num_epcho: epcho个数
205 |
206 | check_every: 每隔多少steps在验证集上检验训练过程中的模型效果
207 |
208 | lstm_layers: lstm 层数
209 |
210 | lr: 学习率
211 |
212 | dropout_keep_prob: dropout过程中keep神经元的概率
213 |
214 | ### 2.用领域分类模型对测试集进行预测
215 |
216 | ```bash
217 | python lstm_predict.py --map_file_path=./data_demo/std_data --model_path=./model/model_max --test_data_path=./data_demo/test_data --test_result_path=./model/model_max/result_max_test --batch_size=250 --seq_length=40 --label2id_file=./model/model_max/label_max --vocab2id_file=./model/model_max/vocab_max
218 | ```
219 |
220 | 参数说明:
221 |
222 | map_file_path: 领域意图映射文件
223 |
224 | model_path: 模型路径
225 |
226 | test_data_path: 测试集
227 |
228 | test_result_path: 测试打分结果文件
229 |
230 | batch_size: batch 大小
231 |
232 | seq_length: 参与训练的最大序列长度(要和训练过程保持一致)
233 |
234 | label2id_file:<领域,id>映射文件
235 |
236 | vocab2id_file: 根据训练数据生成的字典映射文件
237 |
238 | ### 3.训练DSSM意图匹配模型
239 | ```bash
240 | python run_dssm.py --train_path=./data_demo/train_data --valid_path=./data_demo/valid_data --map_file_path=./data_demo/std_data --model_path=./model/model_min/ --result_file_path=./data/result_min --softmax_r=45 --embedding_size=256 --learning_rate=0.001 --keep_prob=0.8 --batch_size=250 --num_epoches=30 --negative_size=200 --eval_every=10 --num_units=256 --use_same_cell=False --label2id_path=./model/model_min/min_label2id --vocab2id_path=./model/model_min/min_vocab2id
241 | ```
242 |
243 | ### 4.用意图匹配模型对测试集进行预测
244 |
245 | ```bash
246 | python dssm_predict.py --map_file_path=./data_demo/std_data --model_path=./model/model_min/ --export_model_dir=./model/model_min/dssm_tf_serving/ --test_data_path=./data_demo/test_data --test_result_path=./model/model_min/result_min_test --softmax_r=45 --batch_size=250 --label2id_file=./model/model_min/min_label2id --vocab2id_file=./model/model_min/min_vocab2id
247 | ```
248 |
249 | ### 5.领域分类结果&意图匹配结果进行融合
250 |
251 | ```bash
252 | python merge_classifier_match_label.py ./model/model_max/result_max_test ./model/model_min/result_min_test ./data_demo/merge_result_2_level ./data_demo/std_data
253 | ```
254 |
255 | 参数说明:
256 |
257 | argv[1]: 领域分类打分文件
258 |
259 | argv[2]: 意图识别打分文件
260 |
261 | argv[3]: 模型融合文件
262 |
263 | argv[4]: 领域意图映射文件
264 |
265 | ## 模型评测指标及测试集效果
266 |
267 | 目前qa_match的问答效果评测是基于分类模型的评测,主要看在模型各种回答类型(唯一回答,列表回答,拒绝回答)占比接近真实应回答类型占比下各种回答的类型的准确率、召回率、F1值,具体定义如下:
268 |
269 | 
270 |
271 | 对上述一层结构知识库和二层结构知识库示例(数据集具体见[data_demo](../data_demo))运行后评测效果如下(使用通用深度学习推理服务[dl_inference](https://github.com/wuba/dl_inference)开源项目部署模型来评测推理耗时):
272 |
273 | | 数据集 | 模型 | **唯一回答准确率** | **唯一回答召回率** | **唯一回答**F1 | **CPU**机器上推理耗时 |
274 | | ---------------- | ------------------------------------------------------------ | ------------------ | ------------------ | -------------- | --------------------- |
275 | | 一级知识库数据集 | DSSM[[下载](http://wos.58cdn.com.cn/nOlKjIhGntU/qamatch/model_min.zip)] | 0.8398 | 0.8326 | 0.8362 | **3ms** |
276 | | 一级知识库数据集 | SPTM(LSTM)[[下载](http://wos.58cdn.com.cn/nOlKjIhGntU/qamatch/model_pretrain.zip)] | 0.8841 | 0.9002 | 0.8921 | 16ms |
277 | | 一级知识库数据集 | SPTM(Transformer 12 Layers,12 Heads)[[下载](http://wos.58cdn.com.cn/nOlKjIhGntU/qamatch/model_pretrain_transformer.zip)] | 0.9275 | 0.9298 | **0.9287** | 17ms |
278 | | 一级知识库数据集 | SPTM(Transformer 1 Layers,12 Heads)[[下载](http://wos.58cdn.com.cn/nOlKjIhGntU/qamatch/transorformer_1_layer_12_heads.zip)] | 0.9122 | 0.9105 | 0.9122 | 13ms |
279 | | 二级知识库数据集 | LSTM+DSSM融合模型[[下载](http://wos.58cdn.com.cn/nOlKjIhGntU/qamatch/model_merge.zip)] | 0.8957 | 0.9027 | 0.8992 | 18ms |
280 |
281 | 说明:由于示例数据中列表回答真实占比较小,这里我们主要看唯一回答的准确率、召回率和F1值。对于二级知识库数据集,我们也可以使用预训练模型来完成自动问答,这里不做过多描述。
282 |
--------------------------------------------------------------------------------
/docs/dssm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/docs/dssm.png
--------------------------------------------------------------------------------
/docs/kg_demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/docs/kg_demo.png
--------------------------------------------------------------------------------
/docs/lstm_dssm_bagging.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/docs/lstm_dssm_bagging.png
--------------------------------------------------------------------------------
/docs/measurement.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/docs/measurement.png
--------------------------------------------------------------------------------
/docs/pretrain.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/docs/pretrain.png
--------------------------------------------------------------------------------
/docs/sptm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/docs/sptm.png
--------------------------------------------------------------------------------
/dssm_predict.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import shutil
4 | import os
5 | from utils.match_utils import DataHelper
6 |
7 | flags = tf.app.flags
8 | FLAGS = flags.FLAGS
9 |
10 | flags.DEFINE_string('map_file_path', None, 'standard data path')
11 | flags.DEFINE_string("model_path", None, "checkpoint dir from predicting")
12 | flags.DEFINE_string("export_model_dir", None, "export model dir")
13 | flags.DEFINE_string('test_data_path', None, 'test data path')
14 | flags.DEFINE_string('test_result_path', None, 'test data result path')
15 | flags.DEFINE_integer('softmax_r', 45, 'Smooth parameter for osine similarity') # must be similar as train
16 | flags.DEFINE_integer('batch_size', 100, 'batch_size for train')
17 | flags.DEFINE_string('label2id_file', None, 'label2id file path')
18 | flags.DEFINE_string('vocab2id_file', None, 'vocab2id_file path')
19 |
20 | dh = DataHelper(None, None, FLAGS.test_data_path, FLAGS.map_file_path, FLAGS.batch_size, None, FLAGS.label2id_file,
21 | FLAGS.vocab2id_file, False)
22 |
23 | config = tf.ConfigProto()
24 | config.gpu_options.allow_growth = True
25 |
26 | with tf.Session(config=config) as sess:
27 | model_file = tf.train.latest_checkpoint(FLAGS.model_path)
28 | saver = tf.train.import_meta_graph("{}.meta".format(model_file))
29 | saver.restore(sess, model_file)
30 | graph = tf.get_default_graph()
31 | input_x = graph.get_tensor_by_name("input_x:0")
32 | length_x = graph.get_tensor_by_name("length_x:0")
33 | input_y = graph.get_tensor_by_name("input_y:0")
34 | length_y = graph.get_tensor_by_name("length_y:0")
35 | keep_prob = graph.get_tensor_by_name("keep_prob:0")
36 | q_y_raw = graph.get_tensor_by_name("representation/q_y_raw:0")
37 | qs_y_raw = graph.get_tensor_by_name("representation/qs_y_raw:0")
38 | # first get std tensor value
39 | length_y_value = [y[0] for y in dh.std_batch]
40 | input_y_value = [y[1] for y in dh.std_batch]
41 | # print("input_y_value: " + str(input_y_value))
42 | # print("input_y_value.shape: " + str(np.array(input_y_value, dtype=np.int32).shape))
43 | # print("length_y_value.shape: " + str(np.array(length_y_value, dtype=np.int32).shape))
44 | qs_y_raw_out = sess.run(qs_y_raw, feed_dict={input_y: np.array(input_y_value, dtype=np.int32),
45 | length_y: np.array(length_y_value, dtype=np.int32), keep_prob: 1.0})
46 |
47 | with tf.name_scope('cosine_similarity_pre'):
48 | # Cosine similarity
49 | q_norm_pre = tf.sqrt(tf.reduce_sum(tf.square(q_y_raw), 1, True)) # b*1
50 | qs_norm_pre = tf.transpose(tf.sqrt(tf.reduce_sum(tf.square(qs_y_raw_out), 1, True))) # 1*sb
51 | prod_nu_pre = tf.matmul(q_y_raw, tf.transpose(qs_y_raw_out)) # b*sb
52 | norm_prod_de = tf.matmul(q_norm_pre, qs_norm_pre) # b*sb
53 | cos_sim_pre = tf.truediv(prod_nu_pre, norm_prod_de) * FLAGS.softmax_r # b*sb
54 |
55 | with tf.name_scope('prob_pre'):
56 | prob_pre = tf.nn.softmax(cos_sim_pre) # b*sb
57 |
58 | test_batches = dh.test_batch_iterator()
59 | test_result_file = open(FLAGS.test_result_path, 'w', encoding='utf-8')
60 | # print(dh.predict_label_seq)
61 | for _, test_batch_q in enumerate(test_batches):
62 | # test_batch_q:[(l1, ws1, label1, line1), (l2, ws2, label2, line2), ...]
63 | length_x_value = [x[0] for x in test_batch_q]
64 | input_x_value = [x[1] for x in test_batch_q]
65 | test_prob = sess.run(prob_pre, feed_dict={input_x: np.array(input_x_value, dtype=np.int32),
66 | length_x: np.array(length_x_value, dtype=np.int32),
67 | keep_prob: 1.0}) # b*sb
68 | # print("test_prob: " + str(test_prob))
69 | for index, example in enumerate(test_batch_q):
70 | (_, _, real_label, words) = example
71 | result_str = str(real_label) + '\t' + str(words) + '\t'
72 | label_scores = {}
73 | # print(test_prob[index])
74 | sample_scores = test_prob[index]
75 | for score_index, real_label_score in enumerate(sample_scores):
76 | label_scores[dh.predict_label_seq[score_index]] = real_label_score
77 | sorted_list = sorted(label_scores.items(), key=lambda x: x[1], reverse=True)
78 | # print(str(sorted_list))
79 | for label, score in sorted_list:
80 | result_str = result_str + str(label) + ":" + str(score) + " "
81 | # write result
82 | test_result_file.write(result_str + '\n')
83 | test_result_file.close()
84 | # export model
85 | if os.path.isdir(FLAGS.export_model_dir):
86 | shutil.rmtree(FLAGS.export_model_dir)
87 | builder = tf.saved_model.builder.SavedModelBuilder(FLAGS.export_model_dir)
88 | pred_x = tf.saved_model.utils.build_tensor_info(input_x)
89 | pred_len_x = tf.saved_model.utils.build_tensor_info(length_x)
90 | drop_keep_prob = tf.saved_model.utils.build_tensor_info(keep_prob)
91 | probs = tf.saved_model.utils.build_tensor_info(prob_pre)
92 | # 定义方法名和输入输出
93 | signature_def_map = {
94 | "predict": tf.saved_model.signature_def_utils.build_signature_def(
95 | inputs={"input": pred_x, "length": pred_len_x, "keep_prob": drop_keep_prob},
96 | outputs={
97 | "probs": probs
98 | },
99 | method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
100 | )
101 | }
102 | builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING],
103 | signature_def_map=signature_def_map)
104 | builder.save()
105 |
--------------------------------------------------------------------------------
/lstm_predict.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from utils.classifier_utils import TextLoader
3 |
4 | flags = tf.app.flags
5 | FLAGS = flags.FLAGS
6 |
7 | flags.DEFINE_string('map_file_path', None, 'standard data path')
8 | flags.DEFINE_string("model_path", None, "checkpoint dir from predicting")
9 | flags.DEFINE_string('test_data_path', None, 'test data path')
10 | flags.DEFINE_string('test_result_path', None, 'test data result path')
11 | flags.DEFINE_integer('batch_size', 100, 'batch_size for train')
12 | flags.DEFINE_integer('seq_length', 40, 'seq_length')
13 | flags.DEFINE_string('label2id_file', None, 'label2id file path')
14 | flags.DEFINE_string('vocab2id_file', None, 'vocab2id_file path')
15 |
16 | # load vocab and label mapping
17 | vocab_id = {}
18 | vocab_file = open(FLAGS.vocab2id_file, 'r', encoding='utf-8')
19 | for line in vocab_file:
20 | word_ids = line.strip().split('\t')
21 | vocab_id[word_ids[0]] = word_ids[1]
22 | vocab_file.close()
23 | label_id = {}
24 | id_label = {}
25 | label_file = open(FLAGS.label2id_file, 'r', encoding='utf-8')
26 | for line in label_file:
27 | std_label_ids = line.strip().split('\t')
28 | label_id[std_label_ids[0]] = std_label_ids[1]
29 | id_label[std_label_ids[1]] = std_label_ids[0]
30 | # print("id_label: " + str(id_label))
31 |
32 | label_file.close()
33 | std_label_map = {}
34 | std_label_map_file = open(FLAGS.map_file_path, 'r', encoding='utf-8')
35 | for line in std_label_map_file:
36 | tokens = line.strip().split('\t')
37 | label = tokens[0]
38 | std_id = tokens[1]
39 | std_label_map[std_id] = label
40 |
41 | std_label_map_file.close()
42 |
43 | test_data_loader = TextLoader(False, FLAGS.test_data_path, FLAGS.map_file_path, FLAGS.batch_size, FLAGS.seq_length,
44 | vocab_id, label_id, std_label_map, 'utf8', False)
45 |
46 | config = tf.ConfigProto()
47 | config.gpu_options.allow_growth = True
48 |
49 | with tf.Session(config=config) as sess:
50 | model_file = tf.train.latest_checkpoint(FLAGS.model_path)
51 | saver = tf.train.import_meta_graph("{}.meta".format(model_file))
52 | saver.restore(sess, model_file)
53 | graph = tf.get_default_graph()
54 | input_x = graph.get_tensor_by_name("input_x:0")
55 | length_x = graph.get_tensor_by_name("x_len:0")
56 | keep_prob = graph.get_tensor_by_name("dropout_keep_prob:0")
57 | test_data_loader.reset_batch_pointer()
58 | prediction = graph.get_tensor_by_name("acc/prediction_softmax:0") # [batchsize, label_size]
59 | test_result_file = open(FLAGS.test_result_path, 'w', encoding='utf-8')
60 | for n in range(test_data_loader.num_batches):
61 | input_x_test, input_y_test, x_len_test, raw_lines = test_data_loader.next_batch()
62 | prediction_result = sess.run(prediction,
63 | feed_dict={input_x: input_x_test, length_x: x_len_test, keep_prob: 1.0})
64 | # print("n: " + str(n))
65 | # print("len(input_x_test): " + str(len(input_x_test)))
66 | # print("len(input_y_test): " + str(len(input_y_test)))
67 | # print("len(raw_lines): " + str(len(raw_lines)))
68 | assert len(input_y_test) == len(raw_lines)
69 | for i in range(len(raw_lines)):
70 | raw_line = raw_lines[i]
71 | # print("input_y_test[i]: " + str(input_y_test[i]))
72 | real_label = id_label.get(str(input_y_test[i]))
73 | label_scores = {}
74 | for j in range(len(prediction_result[i])):
75 | label = id_label.get(str(j))
76 | score = prediction_result[i][j]
77 | label_scores[label] = score
78 | sorted_list = sorted(label_scores.items(), key=lambda x: x[1], reverse=True)
79 | # print("real_label: " + str(type(real_label)))
80 | # print("raw_lines: " + str(raw_lines))
81 | result_str = str(real_label) + "\t" + str(raw_line) + "\t";
82 | for label, score in sorted_list:
83 | result_str = result_str + str(label) + ":" + str(score) + " "
84 | test_result_file.write(result_str + '\n')
85 | test_result_file.close()
86 |
--------------------------------------------------------------------------------
/merge_classifier_match_label.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python
2 | # encoding=utf-8
3 | # author wangyong
4 |
5 | """
6 | merge result for domain identification and intent recognition
7 | """
8 |
9 | import sys
10 | import logging
11 | logging.basicConfig(format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s',
12 | level=logging.DEBUG)
13 |
14 |
15 | # none
16 | MAXCLASS_NONE_HIGHSCORE = 0.93 # 0.6
17 | MAXCLASS_NONE_MIDSCORE = 0.9 # 0.4
18 |
19 | #attention: MAXCLASS_NONE_HIGHSCORE must >= MAXCLASS_NONE_MIDSCORE
20 | assert MAXCLASS_NONE_HIGHSCORE >= MAXCLASS_NONE_MIDSCORE
21 |
22 | MINCLASS_NONE_HIGHSCORE1 = 0.8
23 | MINCLASS_NONE_HIGHSCORE2 = 0.75
24 | MINCLASS_NONE_MIDSCORE1 = 0.65
25 | MINCLASS_NONE_MIDSCORE2 = 0.65
26 |
27 | assert MINCLASS_NONE_HIGHSCORE1 >= MINCLASS_NONE_MIDSCORE1
28 | assert MINCLASS_NONE_HIGHSCORE2 >= MINCLASS_NONE_MIDSCORE2
29 | # list
30 | MAXCLASS_LIST_HIGHSCORE = 0.9
31 | MAXCLASS_LIST_MIDSCORE = 0.65
32 | assert MAXCLASS_LIST_HIGHSCORE >= MAXCLASS_LIST_MIDSCORE
33 |
34 | MINCLASS_LIST_HIGHSCORE1 = 0.9
35 | MINCLASS_LIST_HIGHSCORE2 = 0.9
36 | MINCLASS_LIST_MIDSCORE1 = 0.4
37 | MINCLASS_LIST_MIDSCORE2 = 0.5
38 | assert MINCLASS_LIST_HIGHSCORE1 >= MINCLASS_LIST_MIDSCORE1
39 | assert MINCLASS_LIST_HIGHSCORE2 >= MINCLASS_LIST_MIDSCORE2
40 |
41 | # only
42 | MAXCLASS_ONLY_HIGHSCORE = 0.9
43 | MINCLASS_ONLY_HIGHSCORE1 = 0.01
44 | MINCLASS_ONLY_HIGHSCORE2 = 0.01
45 | MINCLASS_ONLY_MIDSCORE1 = 0.01
46 | MINCLASS_ONLY_MIDSCORE2 = 0.01
47 | assert MINCLASS_ONLY_HIGHSCORE1 >= MINCLASS_ONLY_MIDSCORE1
48 | assert MINCLASS_ONLY_HIGHSCORE2 >= MINCLASS_ONLY_MIDSCORE2
49 |
50 | MINCLASS_ZERO_SCORE = 0.00001
51 |
52 | MINCLASS_HIGH_SCORE_ONLY = 0.65
53 | MINCLASS_MID_SCORE_ONLY = 0.55
54 | assert MINCLASS_HIGH_SCORE_ONLY >= MINCLASS_MID_SCORE_ONLY
55 |
56 |
57 | MINCLASS_NUM = 3
58 |
59 |
60 | class LabelScore:
61 | def __init__(self):
62 | self.label = ''
63 | self.score = 0
64 |
65 |
66 | class MergeObj(object):
67 | def __init__(self):
68 | self.real_max_label = ''
69 | self.real_min_label = ''
70 | self.pre_max_top_label = ''
71 | self.pre_max_top_score = 0
72 | self.pre_min_top_label = ''
73 | self.pre_min_top_score = 0
74 | self.pre_min_label_scores = [0]
75 | self.merge_result = ''
76 |
77 |
78 | def get_max2min_label(max_min_class_file_dir):
79 | min_max_m = {}
80 | max_min_class_file = open(max_min_class_file_dir, 'r', encoding='utf-8')
81 | for line in max_min_class_file.readlines():
82 | mems = line.split("\t")
83 | max_label = mems[0]
84 | min_label = mems[1]
85 | min_max_m[min_label] = max_label
86 | max_min_class_file.close()
87 | return min_max_m
88 |
89 |
90 | def get_pre_label_scores(max_pre_file_d, min_pre_file_d):
91 | merge_items = []
92 | max_pre_file = open(max_pre_file_d, 'r', encoding='utf-8')
93 | for line in max_pre_file.readlines():
94 | line_items = line.split("\t")
95 | real_top_max_label = line_items[0]
96 | pre_top_max_label = line_items[2].split(' ')[0].split(":")[0]
97 | pre_top_max_score = float(line_items[2].split(' ')[0].split(":")[1])
98 | mer_obj = MergeObj()
99 | mer_obj.pre_max_top_label = pre_top_max_label
100 | mer_obj.pre_max_top_score = float(pre_top_max_score)
101 | mer_obj.real_max_label = real_top_max_label
102 | merge_items.append(mer_obj)
103 | max_pre_file.close()
104 | min_pre_file = open(min_pre_file_d, 'r', encoding='utf-8')
105 | index = 0
106 | for line in min_pre_file.readlines():
107 | mer_obj = merge_items[index]
108 | index = index + 1
109 | line_items = line.split("\t")
110 | real_min_label = line_items[0]
111 | label_scores_list = line_items[2].split(" ")
112 | pre_top_min_label = label_scores_list[0].split(":")[0]
113 | pre_top_min_score = float(label_scores_list[0].split(":")[1])
114 | mer_obj.real_min_label = real_min_label
115 | mer_obj.pre_min_top_label = pre_top_min_label
116 | mer_obj.pre_min_top_score = pre_top_min_score
117 | mer_obj.pre_min_label_scores = []
118 | scores_list = mer_obj.pre_min_label_scores
119 |
120 | for i in range(len(label_scores_list)):
121 | label_score = LabelScore()
122 | temp_labels = label_scores_list[i].split(":")
123 | if len(temp_labels) < 2:
124 | continue
125 | label_score.label = temp_labels[0]
126 | label_score.score = float(temp_labels[1])
127 | scores_list.append(label_score)
128 | min_pre_file.close()
129 | return merge_items
130 |
131 |
132 | def get_merge_result_each(str_type, merge_item):
133 | assert str_type in ('__label__none', '__label__only', '__label__list')
134 | if str_type == "__label__none":
135 | merge_item.merge_result = "__label__none"
136 | elif str_type == "__label__only":
137 | merge_item.merge_result = merge_item.pre_min_top_label + ":" + str(merge_item.pre_min_top_score)
138 | elif str_type == "__label__list":
139 | merge_item.merge_result = ""
140 | for i in range(len(merge_item.pre_min_label_scores)):
141 | if i == MINCLASS_NUM:
142 | break
143 | label = merge_item.pre_min_label_scores[i].label
144 | score = merge_item.pre_min_label_scores[i].score
145 | if score < MINCLASS_ZERO_SCORE:
146 | break
147 | merge_item.merge_result = merge_item.merge_result + label + ":" + str(score) + ","
148 |
149 |
150 | def get_only_list_none_result(high_score, low_score, merge_item):
151 | if merge_item.pre_min_top_score >= high_score: # one answer
152 | get_merge_result_each("__label__only", merge_item)
153 | elif merge_item.pre_min_top_score < high_score and merge_item.pre_min_top_score >= low_score: # list answer
154 | get_merge_result_each("__label__list", merge_item)
155 | else: # refuse to answer
156 | get_merge_result_each("__label__none", merge_item)
157 |
158 |
159 | def get_merge_result(merge_items, min_max_m):
160 | for merge_item in merge_items:
161 | if merge_item.pre_max_top_label == "__label__none": # none
162 | if merge_item.pre_max_top_score >= MAXCLASS_NONE_HIGHSCORE: # direct rejection
163 | get_merge_result_each("__label__none", merge_item)
164 | elif merge_item.pre_max_top_score >= MAXCLASS_NONE_MIDSCORE and merge_item.pre_max_top_score < MAXCLASS_NONE_HIGHSCORE: # tendency to reject
165 | get_only_list_none_result(MINCLASS_NONE_HIGHSCORE1, MINCLASS_NONE_MIDSCORE1, merge_item)
166 | else: # not tendency to reject
167 | get_only_list_none_result(MINCLASS_NONE_HIGHSCORE2, MINCLASS_NONE_MIDSCORE2, merge_item)
168 | elif merge_item.pre_max_top_label == "__label__list": # list
169 | if merge_item.pre_max_top_score >= MAXCLASS_LIST_HIGHSCORE: # direct answer a list
170 | get_merge_result_each("__label__list", merge_item)
171 | elif merge_item.pre_max_top_score >= MAXCLASS_LIST_MIDSCORE and merge_item.pre_max_top_score < MAXCLASS_LIST_HIGHSCORE: # tendency to answer list
172 | get_only_list_none_result(MINCLASS_LIST_HIGHSCORE1, MINCLASS_LIST_MIDSCORE1, merge_item)
173 | else: # not tendency to answer list
174 | get_only_list_none_result(MINCLASS_LIST_HIGHSCORE2, MINCLASS_LIST_MIDSCORE2, merge_item)
175 | else: # only
176 | filter_pre_min_label_scores = []
177 | for label_score in merge_item.pre_min_label_scores:
178 | max_label = min_max_m[label_score.label]
179 | if max_label != merge_item.pre_max_top_label:
180 | continue
181 | filter_pre_min_label_scores.append(label_score)
182 | merge_item.pre_min_label_scores = filter_pre_min_label_scores
183 | if len(filter_pre_min_label_scores) == 0: # direct rejection
184 | get_merge_result_each("__label__none", merge_item)
185 | else:
186 | merge_item.pre_min_top_label = filter_pre_min_label_scores[0].label
187 | merge_item.pre_min_top_score = filter_pre_min_label_scores[0].score
188 | if merge_item.pre_max_top_score >= MAXCLASS_ONLY_HIGHSCORE: # not tendency to reject
189 | get_only_list_none_result(MINCLASS_ONLY_HIGHSCORE1, MINCLASS_ONLY_MIDSCORE1, merge_item)
190 | else: # not tendency to one answer
191 | get_only_list_none_result(MINCLASS_ONLY_HIGHSCORE2, MINCLASS_ONLY_MIDSCORE2, merge_item)
192 |
193 |
194 | def write_result(merge_items, result_file_d):
195 | min_pre_file = open(result_file_d, 'w', encoding='utf-8')
196 | for merge_item in merge_items:
197 | min_pre_file.write(merge_item.real_max_label + "\t" + merge_item.real_min_label
198 | + "\t" + merge_item.merge_result + "\n")
199 | min_pre_file.close()
200 |
201 |
202 | def get_result_by_min(min_pre_file_dir, result_file_dir):
203 | with open(min_pre_file_dir, 'r', encoding='utf-8') as f_pre_min:
204 | with open(result_file_dir, 'w', encoding='utf-8') as f_res:
205 | for line in f_pre_min:
206 | lines = line.strip().split('\t')
207 | real_label = lines[0]
208 | model_label_scores = lines[2].split(' ')
209 | temp_label_score_list = []
210 | write_str = '__label__0\t' + str(real_label) + '\t'
211 | for label_score in model_label_scores:
212 | label_scores = label_score.split(':')
213 | temp_label_score = LabelScore()
214 | temp_label_score.label = label_scores[0]
215 | temp_label_score.score = (float)(label_scores[1])
216 | temp_label_score_list.append(temp_label_score)
217 | if temp_label_score_list[0].score < MINCLASS_MID_SCORE_ONLY: # refuse answer
218 | write_str += '__label__none'
219 | elif temp_label_score_list[0].score >= MINCLASS_MID_SCORE_ONLY and temp_label_score_list[
220 | 0].score < MINCLASS_HIGH_SCORE_ONLY: # list answer
221 | for i in range(len(temp_label_score_list)):
222 | if i == MINCLASS_NUM:
223 | break
224 | write_str += str(temp_label_score_list[i].label) + ':' + str(
225 | temp_label_score_list[i].score) + ','
226 | else: # only answer
227 | write_str += str(temp_label_score_list[0].label) + ':' + str(temp_label_score_list[0].score)
228 | f_res.write(write_str + "\n")
229 |
230 |
231 | def get_acc_recall_f1(result_file_dir):
232 | only_real_num = 0
233 | only_model_num = 0
234 | only_right_num = 0
235 | list_real_num = 0
236 | list_model_num = 0
237 | list_right_num = 0
238 | none_real_num = 0
239 | none_model_num = 0
240 | none_right_num = 0
241 | num = 0
242 | with open(result_file_dir, 'r', encoding='utf-8') as f_pre:
243 | for line in f_pre:
244 | num = num + 1
245 | lines = line.strip().split('\t')
246 | if lines[1] == '0':
247 | none_real_num = none_real_num + 1
248 | elif ',' in lines[1]:
249 | list_real_num = list_real_num + 1
250 | else:
251 | only_real_num = only_real_num + 1
252 | model_label_scores = lines[2].split(',')
253 | if lines[2] == '__label__none':
254 | none_model_num = none_model_num + 1
255 | elif len(model_label_scores) == 1:
256 | only_model_num = only_model_num + 1
257 | else:
258 | list_model_num = list_model_num + 1
259 | real_labels_set = set(lines[1].split(','))
260 | if lines[1] == '0' and lines[2] == '__label__none':
261 | none_right_num = none_right_num + 1
262 | if len(real_labels_set) == 1 and len(model_label_scores) == 1 and lines[1] == lines[2].split(':')[0]:
263 | only_right_num = only_right_num + 1
264 | if len(real_labels_set) > 1 and len(model_label_scores) > 1:
265 | for i in range(len(model_label_scores)):
266 | label_scores = model_label_scores[i].split(":")
267 | if label_scores[0] in real_labels_set:
268 | list_right_num = list_right_num + 1
269 | break
270 | logging.info('none_right_num: ' + str(none_right_num) + ', list_right_num: ' + str(list_right_num)
271 | + ', only_right_num: ' + str(only_right_num))
272 | logging.info('none_real_num: ' + str(none_real_num) + ', list_real_num: ' + str(list_real_num)
273 | + ', only_real_num: ' + str(only_real_num))
274 | logging.info('none_model_num: ' + str(none_model_num) + ', list_model_num: ' + str(list_model_num)
275 | + ', only_model_num: ' + str(only_model_num))
276 | all_right_num = list_right_num + only_right_num
277 | all_real_num = list_real_num + only_real_num
278 | all_model_num = list_model_num + only_model_num
279 | logging.info('all_right_num: ' + str(all_right_num) + ', all_real_num: ' + str(all_real_num)
280 | + ', all_model_num: ' + str(all_model_num))
281 | all_acc = all_right_num / all_model_num
282 | all_recall = all_right_num / all_real_num
283 | all_f1 = 2 * all_acc * all_recall / (all_acc + all_recall)
284 | logging.info("all_acc: " + str(all_acc) + ", all_recall: " + str(all_recall) + ", all_f1: " + str(all_f1))
285 | only_acc = only_right_num / only_model_num
286 | only_recall = only_right_num / only_real_num
287 | only_f1 = 2 * only_acc * only_recall / (only_acc + only_recall)
288 | logging.info("only_acc: " + str(only_acc) + ", only_recall: " + str(only_recall) + ", only_f1: " + str(
289 | only_f1) + ", only_real_prop: " + str(only_real_num / num) + ", only_model_prop: " + str(only_model_num / num))
290 | list_acc = list_right_num / list_model_num
291 | list_recall = list_right_num / list_real_num
292 | list_f1 = 2 * list_acc * list_recall / (list_acc + list_recall)
293 | logging.info("list_acc: " + str(list_acc) + ", list_recall: " + str(list_recall) + ", list_f1: " + str(
294 | list_f1) + ", list_real_prop: " + str(list_real_num / num) + ", list_model_prop: " + str(list_model_num / num))
295 | none_acc = none_right_num / none_model_num
296 | none_recall = none_right_num / none_real_num
297 | none_f1 = 2 * none_acc * none_recall / (none_acc + none_recall)
298 | logging.info("none_acc: " + str(none_acc) + ", none_recall: " + str(none_recall) + ", none_f1: " + str(
299 | none_f1) + ", none_real_prop: " + str(none_real_num / num) + ", none_model_prop: " + str(none_model_num / num))
300 |
301 |
302 | if __name__ == "__main__":
303 | max_pre_file_dir = sys.argv[1]
304 | min_pre_file_dir = sys.argv[2]
305 | result_file_dir = sys.argv[3]
306 | std_label_ques = sys.argv[4]
307 | if max_pre_file_dir == 'none' or std_label_ques == 'none': # only use min_pre result
308 | get_result_by_min(min_pre_file_dir, result_file_dir)
309 | else: # merge max_pre result and min_pre result
310 | merge_items_list = get_pre_label_scores(max_pre_file_dir, min_pre_file_dir)
311 | min_max_map = get_max2min_label(std_label_ques)
312 | get_merge_result(merge_items_list, min_max_map)
313 | write_result(merge_items_list, result_file_dir)
314 | # get acc recall f1
315 | get_acc_recall_f1(result_file_dir)
316 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/models/__init__.py
--------------------------------------------------------------------------------
/models/bilstm.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | """
4 | a bi-lstm implementation for short text classification using tensroflow library
5 |
6 | """
7 |
8 | from __future__ import print_function
9 |
10 | import tensorflow as tf
11 | from tensorflow.contrib import rnn
12 |
13 |
14 | class BiLSTM(object):
15 |
16 | def __init__(self, FLAGS):
17 | """Constructor for BiLSTM
18 |
19 | Args:
20 | FLAGS: tf.app.flags, you can see the FLAGS of run_bi_lstm.py
21 | """
22 | self.input_x = tf.placeholder(tf.int64, [None, FLAGS.seq_length], name="input_x")
23 | self.input_y = tf.placeholder(tf.int64, [None, ], name="input_y")
24 | self.x_len = tf.placeholder(tf.int64, [None, ], name="x_len")
25 | self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob")
26 |
27 | with tf.variable_scope("embedding", initializer=tf.orthogonal_initializer()):
28 | with tf.device('/cpu:0'):
29 | # word embedding table
30 | self.vocab = tf.get_variable('w', [FLAGS.vocab_size, FLAGS.embedding_size])
31 | embedded = tf.nn.embedding_lookup(self.vocab, self.input_x) # [batch_size, seq_length, embedding_size]
32 | inputs = tf.split(embedded, FLAGS.seq_length,
33 | 1) # [[batch_size, 1, embedding_size], [batch_size, 1, embedding_size], number is seq_length]
34 | inputs = [tf.squeeze(input_, [1]) for input_ in
35 | inputs] # [[batch_size, embedding_size], [batch_size, embedding_size], number is seq_length]
36 |
37 | with tf.variable_scope("encoder", initializer=tf.orthogonal_initializer()):
38 | lstm_fw_cell = rnn.BasicLSTMCell(FLAGS.num_units)
39 | lstm_bw_cell = rnn.BasicLSTMCell(FLAGS.num_units)
40 | lstm_fw_cell_stack = rnn.MultiRNNCell([lstm_fw_cell] * FLAGS.lstm_layers, state_is_tuple=True)
41 | lstm_bw_cell_stack = rnn.MultiRNNCell([lstm_bw_cell] * FLAGS.lstm_layers, state_is_tuple=True)
42 | lstm_fw_cell_stack = rnn.DropoutWrapper(lstm_fw_cell_stack, input_keep_prob=self.dropout_keep_prob,
43 | output_keep_prob=self.dropout_keep_prob)
44 | lstm_bw_cell_stack = rnn.DropoutWrapper(lstm_bw_cell_stack, input_keep_prob=self.dropout_keep_prob,
45 | output_keep_prob=self.dropout_keep_prob)
46 | self.outputs, self.fw_st, self.bw_st = rnn.static_bidirectional_rnn(lstm_fw_cell_stack, lstm_bw_cell_stack,
47 | inputs, sequence_length=self.x_len,
48 | dtype=tf.float32) # multi-layer
49 | # only use the last layer
50 | last_layer_no = FLAGS.lstm_layers - 1
51 | self.states = tf.concat([self.fw_st[last_layer_no].h, self.bw_st[last_layer_no].h],
52 | 1) # [batchsize, (num_units * 2)]
53 |
54 | attention_size = 2 * FLAGS.num_units
55 | with tf.variable_scope('attention'):
56 | attention_w = tf.Variable(tf.truncated_normal([2 * FLAGS.num_units, attention_size], stddev=0.1),
57 | name='attention_w') # [num_units * 2, num_units * 2]
58 | attention_b = tf.get_variable("attention_b", initializer=tf.zeros([attention_size])) # [num_units * 2]
59 | u_list = []
60 | for index in range(FLAGS.seq_length):
61 | u_t = tf.tanh(tf.matmul(self.outputs[index], attention_w) + attention_b) # [batchsize, num_units * 2]
62 | u_list.append(u_t) # seq_length * [batchsize, num_units * 2]
63 | u_w = tf.Variable(tf.truncated_normal([attention_size, 1], stddev=0.1),
64 | name='attention_uw') # [num_units * 2, 1]
65 | attn_z = []
66 | for index in range(FLAGS.seq_length):
67 | z_t = tf.matmul(u_list[index], u_w)
68 | attn_z.append(z_t) # seq_length * [batchsize, 1]
69 | # transform to batch_size * sequence_length
70 | attn_zconcat = tf.concat(attn_z, axis=1) # [batchsize, seq_length]
71 | alpha = tf.nn.softmax(attn_zconcat) # [batchsize, seq_length]
72 | # transform to sequence_length * batch_size * 1 , same rank as outputs
73 | alpha_trans = tf.reshape(tf.transpose(alpha, [1, 0]),
74 | [FLAGS.seq_length, -1, 1]) # [seq_length, batchsize, 1]
75 | self.final_output = tf.reduce_sum(self.outputs * alpha_trans, 0) # [batchsize, num_units * 2]
76 |
77 | with tf.variable_scope("output_layer"):
78 | weights = tf.get_variable("weights", [2 * FLAGS.num_units, FLAGS.label_size])
79 | biases = tf.get_variable("biases", initializer=tf.zeros([FLAGS.label_size]))
80 |
81 | with tf.variable_scope("acc"):
82 | # use attention
83 | self.logits = tf.matmul(self.final_output, weights) + biases # [batchsize, label_size]
84 | # not use attention
85 | # self.logits = tf.matmul(self.states, weights) + biases
86 | self.prediction = tf.nn.softmax(self.logits, name="prediction_softmax") # [batchsize, label_size]
87 | self.loss = tf.reduce_mean(
88 | tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y))
89 | self.global_step = tf.train.get_or_create_global_step()
90 | self.correct = tf.equal(tf.argmax(self.prediction, 1), self.input_y)
91 | self.acc = tf.reduce_mean(tf.cast(self.correct, tf.float32))
92 | _, self.arg_index = tf.nn.top_k(self.prediction, k=FLAGS.label_size) # [batch_size, label_size]
93 |
94 | with tf.variable_scope('training'):
95 | # optimizer
96 | self.learning_rate = tf.train.exponential_decay(FLAGS.lr, self.global_step, 200, 0.96, staircase=True)
97 | self.train_step = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss,
98 | global_step=self.global_step)
99 |
100 | self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=2)
101 |
102 | def export_model(self, export_path, sess):
103 | builder = tf.saved_model.builder.SavedModelBuilder(export_path)
104 | tensor_info_x = tf.saved_model.utils.build_tensor_info(self.input_x)
105 | tensor_info_y = tf.saved_model.utils.build_tensor_info(self.prediction)
106 | tensor_info_len = tf.saved_model.utils.build_tensor_info(self.x_len)
107 | tensor_dropout_keep_prob = tf.saved_model.utils.build_tensor_info(self.dropout_keep_prob) # 1.0 for inference
108 | prediction_signature = (
109 | tf.saved_model.signature_def_utils.build_signature_def(
110 | inputs={'input': tensor_info_x, 'sen_len': tensor_info_len,
111 | 'dropout_keep_prob': tensor_dropout_keep_prob},
112 | outputs={'output': tensor_info_y},
113 | method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
114 | legacy_init_op = None
115 | builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING],
116 | signature_def_map={'prediction': prediction_signature, },
117 | legacy_init_op=legacy_init_op, clear_devices=True, saver=self.saver)
118 | builder.save()
119 |
--------------------------------------------------------------------------------
/models/dssm.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | """
4 | a lstm + dssm implementation for short text match using tensroflow library
5 |
6 | """
7 |
8 | import random
9 | import tensorflow as tf
10 | from tensorflow.contrib import rnn
11 |
12 |
13 | class Dssm(object):
14 | def __init__(self, num_lstm_units, batch_size, negtive_size, SOFTMAX_R, learning_rate, vocab_size,
15 | embedding_size=100, use_same_cell=False):
16 | """Constructor for Dssm
17 |
18 | Args:
19 | num_lstm_units: int, The number of units in the LSTM cell.
20 | batch_size: int, The number of examples in each batch
21 | negtive_size: int, The number of negative example.
22 | SOFTMAX_R: float, A regulatory factor for cosine similarity
23 | learning_rate: float, learning rate
24 | vocab_size: int, The number of vocabulary
25 | embedding_size: int the size of vocab embedding
26 | use_same_cell: (optional) bool whether to use same cell for fw, bw lstm, default is false
27 | """
28 | self.global_step = tf.Variable(0, trainable=False)
29 | self.keep_prob = tf.placeholder(tf.float32, name="keep_prob")
30 | self.input_x = tf.placeholder(tf.int32, [None, None], name="input_x") # [batch_size, seq_len]
31 | self.length_x = tf.placeholder(tf.int32, [None, ], name="length_x") # [batch_size, ]
32 | self.input_y = tf.placeholder(tf.int32, [None, None], name="input_y") # [batch_size, seq_len]
33 | self.length_y = tf.placeholder(tf.int32, [None, ], name="length_y") # [batch_size, ]
34 | self.lstm_fw_cell = rnn.BasicLSTMCell(num_lstm_units)
35 | if use_same_cell:
36 | self.lstm_bw_cell = self.lstm_fw_cell
37 | else:
38 | self.lstm_bw_cell = rnn.BasicLSTMCell(num_lstm_units)
39 | with tf.name_scope("keep_prob"):
40 | self.lstm_fw_cell = rnn.DropoutWrapper(self.lstm_fw_cell, input_keep_prob=self.keep_prob,
41 | output_keep_prob=self.keep_prob)
42 | self.lstm_bw_cell = rnn.DropoutWrapper(self.lstm_bw_cell, input_keep_prob=self.keep_prob,
43 | output_keep_prob=self.keep_prob)
44 |
45 | with tf.device('/cpu:0'), tf.name_scope("embedding"):
46 | # one_gram
47 | self.vocab = tf.get_variable('w', [vocab_size, embedding_size])
48 | self.lstm_input_embedding_x = tf.nn.embedding_lookup(self.vocab,
49 | self.input_x) # [batch_size, seq_len, embedding_size]
50 | self.lstm_input_embedding_y = tf.nn.embedding_lookup(self.vocab,
51 | self.input_y) # [batch_size, seq_len, embedding_size]
52 |
53 | with tf.name_scope('representation'):
54 | self.states_x = \
55 | tf.nn.bidirectional_dynamic_rnn(self.lstm_fw_cell, self.lstm_bw_cell, self.lstm_input_embedding_x,
56 | self.length_x,
57 | dtype=tf.float32)[1]
58 | self.output_x = tf.concat([self.states_x[0][1], self.states_x[1][1]], 1) # [batch_size, 2*num_lstm_units]
59 | self.states_y = \
60 | tf.nn.bidirectional_dynamic_rnn(self.lstm_fw_cell, self.lstm_bw_cell, self.lstm_input_embedding_y,
61 | self.length_y,
62 | dtype=tf.float32)[1]
63 | self.output_y = tf.concat([self.states_y[0][1], self.states_y[1][1]], 1) # [batch_size, 2*num_lstm_units]
64 | self.q_y_raw = tf.nn.relu(self.output_x, name="q_y_raw") # [batch_size, num_lstm_units*2]
65 | print("self.q_y_raw: " + str(self.q_y_raw))
66 | self.qs_y_raw = tf.nn.relu(self.output_y, name="qs_y_raw") # [batch_size, num_lstm_units*2]
67 | print("self.qs_y_raw: " + str(self.qs_y_raw))
68 |
69 | with tf.name_scope('rotate'):
70 | temp = tf.tile(self.qs_y_raw, [1, 1]) # [batch_size, num_lstm_units*2]
71 | self.qs_y = tf.tile(self.qs_y_raw, [1, 1]) # [batch_size, num_lstm_units*2]
72 | for i in range(negtive_size):
73 | rand = int((random.random() + i) * batch_size / negtive_size)
74 | if rand == 0:
75 | rand = rand + 1
76 | rand_qs_y1 = tf.slice(temp, [rand, 0], [batch_size - rand, -1]) # [batch_size - rand, num_lstm_units*2]
77 | rand_qs_y2 = tf.slice(temp, [0, 0], [rand, -1]) # [rand, num_lstm_units*2]
78 | self.qs_y = tf.concat(axis=0, values=[self.qs_y, rand_qs_y1,
79 | rand_qs_y2]) # [batch_size*(negtive_size+1), num_lstm_units*2]
80 |
81 | with tf.name_scope('sim'):
82 | # cosine similarity
83 | q_norm = tf.tile(tf.sqrt(tf.reduce_sum(tf.square(self.q_y_raw), 1, True)),
84 | [negtive_size + 1, 1]) # [(negtive_size + 1) * batch_size, 1]
85 | qs_norm = tf.sqrt(tf.reduce_sum(tf.square(self.qs_y), 1, True)) # [batch_size*(negtive_size+1), 1]
86 | prod = tf.reduce_sum(tf.multiply(tf.tile(self.q_y_raw, [negtive_size + 1, 1]), self.qs_y), 1,
87 | True) # [batch_size*(negtive_size + 1), 1]
88 | norm_prod = tf.multiply(q_norm, qs_norm) # [batch_size*(negtive_size + 1), 1]
89 | sim_raw = tf.truediv(prod, norm_prod) # [batch_size*(negtive_size + 1), 1]
90 | self.cos_sim = tf.transpose(tf.reshape(tf.transpose(sim_raw), [negtive_size + 1,
91 | batch_size])) * SOFTMAX_R # [batch_size, negtive_size + 1]
92 |
93 | with tf.name_scope('loss'):
94 | # train Loss
95 | self.prob = tf.nn.softmax(self.cos_sim) # [batch_size, negtive_size + 1]
96 | self.hit_prob = tf.slice(self.prob, [0, 0], [-1, 1]) # [batch_size, 1] #positive
97 | raw_loss = -tf.reduce_sum(tf.log(self.hit_prob)) / batch_size
98 | self.loss = raw_loss
99 |
100 | with tf.name_scope('training'):
101 | # optimizer
102 | self.learning_rate = tf.train.exponential_decay(learning_rate, self.global_step, 1000, 0.96, staircase=True)
103 | self.train_step = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss,
104 | global_step=self.global_step)
105 |
106 | # acc for test data
107 | with tf.name_scope('cosine_similarity_pre'):
108 | # Cosine similarity
109 | self.q_norm_pre = tf.sqrt(tf.reduce_sum(tf.square(self.q_y_raw), 1, True)) # b*1
110 | self.qs_norm_pre = tf.transpose(tf.sqrt(tf.reduce_sum(tf.square(self.qs_y_raw), 1, True))) # 1*sb
111 | self.prod_nu_pre = tf.matmul(self.q_y_raw, tf.transpose(self.qs_y_raw)) # b*sb
112 | self.norm_prod_de = tf.matmul(self.q_norm_pre, self.qs_norm_pre) # b*sb
113 | self.cos_sim_pre = tf.truediv(self.prod_nu_pre, self.norm_prod_de) * SOFTMAX_R # b*sb
114 |
115 | with tf.name_scope('prob_pre'):
116 | self.prob_pre = tf.nn.softmax(self.cos_sim_pre) # b*sb
117 | # self.hit_prob_pre = tf.slice(self.prob_pre, [0, 0], [-1, 1]) # [batch_size, 1] #positive
118 | # self.test_loss = -tf.reduce_sum(tf.log(self.hit_prob_pre)) / batch_size
119 |
--------------------------------------------------------------------------------
/run_bi_lstm.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | """
4 | running bi-lstm for short text classification
5 |
6 | """
7 |
8 | import os
9 | import tensorflow as tf
10 | import shutil
11 | from utils.classifier_utils import TextLoader
12 | from models.bilstm import BiLSTM
13 |
14 | flags = tf.flags
15 | FLAGS = flags.FLAGS
16 |
17 | flags.DEFINE_string("train_path", None, "dir for train data")
18 | flags.DEFINE_string("valid_path", None, "dir for valid data")
19 | flags.DEFINE_string("map_file_path", None, "dir for label std question mapping")
20 | flags.DEFINE_string("model_path", None, "dir for save checkpoint data")
21 | # flags.DEFINE_string("result_file", None, "file for valid result")
22 | flags.DEFINE_string("vocab_file", None, "file for vocab")
23 | flags.DEFINE_string("label_file", None, "file for label")
24 | flags.DEFINE_integer("embedding_size", 256, "size of word embedding")
25 | flags.DEFINE_integer("num_units", 256, "The number of units in the LSTM cell")
26 | flags.DEFINE_integer("vocab_size", 256, "The size of vocab")
27 | flags.DEFINE_integer("label_size", 256, "The num of label")
28 | flags.DEFINE_integer("batch_size", 128, "batch_size of train data")
29 | flags.DEFINE_integer("seq_length", 50, "the length of sequence")
30 | flags.DEFINE_integer("num_epcho", 30, "the epcho num")
31 | flags.DEFINE_integer("check_every", 100, "the epcho num")
32 | flags.DEFINE_integer("lstm_layers", 2, "the layers of lstm")
33 | flags.DEFINE_float("lr", 0.001, "learning rate")
34 | flags.DEFINE_float("dropout_keep_prob", 0.8, "drop_out keep prob")
35 |
36 |
37 | def main(_):
38 | tf.logging.set_verbosity(tf.logging.INFO)
39 | data_loader = TextLoader(True, FLAGS.train_path, FLAGS.map_file_path, FLAGS.batch_size, FLAGS.seq_length, None,
40 | None, None, 'utf8', False)
41 | valid_data_loader = TextLoader(False, FLAGS.valid_path, FLAGS.map_file_path, FLAGS.batch_size, FLAGS.seq_length,
42 | data_loader.vocab,
43 | data_loader.labels, data_loader.std_label_map, 'utf8', False)
44 | tf.logging.info("vocab_size: " + str(data_loader.vocab_size))
45 | FLAGS.vocab_size = data_loader.vocab_size
46 | tf.logging.info("label_size: " + str(data_loader.label_size))
47 | FLAGS.label_size = data_loader.label_size
48 | bilstm = BiLSTM(FLAGS)
49 | init = tf.global_variables_initializer()
50 | config = tf.ConfigProto(allow_soft_placement=True)
51 | config.gpu_options.allow_growth = True
52 | with tf.Session(config=config) as sess:
53 | sess.run(init)
54 | idx = 0
55 | test_best_acc = 0
56 | for epcho in range(FLAGS.num_epcho): # for each epoch
57 | data_loader.reset_batch_pointer()
58 | for train_batch_num in range(data_loader.num_batches): # for each batch
59 | input_x, input_y, x_len, _ = data_loader.next_batch()
60 | feed = {bilstm.input_x: input_x, bilstm.input_y: input_y, bilstm.x_len: x_len,
61 | bilstm.dropout_keep_prob: FLAGS.dropout_keep_prob}
62 | _, global_step_op, train_loss, train_acc = sess.run(
63 | [bilstm.train_step, bilstm.global_step, bilstm.loss, bilstm.acc], feed_dict=feed)
64 | tf.logging.info("training...........global_step = {}, epoch = {}, current_batch = {}, "
65 | "train_loss = {:.4f}, accuracy = {:.4f}".format(global_step_op, epcho, train_batch_num,
66 | train_loss, train_acc))
67 | idx += 1
68 | if idx % FLAGS.check_every == 0:
69 | all_num = 0
70 | acc_num = 0
71 | valid_data_loader.reset_batch_pointer()
72 | write_result = []
73 | for _ in range(valid_data_loader.num_batches):
74 | input_x_valid, input_y_valid, x_len_valid, _ = valid_data_loader.next_batch()
75 | feed = {bilstm.input_x: input_x_valid, bilstm.input_y: input_y_valid, bilstm.x_len: x_len_valid,
76 | bilstm.dropout_keep_prob: 1.0}
77 | prediction, arg_index = sess.run([bilstm.prediction, bilstm.arg_index], feed_dict=feed)
78 | all_num = all_num + len(input_y_valid)
79 | # write_str = ""
80 | for i, indexs in enumerate(arg_index):
81 | pre_label_id = indexs[0]
82 | real_label_id = input_y_valid[i]
83 | if pre_label_id == real_label_id:
84 | acc_num = acc_num + 1
85 | # if real_label_id in valid_data_loader.id_2_label:
86 | # write_str = valid_data_loader.id_2_label.get(real_label_id)
87 | # else:
88 | # write_str = "__label__unknown"
89 | # for index in indexs:
90 | # cur_label = valid_data_loader.id_2_label.get(index)
91 | # cur_score = prediction[i][index]
92 | # write_str = write_str + " " + cur_label + ":" + str(cur_score)
93 | # write_str = write_str + "\n"
94 | # write_result.append(write_str)
95 | test_acc = acc_num * 1.0 / all_num
96 | tf.logging.info(
97 | "testing...........global_step = {}, epoch = {}, accuracy = {:.4f}, cur_best_acc = {}".format(
98 | global_step_op, epcho, test_acc, test_best_acc))
99 | if test_best_acc < test_acc:
100 | test_best_acc = test_acc
101 | # save_model
102 | if not os.path.exists(FLAGS.model_path):
103 | os.makedirs(FLAGS.model_path)
104 | checkpoint_path = os.path.join(FLAGS.model_path, 'lstm_model')
105 | bilstm.saver.save(sess, checkpoint_path, global_step=global_step_op)
106 | # export model
107 | export_path = os.path.join(FLAGS.model_path, 'lstm_tf_serving')
108 | if os.path.isdir(export_path):
109 | shutil.rmtree(export_path)
110 | bilstm.export_model(export_path, sess)
111 | # resultfile = open(FLAGS.result_file, 'w', encoding='utf-8')
112 | # for pre_sen in write_result:
113 | # resultfile.write(pre_sen)
114 | tf.logging.info(
115 | "has saved model and write.result...................................................................")
116 | # resultfile.close()
117 | # save label and vocab
118 | vocabfile = open(FLAGS.vocab_file, 'w', encoding='utf-8')
119 | for key, value in data_loader.vocab.items():
120 | vocabfile.write(str(key) + "\t" + str(value) + '\n')
121 | vocabfile.close()
122 | labelfile = open(FLAGS.label_file, 'w', encoding='utf-8')
123 | for key, value in data_loader.labels.items():
124 | labelfile.write(str(key) + "\t" + str(value) + '\n')
125 | labelfile.close()
126 | # break
127 |
128 |
129 | if __name__ == "__main__":
130 | tf.app.run()
131 |
--------------------------------------------------------------------------------
/run_dssm.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | """
4 | running lstm + dssm for short text matching
5 |
6 | """
7 |
8 | import numpy as np
9 | import tensorflow as tf
10 | from utils.match_utils import DataHelper
11 | from models.dssm import Dssm
12 |
13 | flags = tf.app.flags
14 | FLAGS = flags.FLAGS
15 |
16 | # data parameters
17 | flags.DEFINE_string('train_path', None, 'dir for train data')
18 | flags.DEFINE_string('valid_path', None, 'dir for valid data')
19 | flags.DEFINE_string('map_file_path', None, 'dir for label std question mapping')
20 | flags.DEFINE_string('model_path', None, 'Model path')
21 | flags.DEFINE_string('label2id_path', None, 'label2id file path')
22 | flags.DEFINE_string('vocab2id_path', None, 'vocab2id file path')
23 |
24 | # training parameters
25 | flags.DEFINE_integer('softmax_r', 45, 'Smooth parameter for osine similarity')
26 | flags.DEFINE_integer('embedding_size', 200, 'max_sequence_len')
27 | flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
28 |
29 | flags.DEFINE_float('keep_prob', 0.8, 'Dropout keep prob.')
30 | flags.DEFINE_integer('num_epoches', 10, "Number of epochs.")
31 | flags.DEFINE_integer('batch_size', 50, "Size of one batch.")
32 | flags.DEFINE_integer('negative_size', 5, "Size of negtive sample.")
33 | flags.DEFINE_integer('eval_every', 50, "Record summaries every n steps.")
34 | flags.DEFINE_integer('num_units', 100, "Number of units of lstm(default: 100)")
35 | flags.DEFINE_bool('use_same_cell', True, "whether to use sam cell")
36 |
37 |
38 | def feed_dict_builder(batch, keep_prob, dssm):
39 | # batch: ([(q1_len, [q1_w1, q1_w2,...]), (q2_len, [q2_w1, q2_w2,...]), ...], [(std1_len, [std1_w1, std1_w2,...]), (std2_len, [std2_w1, std2_w2,...]), ...])
40 | length_x = [x[0] for x in batch[0]]
41 | input_x = [x[1] for x in batch[0]]
42 | length_y = [y[0] for y in batch[1]]
43 | input_y = [y[1] for y in batch[1]]
44 | feed_dict = {
45 | dssm.input_x: np.array(input_x, dtype=np.int32),
46 | dssm.length_x: np.array(length_x, dtype=np.int32),
47 | dssm.input_y: np.array(input_y, dtype=np.int32),
48 | dssm.length_y: np.array(length_y, dtype=np.int32),
49 | dssm.keep_prob: keep_prob
50 | }
51 | return feed_dict
52 |
53 |
54 | def cal_predict_acc_num(predict_prob, test_batch_q, predict_label_seq):
55 | # calculate acc
56 | assert (len(test_batch_q) == len(predict_prob))
57 | real_labels = []
58 | for ques in test_batch_q:
59 | label = ques[2]
60 | real_labels.append(label)
61 | acc_num = 0
62 | sorted_scores = []
63 | for i, scores in enumerate(predict_prob):
64 | label_scores = {}
65 | for index, score in enumerate(scores):
66 | label_scores[predict_label_seq[index]] = score
67 | # sort
68 | label_scores = sorted(label_scores.items(), key=lambda x: x[1], reverse=True)
69 | sorted_scores.append(label_scores)
70 | top_label = label_scores[0][0]
71 | if top_label == real_labels[i]:
72 | acc_num = acc_num + 1
73 | return acc_num, real_labels, sorted_scores
74 |
75 |
76 | def main(_):
77 | tf.logging.set_verbosity(tf.logging.INFO)
78 | data_help = DataHelper(FLAGS.train_path, FLAGS.valid_path, None, FLAGS.map_file_path, FLAGS.batch_size,
79 | FLAGS.num_epoches, None, None, True)
80 | dssm = Dssm(FLAGS.num_units, FLAGS.batch_size, FLAGS.negative_size, FLAGS.softmax_r, FLAGS.learning_rate,
81 | data_help.vocab_size, FLAGS.embedding_size, use_same_cell=False)
82 | config = tf.ConfigProto(allow_soft_placement=True)
83 | config.gpu_options.allow_growth = True
84 | saver = tf.train.Saver(max_to_keep=1)
85 | train_batches = data_help.train_batch_iterator(data_help.train_id_ques, data_help.std_id_ques)
86 | best_valid_acc = 0
87 | # run_num = 0
88 | with tf.Session(config=config) as sess:
89 | sess.run(tf.global_variables_initializer())
90 | for train_batch_step, train_batch in enumerate(train_batches):
91 | _, step, train_lr, train_loss = sess.run([dssm.train_step, dssm.global_step, dssm.learning_rate, dssm.loss],
92 | feed_dict=feed_dict_builder(train_batch, FLAGS.keep_prob, dssm))
93 | tf.logging.info("Training...... global_step {}, epcho {}, train_batch_step {}, learning rate {} "
94 | "loss {}".format(step, round(step * FLAGS.batch_size / data_help.train_num, 2),
95 | train_batch_step, round(train_lr, 4), train_loss))
96 | if (train_batch_step + 1) % FLAGS.eval_every == 0:
97 | # run_num = run_num + 1
98 | # if run_num % 2 == 0:
99 | # break
100 | all_valid_acc_num = 0
101 | all_valid_num = 0
102 | valid_batches = data_help.valid_batch_iterator()
103 | for _, valid_batch_q in enumerate(valid_batches):
104 | all_valid_num = all_valid_num + len(valid_batch_q)
105 | valid_batch = (valid_batch_q, data_help.std_batch)
106 | valid_prob = sess.run([dssm.prob_pre], feed_dict=feed_dict_builder(valid_batch, 1.0, dssm))
107 | valid_acc_num, real_labels, _ = cal_predict_acc_num(valid_prob[0], valid_batch_q,
108 | data_help.id2label)
109 | all_valid_acc_num = all_valid_acc_num + valid_acc_num
110 | current_acc = all_valid_acc_num * 1.0 / all_valid_num
111 | tf.logging.info(
112 | "validing...... global_step {}, valid_acc {}, current_best_acc {}".format(step, current_acc,
113 | best_valid_acc))
114 | if current_acc > best_valid_acc:
115 | tf.logging.info(
116 | "validing...... get the best acc {} and saving model and result".format(current_acc))
117 | saver.save(sess, FLAGS.model_path + "dssm_{}".format(train_batch_step))
118 | best_valid_acc = current_acc
119 | # save label2id, vocab2id
120 | vocabfile = open(FLAGS.vocab2id_path, 'w', encoding='utf-8')
121 | for key, value in data_help.vocab2id.items():
122 | vocabfile.write(str(key) + "\t" + str(value) + '\n')
123 | vocabfile.close()
124 | labelfile = open(FLAGS.label2id_path, 'w', encoding='utf-8')
125 | for key, value in data_help.label2id.items():
126 | labelfile.write(str(key) + "\t" + str(value) + '\n')
127 | labelfile.close()
128 | # break
129 |
130 |
131 | if __name__ == "__main__":
132 | tf.app.run()
133 |
--------------------------------------------------------------------------------
/sptm/format_result.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | format result to qa_match standard format
5 | USAGE:
6 | python format_result.py [test_file] [model_result] [standard_format_output]
7 | """
8 | import codecs
9 | import sys
10 |
11 | if __name__ == "__main__":
12 | # real result: label\tidx\tsentence
13 | real_result_file = sys.argv[1]
14 | # model result: __label__1|score_1,... \tsentence
15 | model_result_file = sys.argv[2]
16 | result_file = sys.argv[3]
17 |
18 | real_labels = []
19 | for line in codecs.open(real_result_file, encoding='utf-8'):
20 | if len(line.split('\t')) == 3:
21 | real_labels.append(line.split('\t')[0])
22 |
23 | fout = codecs.open(result_file, encoding='utf-8', mode='w+')
24 | with codecs.open(model_result_file, encoding='utf-8') as f:
25 | for idx, line in enumerate(f):
26 | line = line.strip()
27 | s_line = line.split('\t')
28 | if len(s_line) >= 2:
29 | model_res = s_line[0]
30 | sentence = s_line[1]
31 |
32 | model_res = model_res.replace("__label__", "") \
33 | .replace('|', ":").replace(",", " ")
34 | # real_label\tsentence\tmodel_labels
35 | fout.write("{}\t{}\t{}\n".format(real_labels[idx],
36 | sentence, model_res))
37 | f.close()
38 | fout.close()
39 |
--------------------------------------------------------------------------------
/sptm/models.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | models implementation for runing pretrain and finetune language models using tensroflow library
4 |
5 | """
6 | import tensorflow as tf
7 | import collections
8 | import re
9 | import math
10 |
11 |
12 | class BiDirectionalLmModel(object):
13 | """Constructor for BiDirectionalLmModel
14 | Args:
15 | lstm_dim: int, The number of units in the LSTM cell.
16 | embedding_dim: int, The size of vocab embedding
17 | layer_num: int, The number of LSTM layer.
18 | token_num: int, The number of Token
19 | input_arg: dict, Args of inputs
20 | """
21 | def __init__(self, input_arg, other_arg_dict):
22 | self.lstm_dim = input_arg.lstm_dim
23 | self.embedding_dim = input_arg.embedding_dim
24 | self.layer_num = input_arg.layer_num
25 | self.token_num = other_arg_dict["token_num"]
26 | self.input_arg = input_arg
27 | if 2 * self.lstm_dim != self.embedding_dim:
28 | tf.logging.info('please set the 2 * lstm_dim == embedding_dim')
29 | assert False
30 |
31 | #Build graph for SPTM
32 | def build(self):
33 | assert self.input_arg.representation_type in ["lstm", "transformer"]
34 | self.ph_tokens = tf.placeholder(dtype=tf.int32, shape=[None, None], name="ph_tokens") # [batch_size, seq_length]
35 | self.ph_length = tf.placeholder(dtype=tf.int32, shape=[None], name="ph_length") # [batch_size]
36 | self.ph_dropout_rate = tf.placeholder(dtype=tf.float32, shape=None, name="ph_dropout_rate")
37 | self.ph_input_mask = tf.placeholder(dtype=tf.int32, shape=[None, None], name="ph_input_mask") #[batch_size, seq_length]
38 |
39 | self.v_token_embedding = tf.get_variable(name="v_token_embedding",
40 | shape=[self.token_num, self.embedding_dim],
41 | dtype=tf.float32,
42 | initializer=tf.contrib.layers.xavier_initializer()) #[token_num, embedding_dim]
43 | seq_embedding = tf.nn.embedding_lookup(self.v_token_embedding,
44 | self.ph_tokens) # [batch_size, seq_length, embedding_dim]
45 |
46 | if self.input_arg.representation_type == "lstm":
47 | tf.logging.info("representation using lstm ...........................")
48 | with tf.variable_scope(tf.get_variable_scope(), reuse=False):
49 | seq_embedding = tf.nn.dropout(seq_embedding, keep_prob=1 - self.ph_dropout_rate)
50 | last_output = seq_embedding
51 | cur_state = None
52 | for layer in range(1, self.layer_num + 1):
53 | fw_cell = tf.nn.rnn_cell.LSTMCell(self.lstm_dim, name="fw_layer_" + str(layer))
54 | bw_cell = tf.nn.rnn_cell.LSTMCell(self.lstm_dim, name="bw_layer_" + str(layer))
55 | cur_output, cur_state = tf.nn.bidirectional_dynamic_rnn(fw_cell, bw_cell, last_output, self.ph_length,
56 | dtype=tf.float32) # [batch, length, dim]
57 | cur_output = tf.concat(cur_output, -1) # [batch, length, 2 * dim]
58 | cur_output = tf.nn.dropout(cur_output, keep_prob=1 - self.ph_dropout_rate)
59 | last_output = tf.contrib.layers.layer_norm(last_output + cur_output, begin_norm_axis=-1) # add and norm
60 |
61 | output = tf.layers.dense(last_output, self.embedding_dim, activation=tf.tanh,
62 | kernel_initializer=tf.contrib.layers.xavier_initializer()) # [batch, length, 2 * dim]
63 | output = tf.nn.dropout(output, keep_prob=1 - self.ph_dropout_rate)
64 |
65 | # sequence output
66 | self.output = tf.contrib.layers.layer_norm(last_output + output,
67 | begin_norm_axis=-1) # add and norm [batch, length, 2 * dim]
68 |
69 | # max pool output
70 | seq_len = tf.shape(self.ph_tokens)[1]
71 | mask = tf.expand_dims(tf.cast(tf.sequence_mask(self.ph_length, maxlen=seq_len), tf.float32),
72 | axis=2) # [batch, len, 1]
73 | mask = (1 - mask) * -1e5
74 | self.max_pool_output = tf.reduce_max(self.output + mask, axis=1, keepdims=False) # [batch, 2 * dim]
75 | elif self.input_arg.representation_type == "transformer":
76 | tf.logging.info("representation using transformer ...........................")
77 | input_shape = self.get_shape_list(seq_embedding) # [batch_size, seq_length, embedding_size]
78 | batch_size = input_shape[0]
79 | seq_length = input_shape[1]
80 | embedding_size = input_shape[2]
81 |
82 | with tf.variable_scope("pos_embeddings"):
83 | all_position_embeddings = tf.get_variable(
84 | name="position_embeddings",
85 | shape=[self.input_arg.max_position_embeddings, embedding_size],
86 | initializer=tf.truncated_normal_initializer(stddev=self.input_arg.initializer_range)) # [max_position_embeddings, embedding_size]
87 | position_embeddings = tf.slice(all_position_embeddings, [0, 0], [seq_length, -1]) # [seq_length, embedding_size]
88 | position_embeddings = tf.reshape(position_embeddings, [1, seq_length, embedding_size]) # [1, seq_length, embedding_size]
89 | seq_embedding += position_embeddings # [batch_size, seq_length, embedding_size]
90 | seq_embedding = tf.contrib.layers.layer_norm(seq_embedding, begin_norm_axis=-1, begin_params_axis=-1) #[batch_size, seq_length, embedding_size]
91 |
92 | with tf.variable_scope("encoder"):
93 | assert self.input_arg.hidden_size % self.input_arg.num_attention_heads == 0
94 | attention_head_size = self.input_arg.hidden_size // self.input_arg.num_attention_heads
95 | self.all_layer_outputs = []
96 | if embedding_size != self.input_arg.hidden_size:
97 | hidden_output = self.dense_layer_2d(seq_embedding, self.input_arg.hidden_size, None, name="embedding_to_hidden")
98 | else:
99 | hidden_output = seq_embedding # [batch_size, seq_length, hidden_size]
100 |
101 | with tf.variable_scope("transformer", reuse=tf.AUTO_REUSE):
102 | for layer_id in range(self.input_arg.num_hidden_layers):
103 | with tf.name_scope("layer_%d" % layer_id):
104 | with tf.variable_scope("self_attention"):
105 | q = self.dense_layer_3d(hidden_output, self.input_arg.num_attention_heads, attention_head_size, None, "query") # [B, F, N, D]
106 | k = self.dense_layer_3d(hidden_output, self.input_arg.num_attention_heads, attention_head_size, None, "key") # [B, F, N, D]
107 | v = self.dense_layer_3d(hidden_output, self.input_arg.num_attention_heads, attention_head_size, None, "value") # [B, F, N, D]
108 | q = tf.transpose(q, [0, 2, 1, 3]) # [B, N, F, D]
109 | k = tf.transpose(k, [0, 2, 1, 3]) # [B, N, F, D]
110 | v = tf.transpose(v, [0, 2, 1, 3]) # [B, N, F, D]
111 | attention_mask = tf.reshape(self.ph_input_mask, [batch_size, 1, seq_length, 1]) # [B, 1, F, 1]
112 | logits = tf.matmul(q, k, transpose_b=True) # q*k => [B, N, F, F]
113 | logits = tf.multiply(logits, 1.0 / math.sqrt(float(self.get_shape_list(q)[-1]))) # q*k/sqrt(Dk) => [B, N, F, F]
114 | from_shape = self.get_shape_list(q) # [B, N, F, D]
115 | broadcast_ones = tf.ones([from_shape[0], 1, from_shape[2], 1], tf.float32) # [B, 1, F, 1]
116 | attention_mask = tf.matmul(broadcast_ones, tf.cast(attention_mask, tf.float32), transpose_b=True) # [B, 1, F, 1] * [B, 1, F, 1] => [B, 1, F, F]
117 | adder = (1.0 - attention_mask) * -10000.0 # [B, 1, F, F]
118 | logits += adder # [B, N, F, F]
119 | attention_probs = tf.nn.softmax(logits, name="attention_probs") # softmax(q*k/sqrt(Dk)), [B, N, F, F]
120 | attention_output = tf.matmul(attention_probs, v) # softmax(q*k/sqrt(Dk))*v , [B, N, F, F] * [B, N, F, D] => [B, N, F, D]
121 | attention_output = tf.transpose(attention_output, [0, 2, 1, 3]) #[B, F, N, D]
122 | attention_output = self.dense_layer_3d_proj(attention_output, self.input_arg.hidden_size, attention_head_size, None, name="dense") # [B, F, H]
123 | attention_output = tf.contrib.layers.layer_norm(inputs=attention_output + hidden_output, begin_norm_axis=-1, begin_params_axis=-1) # [B, F, H]
124 |
125 | with tf.variable_scope("ffn"):
126 | intermediate_output = self.dense_layer_2d(attention_output, self.input_arg.intermediate_size, tf.nn.relu, name="dense") # [B, F, intermediate_size]
127 | hidden_output = self.dense_layer_2d(intermediate_output, self.input_arg.hidden_size, None, name="output_dense") # [B, F, hidden_size]
128 | hidden_output = tf.contrib.layers.layer_norm(inputs=hidden_output + attention_output, begin_norm_axis=-1, begin_params_axis=-1) # [B, F, H]
129 | layer_output = self.dense_layer_2d(hidden_output, embedding_size, None, name="layer_output_dense") # [B, F, embedding_size]
130 | self.all_layer_outputs.append(layer_output)
131 | self.output = self.all_layer_outputs[-1] # [B, F, embedding_size]
132 | # max pool output
133 | mask = tf.expand_dims(tf.cast(tf.sequence_mask(self.ph_length, maxlen=seq_length), tf.float32), axis=2) # [B, F, 1]
134 | mask = (1 - mask) * -1e5
135 | self.max_pool_output = tf.reduce_max(self.output + mask, axis=1, keepdims=False) # [B, embedding_size]
136 |
137 | #make Matrix from 4D to 3D
138 | def dense_layer_3d_proj(self, input_tensor, hidden_size, head_size, activation, name=None):
139 | input_shape = self.get_shape_list(input_tensor) # [B,F,N,D]
140 | num_attention_heads = input_shape[2]
141 | with tf.variable_scope(name):
142 | w = tf.get_variable(name="kernel", shape=[num_attention_heads * head_size, hidden_size], initializer=tf.truncated_normal_initializer(stddev=self.input_arg.initializer_range))
143 | w = tf.reshape(w, [num_attention_heads, head_size, hidden_size])
144 | b = tf.get_variable(name="bias", shape=[hidden_size], initializer=tf.zeros_initializer)
145 | output = tf.einsum("BFND,NDH->BFH", input_tensor, w) # [B, F, H]
146 | output += b
147 | if activation is not None:
148 | return activation(output)
149 | else:
150 | return output
151 |
152 | #make Matrix for 3D transformation in the last index
153 | def dense_layer_2d(self, input_tensor, output_size, activation, name=None):
154 | input_shape = self.get_shape_list(input_tensor) # [B, F, H]
155 | hidden_size = input_shape[2]
156 | with tf.variable_scope(name):
157 | w = tf.get_variable(name="kernel", shape=[hidden_size, output_size], initializer=tf.truncated_normal_initializer(stddev=self.input_arg.initializer_range))
158 | b = tf.get_variable(name="bias", shape=[output_size], initializer=tf.zeros_initializer)
159 | output = tf.einsum("BFH,HO->BFO", input_tensor, w) # [B, F, O]
160 | output += b
161 | if activation is not None:
162 | return activation(output)
163 | else:
164 | return output
165 |
166 | # make Matrix from 3D to 4D
167 | def dense_layer_3d(self, input_tensor, num_attention_heads, head_size, activation, name=None):
168 | input_shape = self.get_shape_list(input_tensor) # [B, F, H]
169 | hidden_size = input_shape[2]
170 | with tf.variable_scope(name):
171 | w = tf.get_variable(name="kernel", shape=[hidden_size, num_attention_heads * head_size], initializer=tf.truncated_normal_initializer(stddev=self.input_arg.initializer_range))
172 | w = tf.reshape(w, [hidden_size, num_attention_heads, head_size])
173 | b = tf.get_variable(name="bias", shape=[num_attention_heads * head_size], initializer=tf.zeros_initializer)
174 | b = tf.reshape(b, [num_attention_heads, head_size])
175 | output = tf.einsum("BFH,HND->BFND", input_tensor, w) #[B, F, N, D]
176 | output += b
177 | if activation is not None:
178 | return activation(output)
179 | else:
180 | return output
181 |
182 | def get_shape_list(self, tensor):
183 | """Returns a list of the shape of tensor, preferring static dimensions.
184 | """
185 | tensor_shape = tensor.shape.as_list()
186 | none_indexes = []
187 | for (index, dim) in enumerate(tensor_shape):
188 | if dim is None:
189 | none_indexes.append(index)
190 | if not none_indexes:
191 | return tensor_shape
192 | dynamic_shape = tf.shape(tensor)
193 | for index in none_indexes:
194 | tensor_shape[index] = dynamic_shape[index]
195 | return tensor_shape
196 |
197 |
198 |
199 | def create_bidirectional_lm_training_op(input_arg, other_arg_dict):
200 | loss_op, model = create_bidirectional_lm_model(input_arg, other_arg_dict)
201 | train_op, learning_rate_op = create_optimizer(loss_op, input_arg.learning_rate, input_arg.train_step,
202 | input_arg.warmup_step, input_arg.clip_norm, input_arg.weight_decay)
203 | model.loss_op = loss_op
204 | model.train_op = train_op
205 | model.learning_rate_op = learning_rate_op
206 | return model
207 |
208 |
209 | def create_bidirectional_lm_model(input_arg, other_arg_dict):
210 | model = BiDirectionalLmModel(input_arg, other_arg_dict)
211 | model.build()
212 | max_predictions_per_seq = input_arg.max_predictions_per_seq
213 |
214 | model.global_step = tf.train.get_or_create_global_step()
215 | model.ph_labels = tf.placeholder(dtype=tf.int32, shape=[None, max_predictions_per_seq],
216 | name="ph_labels") # [batch, max_predictions_per_seq]
217 | model.ph_positions = tf.placeholder(dtype=tf.int32, shape=[None, max_predictions_per_seq],
218 | name="ph_positions") # [batch, max_predictions_per_seq]
219 | model.ph_weights = tf.placeholder(dtype=tf.float32, shape=[None, max_predictions_per_seq],
220 | name="ph_weights") # [batch, max_predictions_per_seq]
221 |
222 | real_output = gather_indexes(model.output, model.ph_positions) # [batch * max_predictions_per_seq, embedding_dim]
223 |
224 | bias = tf.get_variable("bias", shape=[model.token_num], initializer=tf.zeros_initializer())
225 | logits = tf.matmul(real_output, model.v_token_embedding, transpose_b=True) #[batch * max_predictions_per_seq, token_num]
226 | logits = tf.nn.bias_add(logits, bias)
227 |
228 | log_probs = tf.nn.log_softmax(logits, axis=-1) #[batch * max_predictions_per_seq, token_num]
229 | one_hot_labels = tf.one_hot(tf.reshape(model.ph_labels, [-1]), depth=model.token_num, dtype=tf.float32) #[batch * max_predictions_per_seq, token_num]
230 |
231 | per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) # [batch * max_predictions_per_seq]
232 | weights = tf.reshape(model.ph_weights, [-1]) # [batch * max_predictions_per_seq]
233 | loss = (tf.reduce_sum(weights * per_example_loss)) / (tf.reduce_sum(weights) + 1e-5)
234 |
235 | return loss, model
236 |
237 |
238 | def create_optimizer(loss, init_lr=5e-5, num_train_steps=1000000, num_warmup_steps=20000, clip_nom=1.0,
239 | weight_decay=0.01):
240 | global_step = tf.train.get_or_create_global_step()
241 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
242 | learning_rate = tf.train.polynomial_decay(learning_rate, global_step, num_train_steps, end_learning_rate=0.0,
243 | power=1.0, cycle=False) # linear warmup
244 |
245 | if num_warmup_steps:
246 | global_steps_int = tf.cast(global_step, tf.int32)
247 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
248 | warmup_learning_rate = init_lr * tf.cast(global_steps_int, tf.float32) / tf.cast(warmup_steps_int, tf.float32)
249 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
250 | learning_rate = ((1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
251 |
252 | optimizer = tf.contrib.opt.AdamWOptimizer(weight_decay=weight_decay, learning_rate=learning_rate)
253 | tvars = tf.trainable_variables()
254 | grads = tf.gradients(loss, tvars)
255 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=clip_nom)
256 | train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=global_step)
257 | return train_op, learning_rate
258 |
259 |
260 | def gather_indexes(seq_output, positions):
261 | #seq_output:[batch, length, 2 * dim]
262 | #positions:[batch, max_predictions_per_seq]
263 |
264 | batch_size = tf.shape(seq_output)[0]
265 | length = tf.shape(seq_output)[1]
266 | dim = tf.shape(seq_output)[2]
267 |
268 | flat_offsets = tf.reshape(tf.range(0, batch_size, dtype=tf.int32) * length, [-1, 1]) #[batch_size, 1]
269 | output_tensor = tf.gather(tf.reshape(seq_output, [-1, dim]), tf.reshape(positions + flat_offsets, [-1]))
270 | return output_tensor # [batch * max_predictions_per_seq, dim]
271 |
272 |
273 | # finetune model utils
274 | def create_finetune_classification_training_op(input_arg, other_arg_dict):
275 | model = create_finetune_classification_model(input_arg, other_arg_dict)
276 | repre = model.max_pool_output # [batch, 2 * dim]
277 |
278 | model.ph_labels = tf.placeholder(dtype=tf.int32, shape=[None], name="ph_labels") # [batch]
279 | logits = tf.layers.dense(repre, other_arg_dict["label_num"],
280 | kernel_initializer=tf.contrib.layers.xavier_initializer(), name="logits")
281 | model.softmax_op = tf.nn.softmax(logits, -1, name="softmax_pre")
282 | model.loss_op = tf.reduce_mean(
283 | tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=model.ph_labels), -1)
284 | model.global_step_op = tf.train.get_or_create_global_step()
285 |
286 | tf.logging.info("learning_rate : {}".format(input_arg.learning_rate))
287 | if input_arg.opt_type == "sgd":
288 | tf.logging.info("use sgd")
289 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=input_arg.learning_rate)
290 | elif input_arg.opt_type == "adagrad":
291 | tf.logging.info("use adagrad")
292 | optimizer = tf.train.AdagradOptimizer(learning_rate=input_arg.learning_rate)
293 | elif input_arg.opt_type == "adam":
294 | tf.logging.info("use adam")
295 | optimizer = tf.train.AdamOptimizer(learning_rate=input_arg.learning_rate)
296 | else:
297 | assert False
298 |
299 | list_g_v_pair = optimizer.compute_gradients(model.loss_op)
300 | model.train_op = optimizer.apply_gradients(list_g_v_pair, global_step=model.global_step_op)
301 |
302 | return model
303 |
304 | #create finetune model for classification
305 | def create_finetune_classification_model(input_arg, other_arg_dict):
306 | model = BiDirectionalLmModel(input_arg, other_arg_dict)
307 |
308 | model.build()
309 |
310 | tvars = tf.trainable_variables()
311 | initialized_variable_names = {}
312 | if input_arg.init_checkpoint:
313 | tf.logging.info("init from checkpoint!")
314 | assignment_map, initialized_variable_names = get_assignment_map_from_checkpoint(tvars,
315 | input_arg.init_checkpoint)
316 | tf.train.init_from_checkpoint(input_arg.init_checkpoint, assignment_map)
317 |
318 | tf.logging.info("**** Trainable Variables ****")
319 | for var in tvars:
320 | init_string = ""
321 | if var.name in initialized_variable_names:
322 | init_string = ", *INIT_FROM_CKPT*"
323 | tf.logging.info("name = {}, shape = {}{}".format(var.name, var.shape, init_string))
324 |
325 | return model
326 |
327 |
328 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
329 | name_to_variable = collections.OrderedDict() # trainable variables
330 | for var in tvars:
331 | name = var.name
332 | m = re.match("^(.*):\\d+$", name)
333 | if m is not None:
334 | name = m.group(1)
335 | name_to_variable[name] = var
336 |
337 | assignment_map = collections.OrderedDict()
338 | initialized_variable_names = {}
339 | init_vars = tf.train.list_variables(init_checkpoint) # variables in checkpoint
340 | for x in init_vars:
341 | (name, var) = (x[0], x[1])
342 | if name not in name_to_variable:
343 | continue
344 | assignment_map[name] = name
345 | initialized_variable_names[name] = 1
346 | initialized_variable_names[name + ":0"] = 1
347 |
348 | return assignment_map, initialized_variable_names
349 |
--------------------------------------------------------------------------------
/sptm/run_classifier.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | finetune on pretrained model with trainset and devset
4 | """
5 |
6 | import sys
7 | import os
8 | import tensorflow as tf
9 | import numpy as np
10 | import argparse
11 | import models
12 | import utils
13 |
14 |
15 | def evaluate(sess, full_tensors, args, model):
16 | total_num = 0
17 | right_num = 0
18 | for batch_data in utils.gen_batchs(full_tensors, args.batch_size, is_shuffle=False):
19 | softmax_re = sess.run(model.softmax_op,
20 | feed_dict={model.ph_dropout_rate: 0,
21 | model.ph_tokens: batch_data[0],
22 | model.ph_labels: batch_data[1],
23 | model.ph_length: batch_data[2],
24 | model.ph_input_mask: batch_data[3]})
25 | pred_re = np.argmax(softmax_re, axis=1)
26 | total_num += len(pred_re)
27 | right_num += np.sum(pred_re == batch_data[1])
28 | acc = 1.0 * right_num / (total_num + 1e-5)
29 |
30 | tf.logging.info("dev total num: " + str(total_num) + ", right num: " + str(right_num) + ", acc: " + str(acc))
31 | return acc
32 |
33 |
34 | def main(_):
35 | tf.logging.set_verbosity(tf.logging.INFO)
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument("--train_file", type=str, default="", help="Input train file.")
38 | parser.add_argument("--dev_file", type=str, default="", help="Input dev file.")
39 | parser.add_argument("--vocab_file", type=str, default="", help="Input vocab file.")
40 | parser.add_argument("--output_id2label_file", type=str, default="./id2label",
41 | help="File containing (id, class label) map.")
42 | parser.add_argument("--model_save_dir", type=str, default="",
43 | help="Specified the directory in which the model should stored.")
44 | parser.add_argument("--lstm_dim", type=int, default=500, help="Dimension of LSTM cell.")
45 | parser.add_argument("--embedding_dim", type=int, default=1000, help="Dimension of word embedding.")
46 | parser.add_argument("--opt_type", type=str, default='adam', help="Type of optimizer.")
47 | parser.add_argument("--batch_size", type=int, default=32, help="Batch size.")
48 | parser.add_argument("--epoch", type=int, default=20, help="Epoch.")
49 | parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
50 | parser.add_argument("--dropout_rate", type=float, default=0.1, help="Dropout rate")
51 | parser.add_argument("--seed", type=int, default=1, help="Random seed value.")
52 | parser.add_argument("--print_step", type=int, default=1000, help="Print log every x step.")
53 | parser.add_argument("--init_checkpoint", type=str, default='',
54 | help="Initial checkpoint (usually from a pre-trained model).")
55 | parser.add_argument("--max_len", type=int, default=100, help="Max seqence length.")
56 | parser.add_argument("--layer_num", type=int, default=2, help="LSTM layer num.")
57 |
58 | parser.add_argument("--representation_type", type=str, default="lstm",
59 | help="representation type include:lstm, transformer")
60 |
61 | # transformer args
62 | parser.add_argument("--initializer_range", type=float, default="0.02", help="Embedding initialization range")
63 | parser.add_argument("--max_position_embeddings", type=int, default=512, help="max position num")
64 | parser.add_argument("--hidden_size", type=int, default=768, help="hidden size")
65 | parser.add_argument("--num_hidden_layers", type=int, default=12, help="num hidden layer")
66 | parser.add_argument("--num_attention_heads", type=int, default=12, help="num attention heads")
67 | parser.add_argument("--intermediate_size", type=int, default=3072, help="intermediate_size")
68 |
69 | args = parser.parse_args()
70 |
71 | np.random.seed(args.seed)
72 | tf.set_random_seed(args.seed)
73 | tf.logging.info(str(args))
74 | if not os.path.exists(args.model_save_dir):
75 | os.mkdir(args.model_save_dir)
76 |
77 | tf.logging.info("load training sens")
78 | train_sens = utils.load_training_data(args.train_file, skip_invalid=True)
79 | tf.logging.info("\nload dev sens")
80 | dev_sens = utils.load_training_data(args.dev_file, skip_invalid=True)
81 |
82 | word2id, id2word, label2id, id2label = utils.load_vocab(train_sens + dev_sens, args.vocab_file)
83 | fw = open(args.output_id2label_file, 'w+')
84 | for k, v in id2label.items():
85 | fw.write(str(k) + "\t" + v + "\n")
86 | fw.close()
87 |
88 | utils.gen_ids(train_sens, word2id, label2id, args.max_len)
89 | utils.gen_ids(dev_sens, word2id, label2id, args.max_len)
90 |
91 | train_full_tensors = utils.make_full_tensors(train_sens)
92 | dev_full_tensors = utils.make_full_tensors(dev_sens)
93 |
94 | other_arg_dict = {}
95 | other_arg_dict['token_num'] = len(word2id)
96 | other_arg_dict['label_num'] = len(label2id)
97 | model = models.create_finetune_classification_training_op(args, other_arg_dict)
98 |
99 | steps_in_epoch = int(len(train_sens) // args.batch_size)
100 | tf.logging.info("batch size: " + str(args.batch_size) + ", training sample num : " + str(
101 | len(train_sens)) + ", print step : " + str(args.print_step))
102 | tf.logging.info(
103 | "steps_in_epoch : " + str(steps_in_epoch) + ", epoch num :" + str(args.epoch) + ", total steps : " + str(
104 | args.epoch * steps_in_epoch))
105 | print_step = min(args.print_step, steps_in_epoch)
106 | tf.logging.info("eval dev every {} step".format(print_step))
107 |
108 | save_vars = [v for v in tf.global_variables() if
109 | v.name.find('adam') < 0 and v.name.find('Adam') < 0 and v.name.find('ADAM') < 0]
110 | tf.logging.info(str(save_vars))
111 | tf.logging.info(str(tf.all_variables()))
112 |
113 | saver = tf.train.Saver(max_to_keep=2)
114 | config = tf.ConfigProto(allow_soft_placement=True)
115 | config.gpu_options.allow_growth = True
116 | with tf.Session(config=config) as sess:
117 | sess.run(tf.global_variables_initializer())
118 | total_loss = 0
119 | dev_best_so_far = 0
120 | for epoch in range(1, args.epoch + 1):
121 | tf.logging.info("\n" + "*" * 20 + "epoch num :" + str(epoch) + "*" * 20)
122 | for batch_data in utils.gen_batchs(train_full_tensors, args.batch_size, is_shuffle=True):
123 | _, global_step, loss = sess.run([model.train_op, model.global_step_op, model.loss_op],
124 | feed_dict={model.ph_dropout_rate: args.dropout_rate,
125 | model.ph_tokens: batch_data[0],
126 | model.ph_labels: batch_data[1],
127 | model.ph_length: batch_data[2],
128 | model.ph_input_mask: batch_data[3]})
129 | total_loss += loss
130 | if global_step % print_step == 0:
131 | tf.logging.info(
132 | "\nglobal step : " + str(global_step) + ", avg loss so far : " + str(total_loss / global_step))
133 | tf.logging.info("begin to eval dev set: ")
134 | acc = evaluate(sess, dev_full_tensors, args, model)
135 | if acc > dev_best_so_far:
136 | dev_best_so_far = acc
137 | tf.logging.info("!" * 20 + "best got : " + str(acc))
138 | # constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["scores"])
139 | saver.save(sess, args.model_save_dir + '/finetune.ckpt', global_step=global_step)
140 |
141 | tf.logging.info("\n----------------------eval after one epoch: ")
142 | tf.logging.info(
143 | "global step : " + str(global_step) + ", avg loss so far : " + str(total_loss / global_step))
144 | tf.logging.info("begin to eval dev set: ")
145 | sys.stdout.flush()
146 | acc = evaluate(sess, dev_full_tensors, args, model)
147 | if acc > dev_best_so_far:
148 | dev_best_so_far = acc
149 | tf.logging.info("!" * 20 + "best got : " + str(acc))
150 | saver.save(sess, args.model_save_dir + '/finetune.ckpt', global_step=global_step)
151 |
152 |
153 | if __name__ == "__main__":
154 | tf.app.run()
155 |
--------------------------------------------------------------------------------
/sptm/run_prediction.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | """
3 | predict with finetuned model with testset
4 | """
5 |
6 | import sys
7 | import tensorflow as tf
8 | import numpy as np
9 | import argparse
10 | import utils
11 |
12 |
13 | def get_output(g):
14 | return {"softmax": g.get_tensor_by_name("softmax_pre:0")}
15 |
16 |
17 | def get_input(g):
18 | return {"tokens": g.get_tensor_by_name("ph_tokens:0"),
19 | "length": g.get_tensor_by_name("ph_length:0"),
20 | "dropout_rate": g.get_tensor_by_name("ph_dropout_rate:0"),
21 | "input_mask": g.get_tensor_by_name("ph_input_mask:0")}
22 |
23 |
24 | def main(_):
25 | tf.logging.set_verbosity(tf.logging.INFO)
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument("--input_file", type=str, default="", help="Input file for prediction.")
28 | parser.add_argument("--vocab_file", type=str, default="", help="Input train file.")
29 | parser.add_argument("--model_path", type=str, default="", help="Path to model file.")
30 | parser.add_argument("--model_dir", type=str, default="", help="Directory which contains model.")
31 | parser.add_argument("--id2label_file", type=str, default="./id2label",
32 | help="File containing (id, class label) map.")
33 | args = parser.parse_args()
34 |
35 | word2id, id2word = utils.load_vocab_file(args.vocab_file)
36 | sys.stderr.write("vocab num : " + str(len(word2id)) + "\n")
37 |
38 | sens = utils.gen_test_data(args.input_file, word2id)
39 | sys.stderr.write("sens num : " + str(len(sens)) + "\n")
40 |
41 | id2label = utils.load_id2label_file(args.id2label_file)
42 | sys.stderr.write('label num : ' + str(len(id2label)) + "\n")
43 |
44 | # use latest checkpoint
45 | if "" == args.model_path:
46 | args.model_path = tf.train.latest_checkpoint(checkpoint_dir=args.model_dir)
47 |
48 | config = tf.ConfigProto()
49 | config.gpu_options.allow_growth = True
50 | with tf.Session(config=config) as sess:
51 | saver = tf.train.import_meta_graph("{}.meta".format(args.model_path))
52 | saver.restore(sess, args.model_path)
53 |
54 | graph = tf.get_default_graph()
55 | input_dict = get_input(graph)
56 | output_dict = get_output(graph)
57 |
58 | for sen in sens:
59 | re = sess.run(output_dict['softmax'], feed_dict={input_dict['tokens']: [sen[0]],
60 | input_dict['input_mask']: [sen[1]],
61 | input_dict['length']: [len(sen[0])],
62 | input_dict["dropout_rate"]: 0.0})
63 | sorted_idx = np.argsort(-1 * re[0]) # sort by desc
64 | s = ""
65 | for i in sorted_idx[:3]:
66 | s += id2label[i] + "|" + str(re[0][i]) + ","
67 | print(s + "\t" + " ".join([id2word[t] for t in sen[0]]))
68 |
69 |
70 | if __name__ == "__main__":
71 | tf.app.run()
72 |
--------------------------------------------------------------------------------
/sptm/run_pretraining.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | pretrain a specified language model(modified bi-lstm as default)
4 | """
5 |
6 | from __future__ import print_function
7 |
8 | import os
9 | import tensorflow as tf
10 | import numpy as np
11 | import argparse
12 | import models
13 | import utils
14 | import gc
15 | import time
16 |
17 |
18 | def main(_):
19 | tf.logging.set_verbosity(tf.logging.INFO)
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument("--train_file", type=str, default="", help="Input train file.")
22 | parser.add_argument("--vocab_file", default="", help="Input vocab file.")
23 | parser.add_argument("--model_save_dir", type=str, default="",
24 | help="Specified the directory in which the model should stored.")
25 | parser.add_argument("--lstm_dim", type=int, default=500, help="Dimension of LSTM cell.")
26 | parser.add_argument("--embedding_dim", type=int, default=1000, help="Dimension of word embedding.")
27 | parser.add_argument("--layer_num", type=int, default=2, help="LSTM layer num.")
28 | parser.add_argument("--batch_size", type=int, default=32, help="Batch size.")
29 | parser.add_argument("--train_step", type=int, default=10000, help="Number of training steps.")
30 | parser.add_argument("--warmup_step", type=int, default=1000, help="Number of warmup steps.")
31 | parser.add_argument("--learning_rate", type=float, default=0.001, help="The initial learning rate")
32 | parser.add_argument("--dropout_rate", type=float, default=0.5, help="Dropout rate")
33 | parser.add_argument("--seed", type=int, default=0, help="Random seed value.")
34 | parser.add_argument("--print_step", type=int, default=1000, help="Print log every x step.")
35 | parser.add_argument("--max_predictions_per_seq", type=int, default=10,
36 | help="For each sequence, predict x words at most.")
37 | parser.add_argument("--weight_decay", type=float, default=0, help="Weight decay rate")
38 | parser.add_argument("--clip_norm", type=float, default=1, help='Clip normalization rate.')
39 | parser.add_argument("--max_seq_len", type=int, default=100, help="Max seqence length.")
40 | parser.add_argument("--use_queue", type=int, default=0, help="Whether or not using a queue for input.")
41 | parser.add_argument("--init_checkpoint", type=str, default="", help="Initial checkpoint")
42 | parser.add_argument("--enqueue_thread_num", type=int, default=5, help="Enqueue thread count.")
43 | parser.add_argument("--representation_type", type=str, default="lstm", help="representation type include:lstm, transformer")
44 |
45 | #transformer args
46 | parser.add_argument("--initializer_range", type=float, default="0.02", help="Embedding initialization range")
47 | parser.add_argument("--max_position_embeddings", type=int, default=512, help="max position num")
48 | parser.add_argument("--hidden_size", type=int, default=768, help="hidden size")
49 | parser.add_argument("--num_hidden_layers", type=int, default=12, help ="num hidden layer")
50 | parser.add_argument("--num_attention_heads", type=int, default=12, help="num attention heads")
51 | parser.add_argument("--intermediate_size", type=int, default=3072, help="intermediate_size")
52 |
53 | args = parser.parse_args()
54 |
55 | np.random.seed(args.seed)
56 | tf.set_random_seed(args.seed)
57 | tf.logging.info(args)
58 | if not os.path.exists(args.model_save_dir):
59 | os.mkdir(args.model_save_dir)
60 |
61 | # load data
62 | word2id, id2word = utils.load_vocab_file(args.vocab_file)
63 | training_sens = utils.load_pretraining_data(args.train_file, args.max_seq_len)
64 |
65 | if not args.use_queue:
66 | utils.to_ids(training_sens, word2id, args, id2word)
67 |
68 | other_arg_dict = {}
69 | other_arg_dict['token_num'] = len(word2id)
70 |
71 | # load model
72 | model = models.create_bidirectional_lm_training_op(args, other_arg_dict)
73 |
74 | gc.collect()
75 | saver = tf.train.Saver(max_to_keep=2)
76 | config = tf.ConfigProto(allow_soft_placement=True)
77 | config.gpu_options.allow_growth = True
78 | with tf.Session(config=config) as sess:
79 | sess.run(tf.global_variables_initializer())
80 |
81 | if args.init_checkpoint:
82 | tf.logging.info('restore the checkpoint : ' + str(args.init_checkpoint))
83 | saver.restore(sess, args.init_checkpoint)
84 |
85 | total_loss = 0
86 | num = 0
87 | global_step = 0
88 | while global_step < args.train_step:
89 | if not args.use_queue:
90 | iterator = utils.gen_batches(training_sens, args.batch_size)
91 | else:
92 | iterator = utils.queue_gen_batches(training_sens, args, word2id, id2word)
93 | assert iterator is not None
94 | for batch_data in iterator:
95 | feed_dict = {model.ph_tokens: batch_data[0],
96 | model.ph_length: batch_data[1],
97 | model.ph_labels: batch_data[2],
98 | model.ph_positions: batch_data[3],
99 | model.ph_weights: batch_data[4],
100 | model.ph_input_mask: batch_data[5],
101 | model.ph_dropout_rate: args.dropout_rate}
102 | _, global_step, loss, learning_rate = sess.run([model.train_op, \
103 | model.global_step, model.loss_op,
104 | model.learning_rate_op], feed_dict=feed_dict)
105 |
106 | total_loss += loss
107 | num += 1
108 | if global_step % args.print_step == 0:
109 | tf.logging.info("\nglobal step : " + str(global_step) +
110 | ", avg loss so far : " + str(total_loss / num) +
111 | ", instant loss : " + str(loss) +
112 | ", learning_rate : " + str(learning_rate) +
113 | ", time :" + str(time.strftime('%Y-%m-%d %H:%M:%S')))
114 | tf.logging.info("save model ...")
115 | saver.save(sess, args.model_save_dir + '/lm_pretrain.ckpt', global_step=global_step)
116 | gc.collect()
117 |
118 | if not args.use_queue:
119 | utils.to_ids(training_sens, word2id, args, id2word) # MUST run this for randomization for each sentence
120 | gc.collect()
121 |
122 |
123 | if __name__ == "__main__":
124 | tf.app.run()
125 |
--------------------------------------------------------------------------------
/sptm/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | tools for runing pretrain and finetune language models
4 |
5 | """
6 | import sys
7 | import numpy as np
8 | import codecs
9 | from collections import namedtuple
10 | import queue
11 | import threading
12 | import tensorflow as tf
13 |
14 |
15 | # sample representation
16 | class Sentence(object):
17 | def __init__(self, raw_tokens, raw_label=None):
18 | self.raw_tokens = raw_tokens
19 | self.raw_label = raw_label
20 | self.label_id = None
21 | self.token_ids = []
22 |
23 | # for pretrain
24 | def to_id(self, word2id, args):
25 | # for each epoch, this should be rerun to get random results for each sentence.
26 | self.positions = []
27 | self.labels = []
28 | self.weights = []
29 | self.fw_labels = []
30 | self.bw_labels = []
31 | self.token_ids = []
32 | self.input_masks = []
33 |
34 | for t in self.raw_tokens:
35 | self.token_ids.append(word2id[t])
36 | self.input_masks.append(1)
37 |
38 | for ta in self.bidirectional_targets:
39 | # predict itself
40 | self.labels.append(self.token_ids[ta.position])
41 | # in-place modify the target token in the sentence
42 | self.token_ids[ta.position] = word2id[ta.replace_token]
43 | self.positions.append(ta.position)
44 | self.weights.append(1.0)
45 |
46 | # fix to tensors for the predictions of LM
47 | cur_len = len(self.labels)
48 | self.labels = self.labels + [0] * (args.max_predictions_per_seq - cur_len)
49 | self.positions = self.positions + [0] * (args.max_predictions_per_seq - cur_len)
50 | self.weights = self.weights + [0] * (args.max_predictions_per_seq - cur_len)
51 |
52 | # for finetune
53 | def to_ids(self, word2id, label2id, max_len):
54 | self.label_id = label2id[self.raw_label]
55 | self.raw_tokens = self.raw_tokens[:max_len] # cut off to the max length
56 | self.input_masks = []
57 | all_unk = True
58 | for t in self.raw_tokens:
59 | if t in word2id:
60 | self.token_ids.append(word2id[t])
61 | all_unk = False
62 | else:
63 | self.token_ids.append(word2id[""])
64 | self.input_masks.append(1)
65 | assert not all_unk
66 |
67 | self.token_ids = self.token_ids + [0] * (max_len - len(self.token_ids))
68 | self.input_masks = self.input_masks + [0] * (max_len - len(self.input_masks))
69 |
70 | # file utils
71 |
72 |
73 | def load_vocab_file(vocab_file):
74 | word2id = {}
75 | id2word = {}
76 | for l in codecs.open(vocab_file, 'r', 'utf-8'):
77 | l = l.strip()
78 | assert l != ""
79 | assert l not in word2id
80 | word2id[l] = len(word2id)
81 | id2word[len(id2word)] = l
82 | sys.stderr.write("uniq token num : " + str(len(word2id)) + "\n")
83 | return word2id, id2word
84 |
85 |
86 | def load_vocab(sens, vocab_file):
87 | label2id = {}
88 | id2label = {}
89 | for sen in sens:
90 | if sen.raw_label not in label2id:
91 | label2id[sen.raw_label] = len(label2id)
92 | id2label[len(id2label)] = sen.raw_label
93 |
94 | word2id, id2word = load_vocab_file(vocab_file)
95 | assert len(word2id) == len(id2word)
96 | tf.logging.info("\ntoken num : " + str(len(word2id)))
97 | tf.logging.info(", label num : " + str(len(label2id)))
98 | tf.logging.info(", labels: " + str(id2label))
99 | return word2id, id2word, label2id, id2label
100 |
101 |
102 | def load_id2label_file(id2label_file):
103 | di = {}
104 | for l in open(id2label_file, 'r'):
105 | fs = l.rstrip().split('\t')
106 | assert len(fs) == 2
107 | di[int(fs[0])] = fs[1]
108 | return di
109 |
110 |
111 | def gen_test_data(test_file, word2id):
112 | """
113 | read and encode test file.
114 | """
115 | sens = []
116 | for l in codecs.open(test_file, 'r', 'utf-8'):
117 | fs = l.rstrip().split('\t')[-1].split()
118 | sen = []
119 | mask = []
120 | for f in fs:
121 | if f in word2id:
122 | sen.append(word2id[f])
123 | else:
124 | sen.append(word2id[''])
125 | mask.append(1)
126 | sens.append((sen, mask))
127 | return sens
128 |
129 |
130 | def load_pretraining_data(train_file, max_seq_len):
131 | sens = []
132 | for l in codecs.open(train_file, 'r', 'utf-8'):
133 | sen = Sentence(l.rstrip().split("\t")[-1].split()[:max_seq_len])
134 | if len(sen.raw_tokens) == 0:
135 | continue
136 | sens.append(sen)
137 | if len(sens) % 2000000 == 0:
138 | tf.logging.info("load sens :" + str(len(sens)))
139 | tf.logging.info("training sens num :" + str(len(sens)))
140 | return sens
141 |
142 |
143 | def load_training_data(file_path, skip_invalid=True):
144 | sens = []
145 | invalid_num = 0
146 | max_len = 0
147 | for l in codecs.open(file_path, 'r', 'utf-8'): # load as utf-8 encoding.
148 | if l.strip() == "":
149 | continue
150 | fs = l.rstrip().split('\t')
151 | assert len(fs) == 3
152 | tokens = fs[2].split() # discard empty strings
153 | for t in tokens:
154 | assert t != ""
155 | label = "__label__{}".format(fs[0])
156 | if skip_invalid:
157 | if label.find(',') >= 0 or label.find('NONE') >= 0:
158 | invalid_num += 1
159 | continue
160 | if len(tokens) > max_len:
161 | max_len = len(tokens)
162 | sens.append(Sentence(tokens, label))
163 | tf.logging.info("invalid sen num : " + str(invalid_num))
164 | tf.logging.info("valid sen num : " + str(len(sens)))
165 | tf.logging.info("max_len : " + str(max_len))
166 | return sens
167 |
168 |
169 | # pretrain utils
170 | BiReplacement = namedtuple("BiReplacement", ["position", "replace_token"])
171 |
172 |
173 | def gen_pretrain_targets(raw_tokens, id2word, max_predictions_per_seq):
174 | assert max_predictions_per_seq > 0
175 | assert len(raw_tokens) > 0
176 | pred_num = min(max_predictions_per_seq, max(1, int(round(len(raw_tokens) * 0.15))))
177 |
178 | re = []
179 | covered_pos_set = set()
180 | for _ in range(pred_num):
181 | cur_pos = np.random.randint(0, len(raw_tokens))
182 | if cur_pos in covered_pos_set:
183 | continue
184 | covered_pos_set.add(cur_pos)
185 |
186 | prob = np.random.uniform()
187 | if prob < 0.8:
188 | replace_token = ''
189 | elif prob < 0.9:
190 | replace_token = raw_tokens[cur_pos] # itself
191 | else:
192 | while True:
193 | fake_pos = np.random.randint(0, len(id2word)) # random one
194 | replace_token = id2word[fake_pos]
195 | if raw_tokens[cur_pos] != replace_token:
196 | break
197 | re.append(BiReplacement(position=cur_pos, replace_token=replace_token))
198 | return re
199 |
200 |
201 | def gen_ids(sens, word2id, label2id, max_len):
202 | for sen in sens:
203 | sen.to_ids(word2id, label2id, max_len)
204 |
205 |
206 | def to_ids(sens, word2id, args, id2word):
207 | num = 0
208 | for sen in sens:
209 | if num % 2000000 == 0:
210 | tf.logging.info("to_ids handling num : " + str(num))
211 | num += 1
212 | sen.bidirectional_targets = gen_pretrain_targets(sen.raw_tokens, id2word, args.max_predictions_per_seq)
213 | sen.to_id(word2id, args)
214 |
215 |
216 | def gen_batches(sens, batch_size):
217 | per = np.array([i for i in range(len(sens))])
218 | np.random.shuffle(per)
219 |
220 | cur_idx = 0
221 | token_batch = []
222 | input_mask_batch = []
223 | length_batch = []
224 |
225 | position_batch = []
226 | label_batch = []
227 | weight_batch = []
228 |
229 | while cur_idx < len(sens):
230 | token_batch.append(sens[per[cur_idx]].token_ids)
231 | length_batch.append(len(sens[per[cur_idx]].token_ids))
232 | input_mask_batch.append(sens[per[cur_idx]].input_masks)
233 |
234 | label_batch.append(sens[per[cur_idx]].labels)
235 | position_batch.append(sens[per[cur_idx]].positions)
236 | weight_batch.append(sens[per[cur_idx]].weights)
237 | if len(token_batch) == batch_size or cur_idx == len(sens) - 1:
238 | max_len = max(length_batch)
239 | for ts in token_batch: ts.extend([0] * (max(length_batch) - len(ts)))
240 | for im in input_mask_batch: im.extend([0] * (max_len - len(im)))
241 |
242 | yield token_batch, length_batch, label_batch, position_batch, weight_batch, input_mask_batch
243 |
244 | del token_batch
245 | del input_mask_batch
246 | del length_batch
247 | del label_batch
248 | del position_batch
249 | del weight_batch
250 | token_batch = []
251 | input_mask_batch = []
252 | length_batch = []
253 | label_batch = []
254 | position_batch = []
255 | weight_batch = []
256 | cur_idx += 1
257 |
258 |
259 | def queue_gen_batches(sens, args, word2id, id2word):
260 | def enqueue(sens, q):
261 | permu = np.arange(len(sens))
262 | np.random.shuffle(permu)
263 | idx = 0
264 | tf.logging.info("thread started!")
265 | while True:
266 | sen = sens[permu[idx]]
267 | sen.bidirectional_targets = gen_pretrain_targets(sen.raw_tokens, id2word,
268 | args.max_predictions_per_seq)
269 | sen.to_id(word2id, args)
270 | q.put(sen)
271 | idx += 1
272 | if idx >= len(sens):
273 | np.random.shuffle(permu)
274 | idx = idx % len(sens)
275 |
276 | q = queue.Queue(maxsize=50000)
277 |
278 | for i in range(args.enqueue_thread_num):
279 | tf.logging.info("enqueue thread started : " + str(i))
280 | enqeue_thread = threading.Thread(target=enqueue, args=(sens, q))
281 | enqeue_thread.setDaemon(True)
282 | enqeue_thread.start()
283 |
284 | qu_sens = []
285 | while True:
286 | cur_sen = q.get()
287 | qu_sens.append(cur_sen)
288 | if len(qu_sens) >= args.batch_size:
289 | for data in gen_batches(qu_sens, args.batch_size):
290 | yield data
291 | qu_sens = []
292 |
293 |
294 | def make_full_tensors(sens):
295 | tokens = np.zeros((len(sens), len(sens[0].token_ids)), dtype=np.int32)
296 | masks = np.zeros((len(sens), len(sens[0].input_masks)), dtype=np.int32)
297 | labels = np.zeros((len(sens)), dtype=np.int32)
298 | length = np.zeros((len(sens)), dtype=np.int32)
299 | for idx, sen in enumerate(sens):
300 | tokens[idx] = sen.token_ids
301 | masks[idx] = sen.input_masks
302 | labels[idx] = sen.label_id
303 | length[idx] = len(sen.raw_tokens)
304 | return tokens, labels, length, masks
305 |
306 |
307 | def gen_batchs(full_tensors, batch_size, is_shuffle):
308 | tokens, labels, length, masks = full_tensors
309 | per = np.array([i for i in range(len(tokens))])
310 | if is_shuffle:
311 | np.random.shuffle(per)
312 |
313 | cur_idx = 0
314 | token_batch = []
315 | mask_batch = []
316 | label_batch = []
317 | length_batch = []
318 | while cur_idx < len(tokens):
319 | token_batch.append(tokens[per[cur_idx]])
320 | mask_batch.append(masks[per[cur_idx]])
321 | label_batch.append(labels[per[cur_idx]])
322 | length_batch.append(length[per[cur_idx]])
323 |
324 | if len(token_batch) == batch_size or cur_idx == len(tokens) - 1:
325 | # make the tokens to real max length
326 | real_max_len = max(length_batch)
327 | for i in range(len(token_batch)):
328 | token_batch[i] = token_batch[i][:real_max_len]
329 | mask_batch[i] = mask_batch[i][:real_max_len]
330 |
331 | yield token_batch, label_batch, length_batch, mask_batch
332 | token_batch = []
333 | label_batch = []
334 | length_batch = []
335 | mask_batch = []
336 | cur_idx += 1
337 |
338 |
339 | if __name__ == "__main__":
340 | pass
341 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/utils/__init__.py
--------------------------------------------------------------------------------
/utils/classifier_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | """
4 | tools for run bi-lstm short text classification
5 | """
6 |
7 | import numpy as np
8 | import math
9 |
10 |
11 | class TextLoader(object):
12 | def __init__(self, is_training, data_path, map_file_path, batch_size, seq_length, vocab, labels, std_label_map,
13 | encoding='utf8', is_reverse=False):
14 | self.data_path = data_path
15 | self.map_file_path = map_file_path
16 | self.batch_size = batch_size
17 | self.seq_length = seq_length
18 | self.is_train = is_training
19 | self.encoding = encoding
20 | # load label std mapping index
21 | self.std_label_map = {}
22 | self.label_num_map = {}
23 | label_set = set()
24 | word_set = set()
25 | if is_training:
26 | with open(map_file_path, 'r', encoding=encoding) as map_index_file:
27 | for line in map_index_file:
28 | tokens = line.strip().split('\t')
29 | assert len(tokens) == 3
30 | label = tokens[0]
31 | label_set.add(label)
32 | std_id = tokens[1]
33 | self.std_label_map[std_id] = label
34 | words = tokens[2].split(" ")
35 | for token in words:
36 | word_set.add(token)
37 |
38 | train_file = data_path
39 | with open(train_file, 'r', encoding=encoding) as fin:
40 | for line in fin:
41 | tokens = line.strip().split('\t')
42 | assert len(tokens) == 3
43 | std_ids = tokens[0].split(",")
44 | words = tokens[2].split(" ")
45 | if len(std_ids) > 1:
46 | label = '__label__list' # answer list
47 | label_set.add(label)
48 | elif std_ids[0] == '0':
49 | label = '__label__none' # refuse answer
50 | label_set.add(label)
51 | else:
52 | assert std_ids[0] in self.std_label_map
53 | label = self.std_label_map.get(std_ids[0]) # __label__xx:some label
54 | for token in words:
55 | word_set.add(token)
56 | if label not in self.label_num_map:
57 | self.label_num_map[label] = 1
58 | else:
59 | self.label_num_map[label] = self.label_num_map[label] + 1
60 |
61 | self.labels = dict(
62 | zip(list(label_set), range(0, len(label_set)))) # {__label__1:0, __label_2:1, __label_3:2, ...}
63 | # print("self.labels: " + str(self.labels))
64 | # print("self.std_label_map: " + str(self.std_label_map))
65 | self.id_2_label = {value: key for key, value in self.labels.items()}
66 | self.label_size = len(self.labels)
67 | self.vocab = dict(zip(list(word_set), range(1, len(word_set) + 1)))
68 | self.id_2_vocab = {value: key for key, value in self.vocab.items()}
69 | self.vocab_size = len(
70 | self.vocab) + 1 # self.vocab.size + 1, 0 for pad, not encoding unknown, if care for it, you can modify here
71 | self.load_preprocessed(data_path, is_reverse)
72 | elif vocab is not None and labels is not None and std_label_map is not None:
73 | self.vocab = vocab
74 | self.id_2_vocab = {value: key for key, value in self.vocab.items()}
75 | self.vocab_size = len(vocab) + 1
76 | self.labels = labels
77 | self.id_2_label = {value: key for key, value in self.labels.items()}
78 | self.label_size = len(self.labels)
79 | self.std_label_map = std_label_map
80 | self.load_preprocessed(data_path, is_reverse)
81 | self.num_batches = 1
82 | self.x_batches = None
83 | self.y_batches = None
84 | self.len_batches = None
85 | self.reset_batch_pointer()
86 |
87 | def load_preprocessed(self, data_path, is_reverse):
88 | train_file = data_path
89 | self.raw_lines = []
90 | with open(train_file, 'r', encoding=self.encoding) as fin:
91 | train_x = []
92 | train_y = []
93 | train_len = []
94 | for line in fin:
95 | temp_x = []
96 | temp_y = []
97 | x_len = []
98 | tokens = line.strip().split('\t')
99 | assert len(tokens) == 3
100 | std_ids = tokens[0].split(",")
101 | words = tokens[2].split(" ")
102 | if len(std_ids) > 1:
103 | label = '__label__list' # answer list
104 | elif std_ids[0] == '0':
105 | label = '__label__none' # refuse answer
106 | else:
107 | if std_ids[0] not in self.std_label_map:
108 | label = '__label__none'
109 | else:
110 | label = self.std_label_map.get(std_ids[0]) # __label__xx:some label
111 | # if label not in self.labels:
112 | # print("label: <" + label + ">")
113 | # print("self.labels: ")
114 | # print(str(self.labels))
115 | temp_y.append(self.labels[label])
116 | for item in words:
117 | if item in self.vocab: # not encoding unknown, if care for it, you can modify here
118 | temp_x.append(self.vocab[item])
119 | if len(temp_x) == 0:
120 | print("all word in line is not in vocab, line: " + line)
121 | continue
122 | if len(temp_x) >= self.seq_length:
123 | x_len.append(self.seq_length)
124 | temp_x = temp_x[:self.seq_length]
125 | if is_reverse:
126 | temp_x.reverse()
127 | else:
128 | x_len.append(len(temp_x))
129 | if is_reverse:
130 | temp_x.reverse()
131 | temp_x = temp_x + [0] * (self.seq_length - len(temp_x))
132 | train_x.append(temp_x)
133 | train_y.append(temp_y)
134 | train_len.append(x_len)
135 | self.raw_lines.append(tokens[2])
136 | tensor_x = np.array(train_x)
137 | tensor_y = np.array(train_y)
138 | tensor_len = np.array(train_len)
139 | # print("tensor_x.shape: " + str(tensor_x.shape))
140 | # print("len(self.raw_lines): " + str(len(self.raw_lines)))
141 |
142 | self.tensor = np.c_[tensor_x, tensor_y, tensor_len].astype(int) # tensor_x.size * (40+1+1)
143 |
144 | def list_split_n(self, raw_items, split_len_batches):
145 | split_items = []
146 | j = 0
147 | for i in range(len(split_len_batches)):
148 | split_items.append(raw_items[j: j + split_len_batches[i]])
149 | j += split_len_batches[i]
150 | return split_items
151 |
152 | def create_batches(self):
153 | self.num_batches = int(self.tensor.shape[0] / self.batch_size)
154 | if int(self.tensor.shape[0] % self.batch_size):
155 | self.num_batches = self.num_batches + 1
156 | if self.num_batches == 0:
157 | assert False, 'Not enough data, make batch_size small.'
158 | if self.is_train:
159 | np.random.shuffle(self.tensor)
160 | # print("self.num_batches: " + str(self.num_batches))
161 | tensor = self.tensor[:self.num_batches * self.batch_size]
162 | # print("len(tensor): " + str(len(tensor)))
163 | raw_lines = self.raw_lines[
164 | :self.num_batches * self.batch_size] # if train raw_lines order is different from tensor
165 | # print("len(raw_lines): " + str(len(raw_lines)))
166 |
167 | self.x_batches = np.array_split(tensor[:, :-2], self.num_batches, 0)
168 | self.y_batches = np.array_split(tensor[:, -2], self.num_batches, 0)
169 | self.len_batches = np.array_split(tensor[:, -1], self.num_batches, 0)
170 | split_len_batches = []
171 | for i in range(len(self.x_batches)):
172 | split_len_batches.append(len(self.x_batches[i]))
173 | self.raw_lines_batches = self.list_split_n(raw_lines, split_len_batches) # should split by np.array_split
174 | sum = 0
175 | # for i in range(len(self.x_batches)):
176 | # print("i: " + str(i) + "len(self.x_batches[i]): " + str(len(self.x_batches[i])))
177 | # sum += len(self.x_batches[i])
178 | # print("sum: " + str(sum))
179 | #
180 | # print("len(self.x_batches): " + str(len(self.x_batches)))
181 | # print("len(self.y_batches): " + str(len(self.y_batches)))
182 | # print("len(self.len_batches): " + str(len(self.len_batches)))
183 | # print("len(self.raw_lines_batches): " + str(len(self.raw_lines_batches)))
184 |
185 | def next_batch(self):
186 | batch_x = self.x_batches[self.pointer]
187 | batch_y = self.y_batches[self.pointer]
188 | xlen = self.len_batches[self.pointer]
189 | batch_line = self.raw_lines_batches[self.pointer]
190 |
191 | self.pointer += 1
192 | return batch_x, batch_y, xlen, batch_line
193 |
194 | def reset_batch_pointer(self):
195 | self.create_batches()
196 | self.pointer = 0
197 |
198 |
199 | if __name__ == "__main__":
200 | pass
201 |
--------------------------------------------------------------------------------
/utils/match_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | """
4 | tools for run bi-lstm + dssm short text matching
5 | """
6 |
7 | import random
8 | import math
9 | import copy
10 |
11 |
12 | class DataHelper(object):
13 |
14 | def __init__(self, train_path, valid_path, test_path, standard_path, batch_size, epcho_num, label2id_file, vocab2id_file, is_train):
15 | if is_train:
16 | self.train_valid_generator(train_path, valid_path, standard_path, batch_size, epcho_num)
17 | else:
18 | self.test_generator(test_path, standard_path, batch_size, label2id_file, vocab2id_file)
19 |
20 | def test_generator(self, test_path, standard_path, batch_size, label2id_file, vocab2id_file):
21 | self.label2id = {}
22 | self.id2label = {}
23 | self.vocab2id = {}
24 | self.id2vocab = {}
25 | self.std_id_ques = {}
26 | label_file = open(label2id_file, 'r', encoding='utf-8')
27 | for line in label_file.readlines():
28 | label_ids = line.strip().split('\t')
29 | self.label2id[label_ids[0]] = label_ids[1]
30 | self.id2label[label_ids[1]] = label_ids[0]
31 | label_file.close()
32 | vocab_file = open(vocab2id_file, 'r', encoding='utf-8')
33 | for line in vocab_file.readlines():
34 | vocab_ids = line.strip().split('\t')
35 | self.vocab2id[vocab_ids[0]] = vocab_ids[1]
36 | self.id2vocab[vocab_ids[1]] = vocab_ids[0]
37 | vocab_file.close()
38 | std_file = open(standard_path, 'r', encoding='utf-8')
39 | max_std_len = 0
40 | for line in std_file:
41 | label_words = line.strip().split("\t")
42 | label = label_words[1]
43 | w_temp = []
44 | words = label_words[2].split(" ")
45 | for word in words:
46 | w_temp.append(self.vocab2id[word])
47 | if max_std_len < len(w_temp):
48 | max_std_len = len(w_temp)
49 | self.std_id_ques[self.label2id[label]] = (len(w_temp), w_temp, label, line.strip())
50 | std_file.close()
51 | self.std_batch = []
52 | self.predict_label_seq = []
53 | self.predict_id_seq = []
54 | # when predicted test data must order by this sequence
55 | for std_id, ques_info in self.std_id_ques.items():
56 | self.std_batch.append((ques_info[0], ques_info[1]))
57 | self.predict_label_seq.append(self.id2label[std_id])
58 | self.predict_id_seq.append(std_id)
59 |
60 | # std question padding
61 | for ques_info in self.std_batch:
62 | for _ in range(max_std_len - ques_info[0]):
63 | ques_info[1].append(self.vocab2id['PAD'])
64 |
65 | file = open(test_path, 'r', encoding='utf-8')
66 | self.test_num = 0
67 | self.test_batch = []
68 | for line in file.readlines():
69 | label_words = line.strip().split('\t')
70 | #label include list answer(label likes id1,id2,...) and rufuse answer(label is 0)
71 | label = label_words[0]
72 | w_temp = []
73 | words = label_words[2].split(' ')
74 | for word in words:
75 | if word not in self.vocab2id:
76 | w_temp.append(self.vocab2id['UNK'])
77 | else:
78 | w_temp.append(self.vocab2id[word])
79 | self.test_batch.append((len(w_temp), w_temp, label, label_words[2]))
80 | self.test_num = self.test_num + 1
81 | file.close()
82 | self.batch_size = batch_size
83 | self.test_num_batch = math.ceil(self.test_num / self.batch_size)
84 |
85 | def train_valid_generator(self, train_path, valid_path, standard_path, batch_size, epcho_num):
86 | self.label2id = {}
87 | self.id2label = {}
88 | self.vocab2id = {}
89 | self.vocab2id['PAD'] = 0
90 | self.vocab2id['UNK'] = 1
91 | self.id2vocab = {}
92 | self.id2vocab[0] = 'PAD'
93 | self.id2vocab[1] = 'UNK'
94 | #standard question
95 | file = open(standard_path, 'r', encoding='utf-8')
96 | self.std_id_ques = {}
97 | max_std_len = 0
98 | for line in file.readlines():
99 | label_words = line.strip().split("\t")
100 | label = label_words[1]
101 | if label not in self.label2id:
102 | self.label2id[label] = len(self.label2id)
103 | self.id2label[self.label2id[label]] = label
104 | w_temp = []
105 | words = label_words[2].split(" ")
106 | for word in words:
107 | if word not in self.vocab2id:
108 | self.vocab2id[word] = len(self.vocab2id)
109 | self.id2vocab[self.vocab2id[word]] = word
110 | w_temp.append(self.vocab2id[word])
111 | if max_std_len < len(w_temp):
112 | max_std_len = len(w_temp)
113 | self.std_id_ques[self.label2id[label]] = (len(w_temp), w_temp, label, line.strip())
114 | file.close()
115 | self.std_batch = []
116 | self.predict_label_seq = []
117 | self.predict_id_seq = []
118 | #when predicted valid data must order by this sequence
119 | for std_id, ques_info in self.std_id_ques.items():
120 | self.std_batch.append((ques_info[0], ques_info[1]))
121 | self.predict_label_seq.append(self.id2label[std_id])
122 | self.predict_id_seq.append(std_id)
123 | self.train_num = 0
124 | self.train_id_ques = {}
125 | file = open(train_path, 'r', encoding='utf-8')
126 | for line in file.readlines():
127 | label_words = line.strip().split('\t')
128 | label = label_words[0]
129 | if ',' in label or '0' == label:
130 | continue
131 | assert label in self.label2id
132 | w_temp = []
133 | words = label_words[2].split(' ')
134 | for word in words:
135 | if word not in self.vocab2id:
136 | self.vocab2id[word] = len(self.vocab2id)
137 | self.id2vocab[self.vocab2id[word]] = word
138 | w_temp.append(self.vocab2id[word])
139 | label_id = self.label2id[label]
140 | if label_id not in self.train_id_ques:
141 | self.train_id_ques[label_id] = []
142 | self.train_id_ques[label_id].append((len(w_temp), w_temp))
143 | else:
144 | self.train_id_ques[label_id].append((len(w_temp), w_temp))
145 | self.train_num = self.train_num + 1
146 | file.close()
147 | self.vocab_size = len(self.vocab2id)
148 | #std question padding
149 | for ques_info in self.std_batch:
150 | for _ in range(max_std_len - ques_info[0]):
151 | ques_info[1].append(self.vocab2id['PAD'])
152 | file = open(valid_path, 'r', encoding='utf-8')
153 | self.valid_num = 0
154 | self.valid_batch = []
155 | for line in file.readlines():
156 | label_words = line.strip().split('\t')
157 | label = label_words[0]
158 | #del list answer and rufuse answer
159 | if ',' in label or '0' == label:
160 | continue
161 | assert label in self.label2id
162 | w_temp = []
163 | words = label_words[2].split(' ')
164 | for word in words:
165 | if word not in self.vocab2id:
166 | w_temp.append(self.vocab2id['UNK'])
167 | else:
168 | w_temp.append(self.vocab2id[word])
169 | self.valid_batch.append((len(w_temp), w_temp, label, label_words[2]))
170 | self.valid_num = self.valid_num + 1
171 | file.close()
172 |
173 | self.batch_size = batch_size
174 | self.train_num_epcho = epcho_num
175 | self.train_num_batch = math.ceil(self.train_num / self.batch_size)
176 | self.valid_num_batch = math.ceil(self.valid_num / self.batch_size)
177 |
178 |
179 | def weight_random(self, label_questions, batch_size):
180 | def index_choice(weight):
181 | index_sum_weight = random.randint(0, sum(weight) - 1)
182 | for i, val in enumerate(weight):
183 | index_sum_weight -= val
184 | if index_sum_weight < 0:
185 | return i
186 | return 0
187 | batch_keys = []
188 | keys = list(label_questions.keys()).copy()
189 | weights = [len(label_questions[key]) for key in keys]
190 | for _ in range(batch_size):
191 | index = index_choice(weights)
192 | key = keys.pop(index)
193 | batch_keys.append(key)
194 | weights.pop(index)
195 | return batch_keys
196 |
197 | def train_batch_iterator(self, label_questions, standard_label_question):
198 | '''
199 | select a couple question for each class
200 | '''
201 | num_batch = self.train_num_batch
202 | num_epcho = self.train_num_epcho
203 | for _ in range(num_batch * num_epcho):
204 | query_batch = []
205 | doc_batch = []
206 | batch_keys = self.weight_random(label_questions, self.batch_size)
207 | batch_query_max_num = 0
208 | for key in batch_keys:
209 | questions = copy.deepcopy(random.sample(label_questions[key], 1)[0])
210 | current_num = questions[0]
211 | if current_num > batch_query_max_num:
212 | batch_query_max_num = current_num
213 | query_batch.append(questions)
214 | doc = standard_label_question[key]
215 | doc_batch.append(doc)
216 | #padding
217 | for query, doc in zip(query_batch, doc_batch):
218 | for _ in range(batch_query_max_num - query[0]):
219 | query[1].append(self.vocab2id['PAD'])
220 | yield query_batch, doc_batch
221 |
222 | def valid_batch_iterator(self):
223 | num_batch = self.valid_num_batch
224 | num_epcho = 1
225 | for i in range(num_batch * num_epcho):
226 | if i * self.batch_size + self.batch_size < self.valid_num:
227 | query_batch = copy.deepcopy(self.valid_batch[i * self.batch_size : i * self.batch_size + self.batch_size])
228 | else:
229 | query_batch = copy.deepcopy(self.valid_batch[i * self.batch_size : ])
230 | batch_query_max_num = 0
231 | for q_len, _, _, _ in query_batch:
232 | if batch_query_max_num < q_len:
233 | batch_query_max_num = q_len
234 | #padding
235 | for q_len, label_words, _, _ in query_batch:
236 | for _ in range(batch_query_max_num - q_len):
237 | label_words.append(self.vocab2id['PAD'])
238 | yield query_batch
239 |
240 | def test_batch_iterator(self):
241 | num_batch = self.test_num_batch
242 | num_epcho = 1
243 | for i in range(num_batch * num_epcho):
244 | if i * self.batch_size + self.batch_size < self.test_num:
245 | query_batch = copy.deepcopy(self.test_batch[i * self.batch_size : i * self.batch_size + self.batch_size])
246 | else:
247 | query_batch = copy.deepcopy(self.test_batch[i * self.batch_size : ])
248 | batch_query_max_num = 0
249 | for q_len, _, _, _ in query_batch:
250 | if batch_query_max_num < q_len:
251 | batch_query_max_num = q_len
252 | #padding
253 | for q_len, label_words, _, _ in query_batch:
254 | for _ in range(batch_query_max_num - q_len):
255 | label_words.append(self.vocab2id['PAD'])
256 | yield query_batch
257 |
258 | if __name__ == "__main__":
259 | pass
260 |
--------------------------------------------------------------------------------