├── .gitignore ├── LICENSE ├── README.md ├── data_demo ├── pre_train_data ├── std_data ├── test_data ├── train_data ├── valid_data └── vocab ├── dec_mining ├── README.md ├── data │ ├── topk.std.text.avg │ ├── trainset │ └── vocab ├── data_utils.py ├── dataset.py ├── dec_model.py ├── images │ ├── modify_1.gif │ ├── modify_2.gif │ ├── 冷启动流程图.png │ ├── 轮廓系数公式.png │ └── 迭代挖掘流程图.png ├── inference.py ├── print_sen_embedding.py └── train.py ├── docs ├── RUNDEMO.md ├── dssm.png ├── kg_demo.png ├── lstm_dssm_bagging.png ├── measurement.png ├── pretrain.png └── sptm.png ├── dssm_predict.py ├── lstm_predict.py ├── merge_classifier_match_label.py ├── models ├── __init__.py ├── bilstm.py └── dssm.py ├── run_bi_lstm.py ├── run_dssm.py ├── sptm ├── format_result.py ├── models.py ├── run_classifier.py ├── run_prediction.py ├── run_pretraining.py └── utils.py └── utils ├── __init__.py ├── classifier_utils.py └── match_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | dec_mining/.idea/* 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2005-present, 58.com. All rights reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 项目介绍 2 | qa_match是一款基于深度学习的问答匹配工具,支持一层和两层结构知识库问答。qa_match通过意图匹配模型支持一层结构知识库问答,通过融合领域分类模型和意图匹配模型的结果支持两层结构知识库问答。qa_match同时支持无监督预训练功能,通过轻量级预训练语言模型(SPTM,Simple Pre-trained Model)可以提升基于知识库问答等下游任务的效果。 3 | 4 | ## 知识库问答 5 | 在实际场景中,知识库一般是通过人工总结、标注、机器挖掘等方式进行构建,知识库中包含大量的标准问题,每个标准问题有一个标准答案和一些扩展问法,我们称这些扩展问法为扩展问题。对于一层结构知识库,仅包含标准问题和扩展问题,我们把标准问题称为意图。对于两层结构知识库,每个标准问题及其扩展问题都有一个类别,我们称为领域,一个领域包含多个意图。 6 | 7 | qa_match支持知识库结构如下: 8 | 9 | ![知识库结构](docs/kg_demo.png) 10 | 11 | 对于输入的问题,qa_match能够结合知识库给出三种回答: 12 | 1. 唯一回答(识别为用户具体的意图) 13 | 2. 列表回答(识别为用户可能的多个意图) 14 | 3. 拒识(没有识别出具体的用户意图) 15 | 16 | 在两种知识库结构下,qa_match的使用方式存在差异,以下分别说明: 17 | 18 | ### 基于两层结构知识库的自动问答 19 | 20 | ![两层结果融合](docs/lstm_dssm_bagging.png) 21 | 22 | 对于两层结构知识库问答,qa_match会对用户问题先进行领域分类和意图识别,然后对两者的结果进行融合,获取用户的真实意图进行相应回答(唯一回答、列表回答、拒绝回答)。 23 | 举个例子:如上述知识库问答中[知识库结构图](#知识库问答)所示,我们有一个两层结构知识库,它包括”信息“和”账号“两个领域”。其中“信息”领域下包含两个意图:“如何发布信息”、“如何删除信息”,“账号”领域下包含意图:“如何注销账号”。当用户输入问题为:“我怎么发布帖子?”时,qa_match会进行如下问答逻辑: 24 | 25 | 1. 分别用LSTM领域分类模型和DSSM意图匹配模型对输入问题进行打分,如:领域分类模型最高打分为0.99且识别为“信息”领域,意图匹配模型最高打分为0.98且识别为“如何发布信息”意图。由于领域分类模型最高打分对应的label为信息类,所以进入判断为某一类分支。 26 | 2. 进入判断为某一分类分支后,用领域分类模型的最高打分0.99与两层结构知识库问答图中阈值b1(如b1=0.9)进行对比,由于0.99>=b1,判断为走“严格DSSM意图匹配”子分支。 27 | 3. 进入“严格DSSM意图匹配”分支后,用意图匹配模型的最高打分0.98与阈值x1(例如x1=0.8),x2(如x2=0.95)做比较,发现0.98>x2,由此用如何发布信息对应的答案进行唯一回答(其他分支回答类似)。 28 | 29 | ### 基于一层结构知识库的自动问答 30 | 31 | 实际场景中,我们也会遇到一层结构知识库问答问题,用DSSM意图匹配模型与SPTM轻量级预训练语言模型均可以解决此类问题,两者对比: 32 | 33 | | 模型 | 使用方法 | 优点 | 缺点 | 34 | | ------------------------ | ---------------------------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | 35 | | DSSM意图匹配模型 | DSSM匹配模型直接匹配 | ①使用简便,模型占用空间小
②训练/预测速度快 | 无法利用上下文信息 | 36 | | SPTM轻量级预训练语言模型 | 预训练LSTM/Transformer语言模型
+微调LSTM/Transformer匹配模型 | ①能够充分利用无监督预训练数据提升效果
②语言模型可用于其他下游任务 | ①预训练需要大量无标签数据
②操作较复杂(需两个步骤得到匹配模型) | 37 | 38 | #### 基于DSSM模型的自动问答 39 | ![一层结果融合](docs/dssm.png) 40 | 41 | 对于一层结构知识库问答,只需用DSSM意图匹配模型对输入问题进行打分,根据意图匹配的最高分值与上图中的x1,x2进行比较决定回答类型(唯一回答、列表回答、拒识)。 42 | 43 | #### 基于SPTM模型的自动问答 44 | 45 | ##### 轻量级预训练语言模型(SPTM,Simple Pre-trained Model)介绍 46 | 47 | 考虑到实际使用中往往存在大量的无标签数据,在知识库数据有限时,可使用无监督预训练语言模型提升匹配模型的效果。参考[BERT](https://github.com/google-research/bert)预训练过程,2019年5月我们开发了SPTM模型,该模型相对于BERT主要改进了三方面:一是去掉了效果不明显的NSP(Next Sentence Prediction),二是为了提高线上推理性能将Transformer替换成了LSTM,三是为了保证模型效果降低参数量也提供了BLOCK间共享参数的Transformer,模型原理如下: 48 | 49 | ###### 数据预处理 50 | 51 | 预训练模型时,生成训练数据需要使用无标签单句作为数据集,并参考了BERT来构建样本:每个单句作为一个样本,句子中15%的字参与预测,参与预测的字中80%进行mask,10%随机替换成词典中一个其他的字,10%不替换。 52 | 53 | ###### 预训练 54 | 55 | 预训练阶段的模型结构如下图所示: 56 | 57 | 58 | 59 | ![模型结构](docs/sptm.png) 60 | 61 | 为提升模型的表达能力,保留更多的浅层信息,引入了残差Bi-LSTM网络(Residual LSTM)作为模型主体。该网络将每一层Bi-LSTM的输入和该层输出求和归一化后,结果作为下一层的输入。此外将最末层Bi-LSTM输出作为一个全连接层的输入,与全连接层输出求和归一化后,结果作为整个网络的输出。 62 | 63 | 预训练任务耗时示例如下表所示: 64 | 65 | | **指标名称** | **指标值** | **指标值** | **指标值** | 66 | | ---------------- | -------------------------------- | ----------------------------------------------- | ----------------------------------------------- | 67 | |模型结构 | LSTM | 共享参数的Transformer | 共享参数的Transformer | 68 | | 预训练数据集大小 | 10Million | 10Million | 10Million | 69 | | 预训练资源 | 10台Nvidia K40 / 12G Memory | 10台Nvidia K40 / 12G Memory | 10台Nvidia K40 / 12G Memory | 70 | | 预训练参数 | step = 100000 / batch size = 128 | step = 100000 / batch size = 128 / 1 layers / 12 heads | step = 100000 / batch size = 128 / 12 layers / 12 heads | 71 | | 预训练耗时 | 8.9 hours | 13.5 hours | 32.9 hours | 72 | | 预训练模型大小 | 81M | 80.6M | 121M | 73 | 74 | ##### SPTM自动问答流程 75 | 76 | ![预训练语言模型](docs/pretrain.png) 77 | 78 | 引入SPTM后,对于一层结构知识库问答,先使用基于语言模型微调的意图匹配模型对输入问题进行打分,再根据与DSSM意图匹配模型相同的策略决定回答类型(唯一回答、列表回答、拒识)。 79 | 80 | ## 如何使用 81 | ### 数据介绍 82 | 需要使用到的数据文件(data_demo文件夹下)格式说明如下,这里为了不泄露数据,我们对标准问题和扩展问题原始文本做了编码,在实际应用场景中直接按照以下格式准备数据即可。 83 | - std_data:类别和标准问题对应关系,包含类别ID、标准问题ID、标准问题文本三列 84 | - pre_train_data:无标签的预训练数据集,每行是一段文本 85 | - vocab:预训练数据字典,每行是一个词(字典应包含``、`、``) 86 | - train_data:训练集,包含标准问题ID、扩展问题ID、扩展问题文本三列 87 | - valid_data:验证集,包含标准问题ID、扩展问题ID、扩展问题文本三列 88 | - test_data:测试集,包含标准问题ID、扩展问题ID、扩展问题文本三列 89 | 90 | 数据以\t分隔,问题编码以空格分隔,字之间以空格分隔。注意在本项目的数据示例中,对原始文本做了编码,将每个字替换为一个数字, 例如`205 19 90 417 41 44` 对应的实际文本是`如何删除信息`,**在实际使用时不需要做该编码操作**;若知识库结构为一级,需要把std_data文件中的类别id全部设置为`__label__0`。 91 | 92 | ### 知识库半自动挖掘流程 93 | 94 | 知识库半自动挖掘流程,是在qa match自动问答流程的基础上(参考qa match 基于一层知识库结构的自动问答)构建的一套知识库半自动挖掘方案,帮助提升知识库规模与知识库质量,一方面提高线上匹配的能力;一方面提高离线模型训练数据的质量,进而提高模型性能。知识库半自动挖掘流程可以用于冷启动挖掘和模型上线后迭代挖掘两个场景。详情参见[知识库挖掘说明文档](./dec_mining/README.md) 95 | 96 | ### 怎么运行 97 | 详情见[运行说明](docs/RUNDEMO.md) 98 | 99 | ### tips 100 | 1. 由于DSSM模型训练选取负样本时是将原样本对应标签随机打散,所以模型参数需要满足`batch_size >= negitive_size`,否则模型无法有效训练。 101 | 2. 模型融合参数选取方法:目前参数的选取是基于统计的,首先在测试集上计算同一参数(如两层结构知识库问答图中a1)不同值所对应的模型label(如拒识)的f1值,然后选取较大的f1值对应的数值做为该参数的取值。如:在选取两层结构知识库问答图中参数a1的最终取值时,首先会在测试集上得到不同a1候选值对应的模型label(如拒识,非拒识),然后根据样本的真实label计算f1值,最后选取合适的f1值(根据项目需求可偏重准确率/召回率)对应的候选值作为a1的最终取值。 102 | 103 | ## 运行环境 104 | ``` 105 | tensorflow 版本>r1.8 阈值的样本点作为扩展问题,进行人工审核入库。 50 | 51 | 52 | 53 | ## 第二个场景:迭代挖掘流程 54 | 55 | 56 | 57 | 知识库迭代挖掘场景具体指模型已经上线后,知识库中已经有了一定数量的标准问题和扩展问题,但是由于线上数据是动态变化的,所以存在模型覆盖不到的标准问题和扩展问法,迭代挖掘的目的就是及时的将它们挖掘出来,增加线上样本覆盖度,从而提高模型准召。 58 | 59 | 60 | 61 | 迭代挖掘流程图如下: 62 | 63 | 64 | 65 | ![冷启动挖掘流程](./images/迭代挖掘流程图.png) 66 | 67 | 68 | 69 | ### 迭代挖掘步骤 70 | 71 | 72 | 73 | 1. 基于目前自动问答流程(参考[qa match 基于一层知识库结构的自动问答](https://github.com/wuba/qa_match/tree/master#%E5%9F%BA%E4%BA%8E%E4%B8%80%E5%B1%82%E7%BB%93%E6%9E%84%E7%9F%A5%E8%AF%86%E5%BA%93%E7%9A%84%E8%87%AA%E5%8A%A8%E9%97%AE%E7%AD%94)),从线上拒识问题以及每周人工抽样标注的新分类问题(目前标准问题没有覆盖到并且非拒识的问题)中提取新知识。 74 | 75 | 2. 粗略筛除几类问题: 76 | 77 | a)超短query(长度小于3的线上问题,这个类别是optional的,根据具体场景实现)。此类query 在问答场景通常会被拒识,若不拒识,大部分通过匹配实现。 78 | 79 | b)高频问题。对于高频问题,一定要覆盖,不需要再经过挖掘,直接筛选出来交给人工审核进行入库,剩余问题送入DEC算法模块进行挖掘。 80 | 81 | 3. 初步筛出比较纯粹的query 之后,使用已有标准问题作为自定义聚类中心,选取聚类结果概率值 > 阈值的样本点作为扩展问题,进行人工审核入库;对于挖掘新类别的标准问题,可以参考冷启动场景的方法进行挖掘。 82 | 83 | # 运行说明 84 | 85 | 86 | 87 | 本示例给出了支持一层结构知识库的基于SPTM表征的DEC挖掘算法运行demo、评测指标及在测试集的效果。 88 | 89 | 90 | 91 | ## 数据介绍 92 | 93 | 94 | 95 | 需要使用到的数据文件(dec_mining/data 文件夹下)格式说明如下,这里为了不泄露数据,我们对标准问题和扩展问题原始文本做了编码,在实际应用场景中直接按照以下格式准备数据即可。所给出聚类数据集,取自58智能问答生产环境下的真实数据,这里仅为了跑通模型,因此只取了少部分数据,其中待聚类数据1w。 96 | 97 | 98 | 99 | * [trainset](./data/trainset):待聚类数据。两列\t分隔,第一列为 ground truth 标签问题ID ,格式 `__label__n`;如果没有ground truth 标签,也需要设置一个n进行占位,查看聚类结果时忽略掉即可;第二列为标准问题文本,分字格式(空格切分)。 100 | * [topk.std.text.avg](./data/topk.std.text.avg) : 自定义聚类中心文件。该文件为从topk问题总结出的标准问,每行为一个聚类中心,支持多个问题的平均作为聚类中心的格式 使用斜杠/分隔,如:“你好这辆车还在么/车还在吗/这车还在吗” 101 | 102 | 103 | 104 | ## 运行示例 105 | 106 | 107 | 108 | 使用DEC算法进行聚类需要两步,先train 也就是先微调表征,然后再做inference得到聚类结果。 109 | 110 | 111 | 112 | 113 | (1) 根据自定义聚类中心文本得到表征,此步骤可选,如果选择使用K-means做初始化,则不需要此步骤,在步骤(2)指定 `n_clusters` 即可 114 | 115 | 116 | ```bash 117 | cd dec_mining && python3 print_sen_embedding.py --input_file=./topk.std.text.avg --vocab_file=./vocab --model_path=./pretrain_model/lm_pretrain.ckpt-1000000 --batch_size=512 --max_seq_len=50 --output_file=./topk.std.embedding.avg --embedding_way=max 118 | ``` 119 | 120 | 参数说明: 121 | 122 | input_file:自定义聚类中心文本。 123 | 124 | vocab_file : 词典文件(需要包含 `` ) 125 | 126 | model_path: SPTM预训练表征模型,预训练模型的embedding的维度要跟第(2)步的embedding_dim参数保持一致 127 | 128 | max_seq_len: 截断的最大长度 129 | 130 | output_file: 输出的自定义聚类中心表征文件 131 | 132 | 133 | 134 | 135 | (2) 使用待聚类数据进行微调表征 136 | 137 | ```bash 138 | cd dec_mining && python3 ./train.py --init_checkpoint=./pretrain_model/lm_pretrain.ckpt-1000000 --train_file=./data/trainset --epochs=30 --lstm_dim=128 --embedding_dim=256 --vocab_file=./vocab --external_cluster_center=./topk.std.embedding.avg --model_save_dir=./saved_model --learning_rate=0.03 --warmup_steps=5000 139 | ``` 140 | 141 | 142 | 143 | 参数说明: 144 | 145 | 146 | 147 | init_checkpoint: SPTM预训练表征模型 148 | 149 | train_file: 待聚类数据 150 | 151 | epochs: epoch 数量 152 | 153 | n_clusters: K-means 方法指定聚类中心数;此参数与external_cluster_center 只传入一个即可,需要与步骤(3)中inference 过程使用的参数一致,指定`n_clusters`表示使用K-means做初始化,指定external_cluster_center表示使用自定义聚类中心做初始化 154 | 155 | lstm_dim: SPTM lstm的门控单元数 156 | 157 | embedding_dim: SPTM 词嵌入维度,需要设置为lstm_dim 参数的2倍 158 | 159 | vocab_file: 词典文件(需要包含 ``) 160 | 161 | external_cluster_center: 自定义聚类中心文件,此参数与n_clusters 只传入一个即可,需要与步骤(3)中inference 过程使用的参数一致,指定`n_clusters`表示使用K-means做初始化,指定external_cluster_center表示使用自定义聚类中心做初始化 162 | 163 | model_save_dir: DEC模型保存路径 164 | 165 | learning_rate: 学习率 166 | 167 | 168 | warmup_steps:学习率 warm up 步数 169 | 170 | 171 | (3) 根据微调好的表征对待聚类数据进行DEC聚类 172 | 173 | 174 | 175 | ```bash 176 | cd dec_mining && python3 inference.py --model_path=./saved_model/finetune.ckpt-0 --train_file=./data/trainset --external_cluster_center=./topk.std.embedding.avg --lstm_dim=128 --embedding_dim=256 --vocab_file=./vocab --pred_score_path=./pred_score 177 | ``` 178 | 179 | 180 | 181 | 参数说明: 182 | 183 | 184 | 185 | model_path: 上一步train 得到的DEC模型 186 | 187 | train_file: 待聚类数据 188 | 189 | lstm_dim: SPTM lstm的门控单元数 190 | 191 | embedding_dim: SPTM 词嵌入维度 192 | 193 | vocab_file: 词典文件(需要包含 `` ) 194 | 195 | pred_score_path: 聚类结果打分文件,格式:`pred_label + \t + question + \t + groundtruth_label + \t + probability` 例如:`__label__4` 请添加车主阿涛微信详谈 `__label__0 ` 00.9888488 196 | 197 | 198 | 199 | ## 算法评测指标及测试集效果 200 | 201 | 202 | 203 | 聚类算法的评估一般分为外部评估和内部评估,外部评估是指数据集有ground truth label 时通过有监督标签进行评估; 内部评估是不借助外部可信标签,单纯从无监督数据集内部评估,内部评估的原则是类内距小,类间距大,这里我们使用轮廓系数(silhouette)来评估。 204 | 205 | 206 | 207 | **轮廓系数 silhouette coefficient** 208 | 209 | 210 | 211 | ![冷启动挖掘流程](./images/轮廓系数公式.png) 212 | 213 | 214 | 215 | * a(i) = avg(i 向量到所有它属于的簇中其他点的距离) 216 | 217 | * b(i) = min(i 向量与其他的簇内的所有点的平均) 218 | 219 | * 值介于 [-1,1] ,越趋近于1代表内聚度和分离度都相对较优 220 | 221 | * 将所有点的轮廓系数求平均,就是该聚类结果总的轮廓系数 222 | 223 | 224 | 225 | **准确率accuracy计算** 226 | 227 | [sklearn linear_assignment](https://www.kite.com/python/docs/sklearn.utils.linear_assignment_.linear_assignment) 228 | 229 | 230 | 231 | 归纳top10 问题后,运行后评测效果如下(使用通用深度学习推理服务[dl_inference](https://github.com/wuba/dl_inference)开源项目部署模型来评测推理耗时): 232 | 233 | 234 | 235 | | 数据集 | 模型 | **Silhouette** | **Runtime** | **Inference Time** | **Accuracy** | 236 | | ------ | ---- | -------------- | ----------- |------------ | ------------ | 237 | | 1w | DEC | 0.7962 | 30min | 52s |0.8437 | 238 | | 10W | DEC | 0.9302 | 3h 5min | 5min 55s |-- | 239 | | 100W | DEC | 0.849 | 11h30min |15min 28s | -- | 240 | 241 | **tips:** 242 | 243 | 1. 由于实验场景有标签数据集数量 < 10w,因此10w, 100w数据集上没有accuracy的数值 244 | 245 | 246 | 247 | ## 运行环境 248 | 249 | 250 | 251 | ``` 252 | tensorflow 版本>r1.8 1556 | 1557 | -------------------------------------------------------------------------------- /dec_mining/data_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | data utils 4 | """ 5 | 6 | import os 7 | import sys 8 | import codecs 9 | import collections 10 | import tensorflow as tf 11 | import numpy as np 12 | import six 13 | 14 | 15 | 16 | def convert_to_unicode(text): 17 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 18 | if six.PY3: 19 | if isinstance(text, str): 20 | return text 21 | elif isinstance(text, bytes): 22 | return text.decode("utf-8", "ignore") 23 | else: 24 | raise ValueError("Unsupported string type: %s" % (type(text))) 25 | elif six.PY2: 26 | if isinstance(text, str): 27 | return text.decode("utf-8", "ignore") 28 | elif isinstance(text, unicode): 29 | return text 30 | else: 31 | raise ValueError("Unsupported string type: %s" % (type(text))) 32 | else: 33 | raise ValueError("Not running on Python2 or Python 3?") 34 | 35 | class Sentence: 36 | def __init__(self, raw_tokens, raw_label): 37 | self.raw_tokens = raw_tokens 38 | self.raw_label = raw_label 39 | self.label_id = None 40 | self.token_ids = [] 41 | 42 | def to_ids(self, word2id, label2id, max_len): 43 | self.label_id = label2id[self.raw_label] 44 | self.raw_tokens = self.raw_tokens[:max_len] # cut off to the max length 45 | all_unk = True 46 | for raw_token in self.raw_tokens: 47 | if raw_token not in ['', '', '', '', '', '', '']: 48 | raw_token = raw_token.lower() 49 | if raw_token in word2id: 50 | self.token_ids.append(word2id[raw_token]) 51 | all_unk = False 52 | else: 53 | self.token_ids.append(word2id[""]) 54 | if all_unk: 55 | tf.logging.info("all unk" + self.raw_tokens) 56 | 57 | self.token_ids = self.token_ids + [0] * (max_len - len(self.token_ids)) 58 | 59 | def gen_ids(sens, word2id, label2id, max_len): 60 | for sen in sens: 61 | sen.to_ids(word2id, label2id, max_len) 62 | 63 | # convert dataset to tensor 64 | def make_full_tensors(sens): 65 | tokens = np.zeros((len(sens), len(sens[0].token_ids)), dtype=np.int32) 66 | labels = np.zeros((len(sens)), dtype=np.int32) 67 | length = np.zeros((len(sens)), dtype=np.int32) 68 | for idx, sen in enumerate(sens): 69 | tokens[idx] = sen.token_ids 70 | labels[idx] = sen.label_id 71 | length[idx] = len(sen.raw_tokens) 72 | return tokens, length, labels 73 | 74 | def gen_batchs(full_tensors, batch_size, is_shuffle): 75 | tokens, labels, length = full_tensors 76 | # per = np.array([i for i in range(len(tokens))]) 77 | per = np.array(list(range(len(tokens)))) 78 | if is_shuffle: 79 | np.random.shuffle(per) 80 | 81 | cur_idx = 0 82 | token_batch = [] 83 | label_batch = [] 84 | length_batch = [] 85 | while cur_idx < len(tokens): 86 | token_batch.append(tokens[per[cur_idx]]) 87 | label_batch.append(labels[per[cur_idx]]) 88 | length_batch.append(length[per[cur_idx]]) 89 | if len(token_batch) == batch_size or cur_idx == len(tokens) - 1: 90 | yield token_batch, label_batch, length_batch 91 | token_batch = [] 92 | label_batch = [] 93 | length_batch = [] 94 | cur_idx += 1 95 | 96 | def load_sentences(file_path, skip_invalid): 97 | sens = [] 98 | invalid_num = 0 99 | max_len = 0 100 | for raw_l in codecs.open(file_path, 'r', 'utf-8'): # load as utf-8 encoding. 101 | if raw_l.strip() == "": 102 | continue 103 | file_s = raw_l.rstrip().split('\t') 104 | assert len(file_s) == 2 105 | tokens = file_s[1].split() # discard empty strings 106 | for token in tokens: 107 | assert token != "" 108 | label = file_s[0] 109 | if skip_invalid: 110 | if label.find(',') >= 0 or label.find('NONE') >= 0: 111 | invalid_num += 1 112 | continue 113 | if len(tokens) > max_len: 114 | max_len = len(tokens) 115 | sens.append(Sentence(tokens, label)) 116 | tf.logging.info("invalid sen num : " + str(invalid_num)) 117 | tf.logging.info("valid sen num : " + str(len(sens))) 118 | tf.logging.info("max_len : " + str(max_len)) 119 | return sens 120 | 121 | def load_vocab(sens, vocab_file): 122 | label2id = {} 123 | id2label = {} 124 | for sen in sens: 125 | if sen.raw_label not in label2id: 126 | label2id[sen.raw_label] = len(label2id) 127 | id2label[len(id2label)] = sen.raw_label 128 | 129 | index = 0 130 | word2id = collections.OrderedDict() 131 | id2word = collections.OrderedDict() 132 | for l_raw in codecs.open(vocab_file, 'r', 'utf-8'): 133 | token = convert_to_unicode(l_raw) 134 | # if not token: 135 | # break 136 | token = token.strip() 137 | word2id[token] = index 138 | # id2word[index] = token 139 | index += 1 140 | 141 | for k, value in word2id.items(): 142 | id2word[value] = k 143 | 144 | assert len(word2id) == len(id2word) 145 | tf.logging.info("token num : " + str(len(word2id))) 146 | tf.logging.info("label num : " + str(len(label2id))) 147 | tf.logging.info("labels: " + str(id2label)) 148 | return word2id, id2word, label2id, id2label 149 | 150 | def evaluate(sess, full_tensors, args, model): 151 | total_num = 0 152 | right_num = 0 153 | for batch_data in gen_batchs(full_tensors, args.batch_size, is_shuffle=False): 154 | softmax_re = sess.run(model.softmax_op, 155 | feed_dict={model.ph_dropout_rate: 0, 156 | model.ph_tokens: batch_data[0], 157 | model.ph_labels: batch_data[1], 158 | model.ph_length: batch_data[2]}) 159 | pred_re = np.argmax(softmax_re, axis=1) 160 | total_num += len(pred_re) 161 | right_num += np.sum(pred_re == batch_data[1]) 162 | acc = 1.0 * right_num / (total_num + 1e-5) 163 | 164 | tf.logging.info("dev total num: " + str(total_num) + ", right num: " + str(right_num) + ", acc: " + str(acc)) 165 | return acc 166 | 167 | def load_spec_centers(path): 168 | raw_f = open(path, "r", encoding="utf-8") 169 | f_lines = raw_f.readlines() 170 | 171 | res = [] 172 | for line in f_lines: 173 | vec = [float(i) for i in line.strip().split(" ")] 174 | res.append(vec) 175 | return tf.convert_to_tensor(res), len(res) 176 | 177 | def write_file(out_path, out_str): 178 | exists = os.path.isfile(out_path) 179 | if exists: 180 | os.remove(out_path) 181 | tf.logging.info("File Removed!") 182 | 183 | raw_f = open(out_path, "w", encoding="utf-8") 184 | raw_f.write(out_str) 185 | raw_f.close() 186 | 187 | def load_vocab_file(vocab_file): 188 | word2id = {} 189 | id2word = {} 190 | for raw_l in codecs.open(vocab_file, 'r', 'utf8'): 191 | raw_l = raw_l.strip() 192 | assert raw_l != "" 193 | assert raw_l not in word2id 194 | word2id[raw_l] = len(word2id) 195 | id2word[len(id2word)] = raw_l 196 | tf.logging.info("uniq token num : " + str(len(word2id)) + "\n") 197 | return word2id, id2word 198 | -------------------------------------------------------------------------------- /dec_mining/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | dataset class 4 | """ 5 | import random 6 | import math 7 | import numpy as np 8 | import data_utils 9 | 10 | 11 | class Dataset(): 12 | def __init__(self, train_x=None, train_y=None, test_x=None, 13 | test_y=None, train_length_x=None, test_length_x=None): 14 | self.train_x = train_x 15 | self.train_length_x = train_length_x 16 | self.test_length_x = test_length_x 17 | self.train_y = train_y 18 | self.test_x = test_x 19 | self.test_y = test_y 20 | 21 | def gen_next_batch(self, batch_size, is_train_set, epoch=None, iteration=None): 22 | if is_train_set is True: 23 | raw_x = self.train_x 24 | x_length = self.train_length_x 25 | raw_y = self.train_y 26 | else: 27 | raw_x = self.test_x 28 | x_length = self.test_length_x 29 | raw_y = self.test_y 30 | 31 | assert len(raw_x) >= batch_size,\ 32 | "batch size must be smaller than data size {}.".format(len(raw_x)) 33 | 34 | if epoch is not None: 35 | until = math.ceil(float(epoch * len(raw_x)) / float(batch_size)) 36 | elif iteration is not None: 37 | until = iteration 38 | else: 39 | assert False, "epoch or iteration must be set." 40 | 41 | iter_ = 0 42 | index_list = list(range(len(raw_x))) 43 | while iter_ <= until: 44 | idxs = random.sample(index_list, batch_size) 45 | iter_ += 1 46 | yield (raw_x[idxs], raw_y[idxs], idxs, x_length[idxs]) 47 | 48 | 49 | class ExpDataset(Dataset): 50 | def __init__(self, args): 51 | super().__init__() 52 | 53 | train_file = args.train_file 54 | vocab_file = args.vocab_file 55 | 56 | train_sens = data_utils.load_sentences(train_file, skip_invalid=True) 57 | word2id, id2word, label2id, id2label = data_utils.load_vocab(train_sens, vocab_file) 58 | 59 | data_utils.gen_ids(train_sens, word2id, label2id, 100) 60 | train_full_tensors = data_utils.make_full_tensors(train_sens) 61 | 62 | raw_x = train_full_tensors[0] 63 | x_length = train_full_tensors[1] 64 | x_labels = train_full_tensors[2] 65 | 66 | raw_f = lambda t: id2label[t] 67 | x_labels_true = np.array(list(map(raw_f, x_labels))) 68 | 69 | n_train = int(len(raw_x) * 1) 70 | self.train_x, self.test_x = raw_x[:n_train], raw_x[n_train:] 71 | self.train_length_x, self.test_length_x = x_length[:n_train], x_length[n_train:] 72 | self.train_y, self.test_y = x_labels[:n_train], x_labels[n_train:] 73 | self.gt_label = x_labels_true 74 | self.raw_q = ["".join(i.raw_tokens) for i in train_sens] 75 | -------------------------------------------------------------------------------- /dec_mining/dec_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | modified DEC model 4 | """ 5 | import collections 6 | import re 7 | import sys 8 | import tensorflow as tf 9 | import numpy as np 10 | from sklearn.cluster import KMeans 11 | from sklearn.utils.linear_assignment_ import linear_assignment 12 | sys.path.append("..") 13 | from sptm import models 14 | import data_utils 15 | 16 | 17 | """ modified DEC model 18 | References 19 | ---------- 20 | https://github.com/HaebinShin/dec-tensorflow/blob/master/dec/model.py 21 | """ 22 | class DEC: 23 | def __init__(self, params, other_params): 24 | # pass n_clusters or external_cluster_center 25 | if params.external_cluster_center == "": 26 | self.n_cluster = params.n_clusters 27 | self.kmeans = KMeans(n_clusters=params.n_clusters, n_init=10) 28 | else: 29 | # get cluster_center_embedding, n_cluster from external file 30 | self.external_cluster_center_vec, self.n_cluster = \ 31 | data_utils.load_spec_centers(params.external_cluster_center) 32 | 33 | self.embedding_dim = params.embedding_dim 34 | # load SPTM pretrained model 35 | model = models.create_finetune_classification_model(params, other_params) 36 | self.pretrained_model = model 37 | self.alpha = params.alpha 38 | 39 | # mu: cluster center 40 | self.mu = tf.Variable(tf.zeros(shape=(self.n_cluster, self.embedding_dim)), 41 | name="mu") # [n_class, emb_dim] 42 | 43 | self.z = model.max_pool_output # [batch, emb_dim] 44 | with tf.name_scope("distribution"): 45 | self.q = self._soft_assignment(self.z, self.mu) # [, n_class] 46 | self.p = tf.placeholder(tf.float32, shape=(None, self.n_cluster)) # [, n_class] 47 | self.pred = tf.argmax(self.q, axis=1) 48 | self.pred_prob = tf.reduce_max(self.q, axis=1) 49 | 50 | with tf.name_scope("dec-train"): 51 | self.loss = self._kl_divergence(self.p, self.q) 52 | self.global_step_op = tf.train.get_or_create_global_step() 53 | self.lr = params.learning_rate 54 | warmup_steps = params.warmup_steps 55 | warmup_lr = (self.lr * tf.cast(self.global_step_op, tf.float32) 56 | / tf.cast(warmup_steps, tf.float32)) 57 | self.warmup_learning_rate_op = \ 58 | tf.cond(self.global_step_op < warmup_steps, lambda: warmup_lr, lambda: self.lr) 59 | self.optimizer = tf.train.AdamOptimizer(self.warmup_learning_rate_op) 60 | self.trainer = self.optimizer.minimize(self.loss, global_step=self.global_step_op) 61 | 62 | def get_assign_cluster_centers_op(self, features): 63 | # init mu 64 | tf.logging.info("Kmeans train start.") 65 | kmeans = self.kmeans.fit(features) 66 | tf.logging.info("Kmeans train end.") 67 | return tf.assign(self.mu, kmeans.cluster_centers_) 68 | 69 | # emb [batch, emb_dim] centroid [n_class, emb_dim] 70 | def _soft_assignment(self, embeddings, cluster_centers): 71 | """Implemented a soft assignment as the probability of assigning sample i to cluster j. 72 | 73 | Args: 74 | embeddings: (num_points, dim) 75 | cluster_centers: (num_cluster, dim) 76 | 77 | Return: 78 | q_i_j: (num_points, num_cluster) 79 | """ 80 | def _pairwise_euclidean_distance(a, b): 81 | # p1 [batch, n_class] 82 | p1 = tf.matmul( 83 | tf.expand_dims(tf.reduce_sum(tf.square(a), 1), 1), 84 | tf.ones(shape=(1, self.n_cluster)) 85 | ) 86 | # p2 [batch, n_class] 87 | p2 = tf.transpose(tf.matmul( 88 | tf.reshape(tf.reduce_sum(tf.square(b), 1), shape=[-1, 1]), 89 | tf.ones(shape=(tf.shape(a)[0], 1)), 90 | transpose_b=True 91 | )) 92 | # [batch, n_class] 93 | res = tf.sqrt( 94 | tf.abs(tf.add(p1, p2) - 2 * tf.matmul(a, b, transpose_b=True))) 95 | 96 | return res 97 | 98 | dist = _pairwise_euclidean_distance(embeddings, cluster_centers) 99 | q = 1.0 / (1.0 + dist ** 2 / self.alpha) ** ((self.alpha + 1.0) / 2.0) 100 | q = (q / tf.reduce_sum(q, axis=1, keepdims=True)) 101 | return q 102 | 103 | def target_distribution(self, q): 104 | p = q ** 2 / q.sum(axis=0) 105 | p = p / p.sum(axis=1, keepdims=True) 106 | return p 107 | 108 | def _kl_divergence(self, target, pred): 109 | return tf.reduce_mean(tf.reduce_sum(target * tf.log(target / (pred)), axis=1)) 110 | 111 | def cluster_acc(self, y_true, y_pred): 112 | """ 113 | Calculate clustering accuracy. Require scikit-learn installed 114 | # Arguments 115 | y: true labels, numpy.array with shape `(n_samples,)` 116 | y_pred: predicted labels, numpy.array with shape `(n_samples,)` 117 | # Return 118 | accuracy, in [0,1] 119 | """ 120 | y_true = y_true.astype(np.int64) 121 | assert y_pred.size == y_true.size 122 | D = max(y_pred.max(), y_true.max()) + 1 123 | w = np.zeros((D, D), dtype=np.int64) 124 | for i in range(y_pred.size): 125 | w[y_pred[i], y_true[i]] += 1 126 | ind = linear_assignment(w.max() - w) 127 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size 128 | -------------------------------------------------------------------------------- /dec_mining/images/modify_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/dec_mining/images/modify_1.gif -------------------------------------------------------------------------------- /dec_mining/images/modify_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/dec_mining/images/modify_2.gif -------------------------------------------------------------------------------- /dec_mining/images/冷启动流程图.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/dec_mining/images/冷启动流程图.png -------------------------------------------------------------------------------- /dec_mining/images/轮廓系数公式.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/dec_mining/images/轮廓系数公式.png -------------------------------------------------------------------------------- /dec_mining/images/迭代挖掘流程图.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/dec_mining/images/迭代挖掘流程图.png -------------------------------------------------------------------------------- /dec_mining/inference.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | inference cluster labels 4 | """ 5 | import argparse 6 | from datetime import datetime 7 | import numpy as np 8 | import tensorflow as tf 9 | from sklearn.metrics import silhouette_score, silhouette_samples 10 | import dataset 11 | import dec_model 12 | import data_utils 13 | 14 | 15 | def write_results(z, y_true, raw_q, out_name, prob): 16 | assert len(z) == len(raw_q) 17 | out_str = "" 18 | label_map = {} # sort samples order by y_pred 19 | for (y_pred, gt_label, q, pro) in zip(z, y_true, raw_q, prob): 20 | prob = -np.sort(-prob) 21 | if y_pred in label_map: 22 | label_map[y_pred].append("__label__" + str(y_pred) + "\t" + q + 23 | ": ground truth label is " + str(gt_label) + str(pro)) 24 | else: 25 | label_map[y_pred] = [] 26 | label_map[y_pred].append("__label__" + str(y_pred) + "\t" + q + 27 | ": ground truth label is" + str(gt_label) + str(pro)) 28 | 29 | for _, lines in label_map.items(): 30 | for line in lines: 31 | out_str += line + "\n" 32 | data_utils.write_file(out_name, out_str) 33 | 34 | def print_metrics(x, labels): 35 | sil_avg = silhouette_score(x, labels) # avg silhouette score 36 | sils = silhouette_samples(x, labels) # silhouette score of each sample 37 | tf.logging.info("avg silhouette:" + str(sil_avg)) 38 | 39 | def inference(data, model, params): 40 | config = tf.ConfigProto() 41 | config.gpu_options.allow_growth = True 42 | saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=None) 43 | batch_size = params.batch_size 44 | 45 | with tf.Session(config=config) as sess: 46 | sess.run(tf.global_variables_initializer()) 47 | saver.restore(sess, params.model_path) 48 | 49 | train_size = len(data.train_x) 50 | step_by_batch = train_size // batch_size + 1 51 | tf.logging.info("step by batch " + str(step_by_batch)) 52 | z_total = [] # z: transformed representation 53 | prob_total = [] # predict cluster probability 54 | pred_total = [] # predict cluster label 55 | 56 | for idx in range(step_by_batch): 57 | if idx == step_by_batch - 1: 58 | tf.logging.info("start/ end idx " + str(idx * batch_size) + " " + str(idx * batch_size + batch_size)) 59 | cur_pred, cur_prob, cur_z = sess.run( 60 | [model.pred, model.pred_prob, model.z], feed_dict={ 61 | model.pretrained_model.ph_tokens: data.train_x[idx * batch_size:], 62 | model.pretrained_model.ph_length: data.train_length_x[idx * batch_size:], 63 | model.pretrained_model.ph_dropout_rate: 0 64 | }) 65 | else: 66 | cur_pred, cur_prob, cur_z = sess.run( 67 | [model.pred, model.pred_prob, model.z], feed_dict={ 68 | model.pretrained_model.ph_tokens: 69 | data.train_x[idx * batch_size: idx * batch_size + batch_size], 70 | model.pretrained_model.ph_length: 71 | data.train_length_x[idx * batch_size: idx * batch_size + batch_size], 72 | model.pretrained_model.ph_dropout_rate: 0 73 | }) 74 | 75 | now = datetime.now() 76 | # tf.logging.info("sess run index " + str(idx) + " " + str(len(cur_pred)) + now.strftime("%H:%M:%S")) 77 | pred_total.extend(cur_pred) 78 | prob_total.extend(cur_prob) 79 | z_total.extend(cur_z) 80 | tf.logging.info("pred total " + str(len(pred_total)) + " , sample total " + str(len(data.train_x))) 81 | assert len(pred_total) == len(data.train_x) 82 | clust_label = np.array(pred_total) 83 | prob = np.array(prob_total) 84 | print_metrics(z_total, clust_label) 85 | # write inference result file 86 | write_results(clust_label, data.gt_label, data.raw_q, params.pred_score_path, prob) 87 | return clust_label 88 | 89 | if __name__=="__main__": 90 | tf.logging.set_verbosity(tf.logging.INFO) 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument("--model_path", type=str, default="") 93 | parser.add_argument("--init_checkpoint", type=str, default="") 94 | parser.add_argument("--train_file", type=str, default="") 95 | parser.add_argument("--batch_size", type=int, default=32) 96 | parser.add_argument("--lstm_dim", type=int, default=500) 97 | parser.add_argument("--embedding_dim", type=int, default=1000) 98 | parser.add_argument("--vocab_file", type=str, default="./vocab") 99 | parser.add_argument("--external_cluster_center", type=str, default="") 100 | parser.add_argument("--n_clusters", type=int, default=20) 101 | parser.add_argument("--alpha", type=int, default=1) 102 | parser.add_argument("--layer_num", type=int, default=1) 103 | parser.add_argument("--token_num", type=int, default=7820) 104 | parser.add_argument("--learning_rate", type=float, default=0.01) 105 | parser.add_argument("--warmup_steps", type=int, default=1000) 106 | parser.add_argument("--epochs", type=int, default=5) 107 | parser.add_argument("--pred_score_path", type=str, default='') 108 | args = parser.parse_args() 109 | 110 | word2id, id2word = data_utils.load_vocab_file(args.vocab_file) 111 | TRAINSET_SIZE = len(data_utils.load_sentences(args.train_file, skip_invalid=True)) 112 | other_arg_dict = {} 113 | other_arg_dict['token_num'] = len(word2id) 114 | other_arg_dict['trainset_size'] = TRAINSET_SIZE 115 | 116 | exp_data = dataset.ExpDataset(args) 117 | dec_model = dec_model.DEC(args, other_arg_dict) 118 | inference(exp_data, dec_model, args) 119 | -------------------------------------------------------------------------------- /dec_mining/print_sen_embedding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | print sentence embedding 4 | """ 5 | import os 6 | import sys 7 | import argparse 8 | import codecs 9 | import tensorflow as tf 10 | import numpy as np 11 | import data_utils 12 | 13 | # get graph output in different way: max mean concat 14 | def get_output(g, embedding_way): 15 | if embedding_way == "concat": # here may have problem, this is just for 4 layers of biLM ! 16 | t = g.get_tensor_by_name("concat_4:0") 17 | elif embedding_way == "max": 18 | t = g.get_tensor_by_name("Max:0") 19 | elif embedding_way == 'mean': 20 | t = g.get_tensor_by_name("Mean:0") 21 | else: 22 | assert False 23 | return {"sen_embedding": t} 24 | 25 | # get graph input 26 | def get_input(g): 27 | return {"tokens": g.get_tensor_by_name("ph_tokens:0"), 28 | "length": g.get_tensor_by_name("ph_length:0"), 29 | "dropout_rate": g.get_tensor_by_name("ph_dropout_rate:0")} 30 | 31 | def gen_test_data(input_file, word2id, max_seq_len): 32 | sens = [] 33 | center_size = [] 34 | for line in codecs.open(input_file, 'r', 'utf-8'): 35 | # separated by slash 36 | ls = line.strip().split("/") 37 | center_size.append(len(ls)) 38 | for l in ls: 39 | l = l.replace(" ", "") 40 | l = l.replace("", " ") 41 | fs = l.rstrip().split() 42 | if len(fs) > max_seq_len: 43 | continue 44 | sen = [] 45 | for f in fs: 46 | if f in word2id: 47 | sen.append(word2id[f]) 48 | else: 49 | sen.append(word2id['']) 50 | sens.append(sen) 51 | return sens, center_size 52 | 53 | if __name__=="__main__": 54 | 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument("--input_file", type=str, default="") 57 | parser.add_argument("--vocab_file", type=str, default="") 58 | parser.add_argument("--model_path", type=str, default="") 59 | parser.add_argument("--batch_size", type=int, default=256) 60 | parser.add_argument("--max_seq_len", type=int, default=100) 61 | parser.add_argument("--output_file", type=str, default="") 62 | # sentence representation output way : max mean concat 63 | parser.add_argument("--embedding_way", type=str, default="concat") 64 | args = parser.parse_args() 65 | 66 | word2id, id2word = data_utils.load_vocab_file(args.vocab_file) 67 | sys.stderr.write("vocab num : " + str(len(word2id)) + "\n") 68 | sens, center_size = gen_test_data(args.input_file, word2id, args.max_seq_len) 69 | sys.stderr.write("sens num : " + str(len(sens)) + "\n") 70 | tf.logging.info("embedding_way : ", args.embedding_way) 71 | 72 | # limit cpu resource 73 | cpu_num = int(os.environ.get('CPU_NUM', 15)) 74 | config = tf.ConfigProto(device_count={"CPU": cpu_num}, 75 | inter_op_parallelism_threads = cpu_num, 76 | intra_op_parallelism_threads = cpu_num, 77 | log_device_placement=True) 78 | config.gpu_options.allow_growth = True 79 | with tf.Session(config=config) as sess: 80 | saver = tf.train.import_meta_graph("{}.meta".format(args.model_path)) 81 | saver.restore(sess, args.model_path) 82 | 83 | graph = tf.get_default_graph() 84 | input_dict = get_input(graph) 85 | output_dict = get_output(graph, args.embedding_way) 86 | 87 | caches = [] 88 | idx = 0 89 | while idx < len(sens): 90 | batch_sens = sens[idx:idx + args.batch_size] 91 | batch_tokens = [] 92 | batch_length = [] 93 | for sen in batch_sens: 94 | batch_tokens.append(sen) 95 | batch_length.append(len(sen)) 96 | 97 | real_max_len = max([len(b) for b in batch_tokens]) 98 | for b in batch_tokens: 99 | b.extend([0] * (real_max_len - len(b))) 100 | 101 | re = sess.run(output_dict['sen_embedding'], 102 | feed_dict={input_dict['tokens']: batch_tokens, 103 | input_dict['length']: batch_length, 104 | input_dict["dropout_rate"]: 0.0}) 105 | if len(caches) % 200 == 0: 106 | tf.logging.info(len(caches)) 107 | caches.append(re) 108 | idx += len(batch_sens) 109 | 110 | sen_embeddings = np.concatenate(caches, 0) 111 | # calculate average embedding 112 | avg_centers = [] 113 | 114 | idx = 0 115 | for size in center_size: 116 | avg_center_emb = np.average(sen_embeddings[idx: idx + size], axis=0) 117 | avg_centers.append(avg_center_emb) 118 | idx = idx + size 119 | 120 | np.savetxt(args.output_file, avg_centers, fmt='%.3e') 121 | -------------------------------------------------------------------------------- /dec_mining/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | finetune on pretrained model with dataset to be clustered 4 | """ 5 | 6 | import argparse 7 | import tensorflow as tf 8 | import numpy as np 9 | import dataset 10 | import dec_model 11 | import data_utils 12 | 13 | 14 | def train(data, model, args): 15 | saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=2) 16 | best_acc = 0 17 | config = tf.ConfigProto(allow_soft_placement=True) 18 | config.gpu_options.allow_growth = True 19 | with tf.Session(config=config) as sess: 20 | sess.run(tf.global_variables_initializer()) 21 | 22 | train_size = len(data.train_x) 23 | batch_size = args.batch_size 24 | 25 | steps_in_epoch = train_size // args.batch_size + 1 26 | tf.logging.info("step by batch " + str(steps_in_epoch)) 27 | z_total = [] 28 | 29 | # z: transformed representation of x 30 | for idx in range(steps_in_epoch): 31 | if idx == steps_in_epoch - 1: # last batch 32 | tf.logging.info("start/ end idx " + str(idx * batch_size) + " " + str(idx * batch_size + batch_size)) 33 | cur_z = sess.run(model.z, feed_dict={ 34 | model.pretrained_model.ph_tokens: data.train_x[idx * batch_size:], 35 | model.pretrained_model.ph_length: data.train_length_x[idx * batch_size:], 36 | model.pretrained_model.ph_dropout_rate: 0 37 | }) 38 | else: 39 | cur_z = sess.run(model.z, feed_dict={ 40 | model.pretrained_model.ph_tokens: data.train_x[ 41 | idx * batch_size: idx * batch_size + batch_size], 42 | model.pretrained_model.ph_length: data.train_length_x[ 43 | idx * batch_size: idx * batch_size + batch_size], 44 | model.pretrained_model.ph_dropout_rate: 0 45 | }) 46 | z_total.extend(cur_z) 47 | 48 | tf.logging.info("z total size : " + str(len(z_total))) # sample size 49 | assert len(z_total) == len(data.train_x) 50 | z = np.array(z_total) 51 | 52 | # Customize the cluster center 53 | if args.external_cluster_center != "": 54 | # load external centers file 55 | external_center = model.external_cluster_center_vec 56 | assign_mu_op = tf.assign(model.mu, external_center) 57 | else: 58 | # kmeans init centers 59 | assign_mu_op = model.get_assign_cluster_centers_op(z) 60 | 61 | amu = sess.run(assign_mu_op) # get cluster center 62 | 63 | for cur_epoch in range(args.epochs): 64 | q_list = [] 65 | 66 | for idx in range(steps_in_epoch): 67 | start_idx = idx * batch_size 68 | end_idx = idx * batch_size + batch_size 69 | if idx == steps_in_epoch - 1: 70 | q_batch = sess.run( 71 | model.q, feed_dict={ 72 | model.pretrained_model.ph_tokens: data.train_x[start_idx:], 73 | model.pretrained_model.ph_length: data.train_length_x[start_idx:], 74 | model.pretrained_model.ph_dropout_rate: 0 75 | }) 76 | else: 77 | q_batch = sess.run( 78 | model.q, feed_dict={ 79 | model.pretrained_model.ph_tokens: data.train_x[start_idx: end_idx], 80 | model.pretrained_model.ph_length: data.train_length_x[start_idx: end_idx], 81 | model.pretrained_model.ph_dropout_rate: 0 82 | }) 83 | 84 | q_list.extend(q_batch) 85 | 86 | q = np.array(q_list) 87 | p = model.target_distribution(q) 88 | 89 | for iter_, (batch_x, batch_y, batch_idxs, batch_x_lengths) in enumerate( 90 | data.gen_next_batch(batch_size=batch_size, \ 91 | is_train_set=True, epoch=1)): 92 | batch_p = p[batch_idxs] 93 | _, loss, pred, global_step, lr = sess.run([model.trainer, model.loss, model.pred, model.global_step_op, model.optimizer._lr], \ 94 | feed_dict={model.pretrained_model.ph_tokens: batch_x, \ 95 | model.pretrained_model.ph_length: batch_x_lengths, \ 96 | model.p: batch_p, \ 97 | model.pretrained_model.ph_dropout_rate: 0 98 | }) 99 | # NOTE: acc 只用于有监督数据查看聚类效果,ground truth label不会参与到train,如果是无监督数据,此acc 无用 100 | acc = model.cluster_acc(batch_y, pred) 101 | tf.logging.info("[DEC] epoch: {}\tloss: {}\tacc: {}\t lr {} \t global_step {} ".format(cur_epoch, loss, acc, lr, global_step)) 102 | if acc > best_acc: 103 | best_acc = acc 104 | tf.logging.info("!!!!!!!!!!!!! best acc got " + str(best_acc)) 105 | # save model each epoch 106 | saver.save(sess, args.model_save_dir + '/finetune.ckpt', global_step=global_step) 107 | 108 | if __name__ == "__main__": 109 | tf.logging.set_verbosity(tf.logging.INFO) 110 | # pass params 111 | parser = argparse.ArgumentParser() 112 | # sptm pretrain model path 113 | parser.add_argument("--init_checkpoint", type=str, default='') 114 | parser.add_argument("--train_file", type=str, default="") 115 | parser.add_argument("--batch_size", type=int, default=32) 116 | # customized cluster centers file path, pass either of params 'external_cluster_center' or 'n_clusters' 117 | parser.add_argument("--external_cluster_center", type=str, default="") 118 | # number of clusters (init with kmeans) 119 | parser.add_argument("--n_clusters", type=int, default=20) 120 | parser.add_argument("--epochs", type=int, default=50) 121 | parser.add_argument("--warmup_steps", type=int, default=1000) 122 | parser.add_argument("--learning_rate", type=float, default=0.01) 123 | # DEC model q distribution param, alpha=1 in paper 124 | parser.add_argument("--alpha", type=int, default=1) 125 | parser.add_argument("--layer_num", type=int, default=1) 126 | parser.add_argument("--token_num", type=int, default=7820) 127 | parser.add_argument("--lstm_dim", type=int, default=500) 128 | parser.add_argument("--embedding_dim", type=int, default=1000) 129 | parser.add_argument("--vocab_file", type=str, default="./vocab") 130 | parser.add_argument("--model_save_dir", type=str, default="./saved_model") 131 | args = parser.parse_args() 132 | 133 | word2id, id2word = data_utils.load_vocab_file(args.vocab_file) 134 | trainset_size = len(data_utils.load_sentences(args.train_file, skip_invalid=True)) 135 | other_arg_dict = {} 136 | other_arg_dict['token_num'] = len(word2id) 137 | other_arg_dict['trainset_size'] = trainset_size 138 | 139 | exp_data = dataset.ExpDataset(args) 140 | dec_model = dec_model.DEC(args, other_arg_dict) 141 | train(exp_data, dec_model, args) 142 | -------------------------------------------------------------------------------- /docs/RUNDEMO.md: -------------------------------------------------------------------------------- 1 | # 运行说明 2 | 本示例给出了支持一层和两层结构知识库问答运行demo、评测指标及在测试集的效果。 3 | 4 | ## 数据介绍 5 | [data_demo](../data_demo)所给出的预训练集(pre_train_data),训练集(train_data),验证集(valid_data),预测集(test_data) 取自58智能问答生产环境下的真实数据,这里仅为了跑通模型,因此只取了少部分数据,其中预训练集9W+(真实场景下数量较大),训练集9W+,验证集和测试集均3000+,具体数据格式可见[README](../README.md)中的数据介绍部分。 6 | 7 | ## 基于一层结构知识库的自动问答运行示例 8 | 9 | ### 基于DSSM模型的自动问答 10 | 11 | 使用DSSM意图匹配模型时,对于一层结构知识库只需要先训练意图匹配模型,然后用训练好的模型对测试集进行预测,最后对意图匹配的结果按照打分阈值高低给出回答类别,当意图匹配打分高于某个阈值时给出唯一回答,当打分低于某个阈值时给出拒识回答,当打分处于这两个阈值之间时给出列表回答。 12 | 13 | #### 1.训练DSSM意图匹配模型 14 | 15 | ```bash 16 | mkdir model && python run_dssm.py --train_path=./data_demo/train_data --valid_path=./data_demo/valid_data --map_file_path=./data_demo/std_data --model_path=./model/model_min/ --softmax_r=45 --embedding_size=256 --learning_rate=0.001 --keep_prob=0.8 --batch_size=250 --num_epoches=30 --negative_size=200 --eval_every=10 --num_units=256 --use_same_cell=True --label2id_path=./model/model_min/min_label2id --vocab2id_path=./model/model_min/min_vocab2id 17 | ``` 18 | 19 | 参数说明: 20 | 21 | train_path: 训练集 22 | 23 | valid_path: 验证集 24 | 25 | map_file_path: 领域意图映射文件 26 | 27 | model_path: 模型存储路径 28 | 29 | softmax_r: 余弦相似度滑动参数 30 | 31 | embedding_size: embedding层向量大小 32 | 33 | learning_rate: 学习率 34 | 35 | keep_prob: dropout过程中keep神经元的概率 36 | 37 | batch_size: batch 大小 38 | 39 | num_epoches: epcho个数 40 | 41 | negative_size: 负样本数量 42 | 43 | eval_every: 每隔多少steps在验证集上检验训练过程中的模型效果 44 | 45 | num_units: lstm cell 的单元个数 46 | 47 | use_same_cell: 前向后向lstm是否需要用相同的cell(共享一套参数) 48 | 49 | label2id_path: <意图,id>映射文件 50 | 51 | vocab2id_path: 根据训练数据生成的字典映射文件 52 | 53 | #### 2.用意图匹配模型对测试集进行预测 54 | 55 | ```bash 56 | python dssm_predict.py --map_file_path=./data_demo/std_data --model_path=./model/model_min/ --export_model_dir=./model/model_min/dssm_tf_serving/ --test_data_path=./data_demo/test_data --test_result_path=./model/model_min/result_min_test --softmax_r=45 --batch_size=250 --label2id_file=./model/model_min/min_label2id --vocab2id_file=./model/model_min/min_vocab2id 57 | ``` 58 | 59 | #### 3.意图匹配的结果按照打分阈值高低给出回答类别 60 | 61 | ```bash 62 | python merge_classifier_match_label.py none ./model/model_min/result_min_test ./data_demo/merge_result_1_level none 63 | ``` 64 | 65 | ### 基于轻量级预训练语言模型(SPTM)的自动问答 66 | 67 | 使用SPTM进行意图匹配时,对于一层结构知识库需要先预训练语言模型,然后基于预训练语言模型与训练集微调意图匹配模型;最终在用训练好的模型对测试集进行预测后,对意图匹配的结果按照打分阈值高低给出回答类别。其阈值判别的方式与无预训练场景相同。 68 | 69 | #### 1.预训练语言模型 70 | 71 | ##### 基于Bi-LSTM block的预训练 72 | ```bash 73 | cd sptm && mkdir -p model/pretrain && python run_pretraining.py --train_file="../data_demo/pre_train_data" --vocab_file="../data_demo/vocab" --model_save_dir="./model/pretrain" --batch_size=256 --print_step=100 --weight_decay=0 --embedding_dim=1000 --lstm_dim=500 --layer_num=1 --train_step=100000 --warmup_step=1000 --learning_rate=5e-5 --dropout_rate=0.1 --max_predictions_per_seq=10 --clip_norm=1.0 --max_seq_len=100 --use_queue=0 --representation_type=lstm 74 | ``` 75 | 参数说明: 76 | 77 | vocab:词典文件(需要包含 ``) 78 | 79 | train_file/valid_data:训练/验证集 80 | 81 | lstm_dim:lstm的门控单元数 82 | 83 | embedding_dim:词嵌入维度 84 | 85 | dropout_rate:节点被dropout的比例 86 | 87 | layer_num:LSTM的层数 88 | 89 | weight_decay:adam的衰减系数 90 | 91 | max_predictions_per_seq:每个句子中,最多会mask的词数 92 | 93 | clip_norm:梯度裁剪阈值 94 | 95 | use_queue:是否使用队列生成预训练数据 96 | 97 | representation_type:使用何种结构训练模型,可选择lstm或transformer 98 | 99 | ##### 基于参数共享Transformer block的预训练 100 | ```bash 101 | cd sptm && mkdir -p model/pretrain && python run_pretraining.py --train_file="../data_demo/pre_train_data" --vocab_file="../data_demo/vocab" --model_save_dir="./model/pretrain" --batch_size=64 --print_step=100 --embedding_dim=1000 --train_step=100000 --warmup_step=5000 --learning_rate=1e-5 --max_predictions_per_seq=10 --clip_norm=1.0 --max_len=100 --use_queue=0 --representation_type=transformer --initializer_range=0.02 --max_position_embeddings=140 --hidden_size=768 --num_hidden_layers=12 --num_attention_heads=12 --intermediate_size=1024 102 | ``` 103 | 参数说明: 104 | 105 | learning_rate:学习率 106 | 107 | initializer_range:模型参数正态分布初始化的标准差 108 | 109 | max_position_embeddings: position的最大位置 110 | 111 | hidden_size:隐层大小 112 | 113 | num_hidden_layers:隐层数 114 | 115 | num_attention_heads:多头数 116 | 117 | intermediate_size:ffn层大小 118 | 119 | representation_type:使用何种结构训练模型,可选择lstm或transformer 120 | 121 | 122 | #### 2.微调意图匹配模型 123 | 124 | 注意此处的```init_checkpoint```需要根据预训练的结果进行选取,如没有预训练模型,也可以不填写: 125 | 126 | ##### 基于Bi-LSTM block的微调 127 | ```bash 128 | cd sptm && python run_classifier.py --output_id2label_file="model/id2label.has_init" --vocab_file="../data_demo/vocab" --train_file="../data_demo/train_data" --dev_file="../data_demo/valid_data" --model_save_dir="model/finetune" --lstm_dim=500 --embedding_dim=1000 --opt_type=adam --batch_size=256 --epoch=20 --learning_rate=1e-4 --seed=1 --max_len=100 --print_step=10 --dropout_rate=0.1 --layer_num=1 --init_checkpoint="model/pretrain/lm_pretrain.ckpt-1400" --representation_type=lstm 129 | ``` 130 | 参数说明: 131 | 132 | output_id2label_file:(id,标签)映射文件,最后预测的时侯使用 133 | 134 | opt_type:优化器类型,有sgd/adagrad/adam几种可选 135 | 136 | seed:随机种子的值,使用相同的随机种子保证微调模型结果一致 137 | 138 | init_checkpoint:预训练模型保存的checkpoint 139 | 140 | ##### 基于参数共享Transformer block的微调 141 | ```bash 142 | cd sptm && python run_classifier.py --output_id2label_file="model/id2label.has_init" --vocab_file="../data_demo/vocab" --train_file="../data_demo/train_data" --dev_file="../data_demo/valid_data" --model_save_dir="model/finetune" --embedding_dim=1000 --opt_type=adam --batch_size=64 --epoch=10 --learning_rate=1e-4 --seed=1 --max_len=100 --print_step=100 --dropout_rate=0.1 --use_queue=0 --representation_type=transformer --initializer_range=0.02 --max_position_embeddings=140 --hidden_size=768 --num_hidden_layers=12 --num_attention_heads=12 --intermediate_size=1024 --init_checkpoint="model/pretrain/lm_pretrain.ckpt-100800" --representation_type=transformer 143 | ``` 144 | 参数说明: 145 | 146 | opt_type:优化器类型 147 | 148 | representation_type:使用何种结构训练模型,可选择lstm或transformer 149 | 150 | 151 | #### 3.用意图匹配模型对测试集进行预测 152 | 153 | ```bash 154 | cd sptm && python run_prediction.py --input_file="../data_demo/test_data" --vocab_file="../data_demo/vocab" --id2label_file="model/id2label.has_init" --model_dir="model/finetune" > "../data_demo/result_pretrain_raw" 155 | ``` 156 | 157 | #### 4.预测结果格式化,按照打分阈值高低给出回答类别 158 | 159 | ```bash 160 | python sptm/format_result.py ./data_demo/test_data ./data_demo/result_pretrain_raw ./data_demo/result_pretrain_test 161 | ``` 162 | 参数说明: 163 | 164 | argv[1]: 测试集 165 | 166 | argv[2]: SPTM打分文件 167 | 168 | argv[3]: 格式化SPTM打分文件 169 | 170 | ``` 171 | python merge_classifier_match_label.py none ./data_demo/result_pretrain_test ./data_demo/merge_result_pretrain none 172 | ``` 173 | 174 | ## 基于两层结构知识库的自动问答运行示例 175 | 对于两层结构知识库需要先训练领域分类模型和意图匹配模型,然后用训练好的模型对测试集进行预测,最后对领域分类和意图匹配的结果进行融合,给出回答类别,具体融合策略参考README中两层结构知识库问答融合示意图。 176 | 177 | ### 1.训练LSTM领域分类模型 178 | 179 | ```bash 180 | mkdir model && python run_bi_lstm.py --train_path=./data_demo/train_data --valid_path=./data_demo/valid_data --map_file_path=./data_demo/std_data --model_path=./model/model_max --vocab_file=./model/model_max/vocab_max --label_file=./model/model_max/label_max --embedding_size=256 --num_units=256 --batch_size=200 --seq_length=40 --num_epcho=30 --check_every=20 --lstm_layers=2 --lr=0.01 --dropout_keep_prob=0.8 181 | ``` 182 | 参数说明: 183 | 184 | train_path: 训练集 185 | 186 | valid_path: 验证集 187 | 188 | map_file_path: 领域意图映射文件 189 | 190 | model_path: 模型存储路径 191 | 192 | vocab_file:根据训练数据生成的字典映射文件 193 | 194 | label_file:根据训练过程生成的<领域,id>映射文件 195 | 196 | embedding_size: embedding层向量大小 197 | 198 | num_units: lstm cell 的单元个数 199 | 200 | batch_size: batch 大小 201 | 202 | seq_length: 参与训练的最大序列长度 203 | 204 | num_epcho: epcho个数 205 | 206 | check_every: 每隔多少steps在验证集上检验训练过程中的模型效果 207 | 208 | lstm_layers: lstm 层数 209 | 210 | lr: 学习率 211 | 212 | dropout_keep_prob: dropout过程中keep神经元的概率 213 | 214 | ### 2.用领域分类模型对测试集进行预测 215 | 216 | ```bash 217 | python lstm_predict.py --map_file_path=./data_demo/std_data --model_path=./model/model_max --test_data_path=./data_demo/test_data --test_result_path=./model/model_max/result_max_test --batch_size=250 --seq_length=40 --label2id_file=./model/model_max/label_max --vocab2id_file=./model/model_max/vocab_max 218 | ``` 219 | 220 | 参数说明: 221 | 222 | map_file_path: 领域意图映射文件 223 | 224 | model_path: 模型路径 225 | 226 | test_data_path: 测试集 227 | 228 | test_result_path: 测试打分结果文件 229 | 230 | batch_size: batch 大小 231 | 232 | seq_length: 参与训练的最大序列长度(要和训练过程保持一致) 233 | 234 | label2id_file:<领域,id>映射文件 235 | 236 | vocab2id_file: 根据训练数据生成的字典映射文件 237 | 238 | ### 3.训练DSSM意图匹配模型 239 | ```bash 240 | python run_dssm.py --train_path=./data_demo/train_data --valid_path=./data_demo/valid_data --map_file_path=./data_demo/std_data --model_path=./model/model_min/ --result_file_path=./data/result_min --softmax_r=45 --embedding_size=256 --learning_rate=0.001 --keep_prob=0.8 --batch_size=250 --num_epoches=30 --negative_size=200 --eval_every=10 --num_units=256 --use_same_cell=False --label2id_path=./model/model_min/min_label2id --vocab2id_path=./model/model_min/min_vocab2id 241 | ``` 242 | 243 | ### 4.用意图匹配模型对测试集进行预测 244 | 245 | ```bash 246 | python dssm_predict.py --map_file_path=./data_demo/std_data --model_path=./model/model_min/ --export_model_dir=./model/model_min/dssm_tf_serving/ --test_data_path=./data_demo/test_data --test_result_path=./model/model_min/result_min_test --softmax_r=45 --batch_size=250 --label2id_file=./model/model_min/min_label2id --vocab2id_file=./model/model_min/min_vocab2id 247 | ``` 248 | 249 | ### 5.领域分类结果&意图匹配结果进行融合 250 | 251 | ```bash 252 | python merge_classifier_match_label.py ./model/model_max/result_max_test ./model/model_min/result_min_test ./data_demo/merge_result_2_level ./data_demo/std_data 253 | ``` 254 | 255 | 参数说明: 256 | 257 | argv[1]: 领域分类打分文件 258 | 259 | argv[2]: 意图识别打分文件 260 | 261 | argv[3]: 模型融合文件 262 | 263 | argv[4]: 领域意图映射文件 264 | 265 | ## 模型评测指标及测试集效果 266 | 267 | 目前qa_match的问答效果评测是基于分类模型的评测,主要看在模型各种回答类型(唯一回答,列表回答,拒绝回答)占比接近真实应回答类型占比下各种回答的类型的准确率、召回率、F1值,具体定义如下: 268 | 269 | ![评测指标](measurement.png) 270 | 271 | 对上述一层结构知识库和二层结构知识库示例(数据集具体见[data_demo](../data_demo))运行后评测效果如下(使用通用深度学习推理服务[dl_inference](https://github.com/wuba/dl_inference)开源项目部署模型来评测推理耗时): 272 | 273 | | 数据集 | 模型 | **唯一回答准确率** | **唯一回答召回率** | **唯一回答**F1 | **CPU**机器上推理耗时 | 274 | | ---------------- | ------------------------------------------------------------ | ------------------ | ------------------ | -------------- | --------------------- | 275 | | 一级知识库数据集 | DSSM[[下载](http://wos.58cdn.com.cn/nOlKjIhGntU/qamatch/model_min.zip)] | 0.8398 | 0.8326 | 0.8362 | **3ms** | 276 | | 一级知识库数据集 | SPTM(LSTM)[[下载](http://wos.58cdn.com.cn/nOlKjIhGntU/qamatch/model_pretrain.zip)] | 0.8841 | 0.9002 | 0.8921 | 16ms | 277 | | 一级知识库数据集 | SPTM(Transformer 12 Layers,12 Heads)[[下载](http://wos.58cdn.com.cn/nOlKjIhGntU/qamatch/model_pretrain_transformer.zip)] | 0.9275 | 0.9298 | **0.9287** | 17ms | 278 | | 一级知识库数据集 | SPTM(Transformer 1 Layers,12 Heads)[[下载](http://wos.58cdn.com.cn/nOlKjIhGntU/qamatch/transorformer_1_layer_12_heads.zip)] | 0.9122 | 0.9105 | 0.9122 | 13ms | 279 | | 二级知识库数据集 | LSTM+DSSM融合模型[[下载](http://wos.58cdn.com.cn/nOlKjIhGntU/qamatch/model_merge.zip)] | 0.8957 | 0.9027 | 0.8992 | 18ms | 280 | 281 | 说明:由于示例数据中列表回答真实占比较小,这里我们主要看唯一回答的准确率、召回率和F1值。对于二级知识库数据集,我们也可以使用预训练模型来完成自动问答,这里不做过多描述。 282 | -------------------------------------------------------------------------------- /docs/dssm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/docs/dssm.png -------------------------------------------------------------------------------- /docs/kg_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/docs/kg_demo.png -------------------------------------------------------------------------------- /docs/lstm_dssm_bagging.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/docs/lstm_dssm_bagging.png -------------------------------------------------------------------------------- /docs/measurement.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/docs/measurement.png -------------------------------------------------------------------------------- /docs/pretrain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/docs/pretrain.png -------------------------------------------------------------------------------- /docs/sptm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/docs/sptm.png -------------------------------------------------------------------------------- /dssm_predict.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import shutil 4 | import os 5 | from utils.match_utils import DataHelper 6 | 7 | flags = tf.app.flags 8 | FLAGS = flags.FLAGS 9 | 10 | flags.DEFINE_string('map_file_path', None, 'standard data path') 11 | flags.DEFINE_string("model_path", None, "checkpoint dir from predicting") 12 | flags.DEFINE_string("export_model_dir", None, "export model dir") 13 | flags.DEFINE_string('test_data_path', None, 'test data path') 14 | flags.DEFINE_string('test_result_path', None, 'test data result path') 15 | flags.DEFINE_integer('softmax_r', 45, 'Smooth parameter for osine similarity') # must be similar as train 16 | flags.DEFINE_integer('batch_size', 100, 'batch_size for train') 17 | flags.DEFINE_string('label2id_file', None, 'label2id file path') 18 | flags.DEFINE_string('vocab2id_file', None, 'vocab2id_file path') 19 | 20 | dh = DataHelper(None, None, FLAGS.test_data_path, FLAGS.map_file_path, FLAGS.batch_size, None, FLAGS.label2id_file, 21 | FLAGS.vocab2id_file, False) 22 | 23 | config = tf.ConfigProto() 24 | config.gpu_options.allow_growth = True 25 | 26 | with tf.Session(config=config) as sess: 27 | model_file = tf.train.latest_checkpoint(FLAGS.model_path) 28 | saver = tf.train.import_meta_graph("{}.meta".format(model_file)) 29 | saver.restore(sess, model_file) 30 | graph = tf.get_default_graph() 31 | input_x = graph.get_tensor_by_name("input_x:0") 32 | length_x = graph.get_tensor_by_name("length_x:0") 33 | input_y = graph.get_tensor_by_name("input_y:0") 34 | length_y = graph.get_tensor_by_name("length_y:0") 35 | keep_prob = graph.get_tensor_by_name("keep_prob:0") 36 | q_y_raw = graph.get_tensor_by_name("representation/q_y_raw:0") 37 | qs_y_raw = graph.get_tensor_by_name("representation/qs_y_raw:0") 38 | # first get std tensor value 39 | length_y_value = [y[0] for y in dh.std_batch] 40 | input_y_value = [y[1] for y in dh.std_batch] 41 | # print("input_y_value: " + str(input_y_value)) 42 | # print("input_y_value.shape: " + str(np.array(input_y_value, dtype=np.int32).shape)) 43 | # print("length_y_value.shape: " + str(np.array(length_y_value, dtype=np.int32).shape)) 44 | qs_y_raw_out = sess.run(qs_y_raw, feed_dict={input_y: np.array(input_y_value, dtype=np.int32), 45 | length_y: np.array(length_y_value, dtype=np.int32), keep_prob: 1.0}) 46 | 47 | with tf.name_scope('cosine_similarity_pre'): 48 | # Cosine similarity 49 | q_norm_pre = tf.sqrt(tf.reduce_sum(tf.square(q_y_raw), 1, True)) # b*1 50 | qs_norm_pre = tf.transpose(tf.sqrt(tf.reduce_sum(tf.square(qs_y_raw_out), 1, True))) # 1*sb 51 | prod_nu_pre = tf.matmul(q_y_raw, tf.transpose(qs_y_raw_out)) # b*sb 52 | norm_prod_de = tf.matmul(q_norm_pre, qs_norm_pre) # b*sb 53 | cos_sim_pre = tf.truediv(prod_nu_pre, norm_prod_de) * FLAGS.softmax_r # b*sb 54 | 55 | with tf.name_scope('prob_pre'): 56 | prob_pre = tf.nn.softmax(cos_sim_pre) # b*sb 57 | 58 | test_batches = dh.test_batch_iterator() 59 | test_result_file = open(FLAGS.test_result_path, 'w', encoding='utf-8') 60 | # print(dh.predict_label_seq) 61 | for _, test_batch_q in enumerate(test_batches): 62 | # test_batch_q:[(l1, ws1, label1, line1), (l2, ws2, label2, line2), ...] 63 | length_x_value = [x[0] for x in test_batch_q] 64 | input_x_value = [x[1] for x in test_batch_q] 65 | test_prob = sess.run(prob_pre, feed_dict={input_x: np.array(input_x_value, dtype=np.int32), 66 | length_x: np.array(length_x_value, dtype=np.int32), 67 | keep_prob: 1.0}) # b*sb 68 | # print("test_prob: " + str(test_prob)) 69 | for index, example in enumerate(test_batch_q): 70 | (_, _, real_label, words) = example 71 | result_str = str(real_label) + '\t' + str(words) + '\t' 72 | label_scores = {} 73 | # print(test_prob[index]) 74 | sample_scores = test_prob[index] 75 | for score_index, real_label_score in enumerate(sample_scores): 76 | label_scores[dh.predict_label_seq[score_index]] = real_label_score 77 | sorted_list = sorted(label_scores.items(), key=lambda x: x[1], reverse=True) 78 | # print(str(sorted_list)) 79 | for label, score in sorted_list: 80 | result_str = result_str + str(label) + ":" + str(score) + " " 81 | # write result 82 | test_result_file.write(result_str + '\n') 83 | test_result_file.close() 84 | # export model 85 | if os.path.isdir(FLAGS.export_model_dir): 86 | shutil.rmtree(FLAGS.export_model_dir) 87 | builder = tf.saved_model.builder.SavedModelBuilder(FLAGS.export_model_dir) 88 | pred_x = tf.saved_model.utils.build_tensor_info(input_x) 89 | pred_len_x = tf.saved_model.utils.build_tensor_info(length_x) 90 | drop_keep_prob = tf.saved_model.utils.build_tensor_info(keep_prob) 91 | probs = tf.saved_model.utils.build_tensor_info(prob_pre) 92 | # 定义方法名和输入输出 93 | signature_def_map = { 94 | "predict": tf.saved_model.signature_def_utils.build_signature_def( 95 | inputs={"input": pred_x, "length": pred_len_x, "keep_prob": drop_keep_prob}, 96 | outputs={ 97 | "probs": probs 98 | }, 99 | method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME 100 | ) 101 | } 102 | builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING], 103 | signature_def_map=signature_def_map) 104 | builder.save() 105 | -------------------------------------------------------------------------------- /lstm_predict.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from utils.classifier_utils import TextLoader 3 | 4 | flags = tf.app.flags 5 | FLAGS = flags.FLAGS 6 | 7 | flags.DEFINE_string('map_file_path', None, 'standard data path') 8 | flags.DEFINE_string("model_path", None, "checkpoint dir from predicting") 9 | flags.DEFINE_string('test_data_path', None, 'test data path') 10 | flags.DEFINE_string('test_result_path', None, 'test data result path') 11 | flags.DEFINE_integer('batch_size', 100, 'batch_size for train') 12 | flags.DEFINE_integer('seq_length', 40, 'seq_length') 13 | flags.DEFINE_string('label2id_file', None, 'label2id file path') 14 | flags.DEFINE_string('vocab2id_file', None, 'vocab2id_file path') 15 | 16 | # load vocab and label mapping 17 | vocab_id = {} 18 | vocab_file = open(FLAGS.vocab2id_file, 'r', encoding='utf-8') 19 | for line in vocab_file: 20 | word_ids = line.strip().split('\t') 21 | vocab_id[word_ids[0]] = word_ids[1] 22 | vocab_file.close() 23 | label_id = {} 24 | id_label = {} 25 | label_file = open(FLAGS.label2id_file, 'r', encoding='utf-8') 26 | for line in label_file: 27 | std_label_ids = line.strip().split('\t') 28 | label_id[std_label_ids[0]] = std_label_ids[1] 29 | id_label[std_label_ids[1]] = std_label_ids[0] 30 | # print("id_label: " + str(id_label)) 31 | 32 | label_file.close() 33 | std_label_map = {} 34 | std_label_map_file = open(FLAGS.map_file_path, 'r', encoding='utf-8') 35 | for line in std_label_map_file: 36 | tokens = line.strip().split('\t') 37 | label = tokens[0] 38 | std_id = tokens[1] 39 | std_label_map[std_id] = label 40 | 41 | std_label_map_file.close() 42 | 43 | test_data_loader = TextLoader(False, FLAGS.test_data_path, FLAGS.map_file_path, FLAGS.batch_size, FLAGS.seq_length, 44 | vocab_id, label_id, std_label_map, 'utf8', False) 45 | 46 | config = tf.ConfigProto() 47 | config.gpu_options.allow_growth = True 48 | 49 | with tf.Session(config=config) as sess: 50 | model_file = tf.train.latest_checkpoint(FLAGS.model_path) 51 | saver = tf.train.import_meta_graph("{}.meta".format(model_file)) 52 | saver.restore(sess, model_file) 53 | graph = tf.get_default_graph() 54 | input_x = graph.get_tensor_by_name("input_x:0") 55 | length_x = graph.get_tensor_by_name("x_len:0") 56 | keep_prob = graph.get_tensor_by_name("dropout_keep_prob:0") 57 | test_data_loader.reset_batch_pointer() 58 | prediction = graph.get_tensor_by_name("acc/prediction_softmax:0") # [batchsize, label_size] 59 | test_result_file = open(FLAGS.test_result_path, 'w', encoding='utf-8') 60 | for n in range(test_data_loader.num_batches): 61 | input_x_test, input_y_test, x_len_test, raw_lines = test_data_loader.next_batch() 62 | prediction_result = sess.run(prediction, 63 | feed_dict={input_x: input_x_test, length_x: x_len_test, keep_prob: 1.0}) 64 | # print("n: " + str(n)) 65 | # print("len(input_x_test): " + str(len(input_x_test))) 66 | # print("len(input_y_test): " + str(len(input_y_test))) 67 | # print("len(raw_lines): " + str(len(raw_lines))) 68 | assert len(input_y_test) == len(raw_lines) 69 | for i in range(len(raw_lines)): 70 | raw_line = raw_lines[i] 71 | # print("input_y_test[i]: " + str(input_y_test[i])) 72 | real_label = id_label.get(str(input_y_test[i])) 73 | label_scores = {} 74 | for j in range(len(prediction_result[i])): 75 | label = id_label.get(str(j)) 76 | score = prediction_result[i][j] 77 | label_scores[label] = score 78 | sorted_list = sorted(label_scores.items(), key=lambda x: x[1], reverse=True) 79 | # print("real_label: " + str(type(real_label))) 80 | # print("raw_lines: " + str(raw_lines)) 81 | result_str = str(real_label) + "\t" + str(raw_line) + "\t"; 82 | for label, score in sorted_list: 83 | result_str = result_str + str(label) + ":" + str(score) + " " 84 | test_result_file.write(result_str + '\n') 85 | test_result_file.close() 86 | -------------------------------------------------------------------------------- /merge_classifier_match_label.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # encoding=utf-8 3 | # author wangyong 4 | 5 | """ 6 | merge result for domain identification and intent recognition 7 | """ 8 | 9 | import sys 10 | import logging 11 | logging.basicConfig(format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s', 12 | level=logging.DEBUG) 13 | 14 | 15 | # none 16 | MAXCLASS_NONE_HIGHSCORE = 0.93 # 0.6 17 | MAXCLASS_NONE_MIDSCORE = 0.9 # 0.4 18 | 19 | #attention: MAXCLASS_NONE_HIGHSCORE must >= MAXCLASS_NONE_MIDSCORE 20 | assert MAXCLASS_NONE_HIGHSCORE >= MAXCLASS_NONE_MIDSCORE 21 | 22 | MINCLASS_NONE_HIGHSCORE1 = 0.8 23 | MINCLASS_NONE_HIGHSCORE2 = 0.75 24 | MINCLASS_NONE_MIDSCORE1 = 0.65 25 | MINCLASS_NONE_MIDSCORE2 = 0.65 26 | 27 | assert MINCLASS_NONE_HIGHSCORE1 >= MINCLASS_NONE_MIDSCORE1 28 | assert MINCLASS_NONE_HIGHSCORE2 >= MINCLASS_NONE_MIDSCORE2 29 | # list 30 | MAXCLASS_LIST_HIGHSCORE = 0.9 31 | MAXCLASS_LIST_MIDSCORE = 0.65 32 | assert MAXCLASS_LIST_HIGHSCORE >= MAXCLASS_LIST_MIDSCORE 33 | 34 | MINCLASS_LIST_HIGHSCORE1 = 0.9 35 | MINCLASS_LIST_HIGHSCORE2 = 0.9 36 | MINCLASS_LIST_MIDSCORE1 = 0.4 37 | MINCLASS_LIST_MIDSCORE2 = 0.5 38 | assert MINCLASS_LIST_HIGHSCORE1 >= MINCLASS_LIST_MIDSCORE1 39 | assert MINCLASS_LIST_HIGHSCORE2 >= MINCLASS_LIST_MIDSCORE2 40 | 41 | # only 42 | MAXCLASS_ONLY_HIGHSCORE = 0.9 43 | MINCLASS_ONLY_HIGHSCORE1 = 0.01 44 | MINCLASS_ONLY_HIGHSCORE2 = 0.01 45 | MINCLASS_ONLY_MIDSCORE1 = 0.01 46 | MINCLASS_ONLY_MIDSCORE2 = 0.01 47 | assert MINCLASS_ONLY_HIGHSCORE1 >= MINCLASS_ONLY_MIDSCORE1 48 | assert MINCLASS_ONLY_HIGHSCORE2 >= MINCLASS_ONLY_MIDSCORE2 49 | 50 | MINCLASS_ZERO_SCORE = 0.00001 51 | 52 | MINCLASS_HIGH_SCORE_ONLY = 0.65 53 | MINCLASS_MID_SCORE_ONLY = 0.55 54 | assert MINCLASS_HIGH_SCORE_ONLY >= MINCLASS_MID_SCORE_ONLY 55 | 56 | 57 | MINCLASS_NUM = 3 58 | 59 | 60 | class LabelScore: 61 | def __init__(self): 62 | self.label = '' 63 | self.score = 0 64 | 65 | 66 | class MergeObj(object): 67 | def __init__(self): 68 | self.real_max_label = '' 69 | self.real_min_label = '' 70 | self.pre_max_top_label = '' 71 | self.pre_max_top_score = 0 72 | self.pre_min_top_label = '' 73 | self.pre_min_top_score = 0 74 | self.pre_min_label_scores = [0] 75 | self.merge_result = '' 76 | 77 | 78 | def get_max2min_label(max_min_class_file_dir): 79 | min_max_m = {} 80 | max_min_class_file = open(max_min_class_file_dir, 'r', encoding='utf-8') 81 | for line in max_min_class_file.readlines(): 82 | mems = line.split("\t") 83 | max_label = mems[0] 84 | min_label = mems[1] 85 | min_max_m[min_label] = max_label 86 | max_min_class_file.close() 87 | return min_max_m 88 | 89 | 90 | def get_pre_label_scores(max_pre_file_d, min_pre_file_d): 91 | merge_items = [] 92 | max_pre_file = open(max_pre_file_d, 'r', encoding='utf-8') 93 | for line in max_pre_file.readlines(): 94 | line_items = line.split("\t") 95 | real_top_max_label = line_items[0] 96 | pre_top_max_label = line_items[2].split(' ')[0].split(":")[0] 97 | pre_top_max_score = float(line_items[2].split(' ')[0].split(":")[1]) 98 | mer_obj = MergeObj() 99 | mer_obj.pre_max_top_label = pre_top_max_label 100 | mer_obj.pre_max_top_score = float(pre_top_max_score) 101 | mer_obj.real_max_label = real_top_max_label 102 | merge_items.append(mer_obj) 103 | max_pre_file.close() 104 | min_pre_file = open(min_pre_file_d, 'r', encoding='utf-8') 105 | index = 0 106 | for line in min_pre_file.readlines(): 107 | mer_obj = merge_items[index] 108 | index = index + 1 109 | line_items = line.split("\t") 110 | real_min_label = line_items[0] 111 | label_scores_list = line_items[2].split(" ") 112 | pre_top_min_label = label_scores_list[0].split(":")[0] 113 | pre_top_min_score = float(label_scores_list[0].split(":")[1]) 114 | mer_obj.real_min_label = real_min_label 115 | mer_obj.pre_min_top_label = pre_top_min_label 116 | mer_obj.pre_min_top_score = pre_top_min_score 117 | mer_obj.pre_min_label_scores = [] 118 | scores_list = mer_obj.pre_min_label_scores 119 | 120 | for i in range(len(label_scores_list)): 121 | label_score = LabelScore() 122 | temp_labels = label_scores_list[i].split(":") 123 | if len(temp_labels) < 2: 124 | continue 125 | label_score.label = temp_labels[0] 126 | label_score.score = float(temp_labels[1]) 127 | scores_list.append(label_score) 128 | min_pre_file.close() 129 | return merge_items 130 | 131 | 132 | def get_merge_result_each(str_type, merge_item): 133 | assert str_type in ('__label__none', '__label__only', '__label__list') 134 | if str_type == "__label__none": 135 | merge_item.merge_result = "__label__none" 136 | elif str_type == "__label__only": 137 | merge_item.merge_result = merge_item.pre_min_top_label + ":" + str(merge_item.pre_min_top_score) 138 | elif str_type == "__label__list": 139 | merge_item.merge_result = "" 140 | for i in range(len(merge_item.pre_min_label_scores)): 141 | if i == MINCLASS_NUM: 142 | break 143 | label = merge_item.pre_min_label_scores[i].label 144 | score = merge_item.pre_min_label_scores[i].score 145 | if score < MINCLASS_ZERO_SCORE: 146 | break 147 | merge_item.merge_result = merge_item.merge_result + label + ":" + str(score) + "," 148 | 149 | 150 | def get_only_list_none_result(high_score, low_score, merge_item): 151 | if merge_item.pre_min_top_score >= high_score: # one answer 152 | get_merge_result_each("__label__only", merge_item) 153 | elif merge_item.pre_min_top_score < high_score and merge_item.pre_min_top_score >= low_score: # list answer 154 | get_merge_result_each("__label__list", merge_item) 155 | else: # refuse to answer 156 | get_merge_result_each("__label__none", merge_item) 157 | 158 | 159 | def get_merge_result(merge_items, min_max_m): 160 | for merge_item in merge_items: 161 | if merge_item.pre_max_top_label == "__label__none": # none 162 | if merge_item.pre_max_top_score >= MAXCLASS_NONE_HIGHSCORE: # direct rejection 163 | get_merge_result_each("__label__none", merge_item) 164 | elif merge_item.pre_max_top_score >= MAXCLASS_NONE_MIDSCORE and merge_item.pre_max_top_score < MAXCLASS_NONE_HIGHSCORE: # tendency to reject 165 | get_only_list_none_result(MINCLASS_NONE_HIGHSCORE1, MINCLASS_NONE_MIDSCORE1, merge_item) 166 | else: # not tendency to reject 167 | get_only_list_none_result(MINCLASS_NONE_HIGHSCORE2, MINCLASS_NONE_MIDSCORE2, merge_item) 168 | elif merge_item.pre_max_top_label == "__label__list": # list 169 | if merge_item.pre_max_top_score >= MAXCLASS_LIST_HIGHSCORE: # direct answer a list 170 | get_merge_result_each("__label__list", merge_item) 171 | elif merge_item.pre_max_top_score >= MAXCLASS_LIST_MIDSCORE and merge_item.pre_max_top_score < MAXCLASS_LIST_HIGHSCORE: # tendency to answer list 172 | get_only_list_none_result(MINCLASS_LIST_HIGHSCORE1, MINCLASS_LIST_MIDSCORE1, merge_item) 173 | else: # not tendency to answer list 174 | get_only_list_none_result(MINCLASS_LIST_HIGHSCORE2, MINCLASS_LIST_MIDSCORE2, merge_item) 175 | else: # only 176 | filter_pre_min_label_scores = [] 177 | for label_score in merge_item.pre_min_label_scores: 178 | max_label = min_max_m[label_score.label] 179 | if max_label != merge_item.pre_max_top_label: 180 | continue 181 | filter_pre_min_label_scores.append(label_score) 182 | merge_item.pre_min_label_scores = filter_pre_min_label_scores 183 | if len(filter_pre_min_label_scores) == 0: # direct rejection 184 | get_merge_result_each("__label__none", merge_item) 185 | else: 186 | merge_item.pre_min_top_label = filter_pre_min_label_scores[0].label 187 | merge_item.pre_min_top_score = filter_pre_min_label_scores[0].score 188 | if merge_item.pre_max_top_score >= MAXCLASS_ONLY_HIGHSCORE: # not tendency to reject 189 | get_only_list_none_result(MINCLASS_ONLY_HIGHSCORE1, MINCLASS_ONLY_MIDSCORE1, merge_item) 190 | else: # not tendency to one answer 191 | get_only_list_none_result(MINCLASS_ONLY_HIGHSCORE2, MINCLASS_ONLY_MIDSCORE2, merge_item) 192 | 193 | 194 | def write_result(merge_items, result_file_d): 195 | min_pre_file = open(result_file_d, 'w', encoding='utf-8') 196 | for merge_item in merge_items: 197 | min_pre_file.write(merge_item.real_max_label + "\t" + merge_item.real_min_label 198 | + "\t" + merge_item.merge_result + "\n") 199 | min_pre_file.close() 200 | 201 | 202 | def get_result_by_min(min_pre_file_dir, result_file_dir): 203 | with open(min_pre_file_dir, 'r', encoding='utf-8') as f_pre_min: 204 | with open(result_file_dir, 'w', encoding='utf-8') as f_res: 205 | for line in f_pre_min: 206 | lines = line.strip().split('\t') 207 | real_label = lines[0] 208 | model_label_scores = lines[2].split(' ') 209 | temp_label_score_list = [] 210 | write_str = '__label__0\t' + str(real_label) + '\t' 211 | for label_score in model_label_scores: 212 | label_scores = label_score.split(':') 213 | temp_label_score = LabelScore() 214 | temp_label_score.label = label_scores[0] 215 | temp_label_score.score = (float)(label_scores[1]) 216 | temp_label_score_list.append(temp_label_score) 217 | if temp_label_score_list[0].score < MINCLASS_MID_SCORE_ONLY: # refuse answer 218 | write_str += '__label__none' 219 | elif temp_label_score_list[0].score >= MINCLASS_MID_SCORE_ONLY and temp_label_score_list[ 220 | 0].score < MINCLASS_HIGH_SCORE_ONLY: # list answer 221 | for i in range(len(temp_label_score_list)): 222 | if i == MINCLASS_NUM: 223 | break 224 | write_str += str(temp_label_score_list[i].label) + ':' + str( 225 | temp_label_score_list[i].score) + ',' 226 | else: # only answer 227 | write_str += str(temp_label_score_list[0].label) + ':' + str(temp_label_score_list[0].score) 228 | f_res.write(write_str + "\n") 229 | 230 | 231 | def get_acc_recall_f1(result_file_dir): 232 | only_real_num = 0 233 | only_model_num = 0 234 | only_right_num = 0 235 | list_real_num = 0 236 | list_model_num = 0 237 | list_right_num = 0 238 | none_real_num = 0 239 | none_model_num = 0 240 | none_right_num = 0 241 | num = 0 242 | with open(result_file_dir, 'r', encoding='utf-8') as f_pre: 243 | for line in f_pre: 244 | num = num + 1 245 | lines = line.strip().split('\t') 246 | if lines[1] == '0': 247 | none_real_num = none_real_num + 1 248 | elif ',' in lines[1]: 249 | list_real_num = list_real_num + 1 250 | else: 251 | only_real_num = only_real_num + 1 252 | model_label_scores = lines[2].split(',') 253 | if lines[2] == '__label__none': 254 | none_model_num = none_model_num + 1 255 | elif len(model_label_scores) == 1: 256 | only_model_num = only_model_num + 1 257 | else: 258 | list_model_num = list_model_num + 1 259 | real_labels_set = set(lines[1].split(',')) 260 | if lines[1] == '0' and lines[2] == '__label__none': 261 | none_right_num = none_right_num + 1 262 | if len(real_labels_set) == 1 and len(model_label_scores) == 1 and lines[1] == lines[2].split(':')[0]: 263 | only_right_num = only_right_num + 1 264 | if len(real_labels_set) > 1 and len(model_label_scores) > 1: 265 | for i in range(len(model_label_scores)): 266 | label_scores = model_label_scores[i].split(":") 267 | if label_scores[0] in real_labels_set: 268 | list_right_num = list_right_num + 1 269 | break 270 | logging.info('none_right_num: ' + str(none_right_num) + ', list_right_num: ' + str(list_right_num) 271 | + ', only_right_num: ' + str(only_right_num)) 272 | logging.info('none_real_num: ' + str(none_real_num) + ', list_real_num: ' + str(list_real_num) 273 | + ', only_real_num: ' + str(only_real_num)) 274 | logging.info('none_model_num: ' + str(none_model_num) + ', list_model_num: ' + str(list_model_num) 275 | + ', only_model_num: ' + str(only_model_num)) 276 | all_right_num = list_right_num + only_right_num 277 | all_real_num = list_real_num + only_real_num 278 | all_model_num = list_model_num + only_model_num 279 | logging.info('all_right_num: ' + str(all_right_num) + ', all_real_num: ' + str(all_real_num) 280 | + ', all_model_num: ' + str(all_model_num)) 281 | all_acc = all_right_num / all_model_num 282 | all_recall = all_right_num / all_real_num 283 | all_f1 = 2 * all_acc * all_recall / (all_acc + all_recall) 284 | logging.info("all_acc: " + str(all_acc) + ", all_recall: " + str(all_recall) + ", all_f1: " + str(all_f1)) 285 | only_acc = only_right_num / only_model_num 286 | only_recall = only_right_num / only_real_num 287 | only_f1 = 2 * only_acc * only_recall / (only_acc + only_recall) 288 | logging.info("only_acc: " + str(only_acc) + ", only_recall: " + str(only_recall) + ", only_f1: " + str( 289 | only_f1) + ", only_real_prop: " + str(only_real_num / num) + ", only_model_prop: " + str(only_model_num / num)) 290 | list_acc = list_right_num / list_model_num 291 | list_recall = list_right_num / list_real_num 292 | list_f1 = 2 * list_acc * list_recall / (list_acc + list_recall) 293 | logging.info("list_acc: " + str(list_acc) + ", list_recall: " + str(list_recall) + ", list_f1: " + str( 294 | list_f1) + ", list_real_prop: " + str(list_real_num / num) + ", list_model_prop: " + str(list_model_num / num)) 295 | none_acc = none_right_num / none_model_num 296 | none_recall = none_right_num / none_real_num 297 | none_f1 = 2 * none_acc * none_recall / (none_acc + none_recall) 298 | logging.info("none_acc: " + str(none_acc) + ", none_recall: " + str(none_recall) + ", none_f1: " + str( 299 | none_f1) + ", none_real_prop: " + str(none_real_num / num) + ", none_model_prop: " + str(none_model_num / num)) 300 | 301 | 302 | if __name__ == "__main__": 303 | max_pre_file_dir = sys.argv[1] 304 | min_pre_file_dir = sys.argv[2] 305 | result_file_dir = sys.argv[3] 306 | std_label_ques = sys.argv[4] 307 | if max_pre_file_dir == 'none' or std_label_ques == 'none': # only use min_pre result 308 | get_result_by_min(min_pre_file_dir, result_file_dir) 309 | else: # merge max_pre result and min_pre result 310 | merge_items_list = get_pre_label_scores(max_pre_file_dir, min_pre_file_dir) 311 | min_max_map = get_max2min_label(std_label_ques) 312 | get_merge_result(merge_items_list, min_max_map) 313 | write_result(merge_items_list, result_file_dir) 314 | # get acc recall f1 315 | get_acc_recall_f1(result_file_dir) 316 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/models/__init__.py -------------------------------------------------------------------------------- /models/bilstm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """ 4 | a bi-lstm implementation for short text classification using tensroflow library 5 | 6 | """ 7 | 8 | from __future__ import print_function 9 | 10 | import tensorflow as tf 11 | from tensorflow.contrib import rnn 12 | 13 | 14 | class BiLSTM(object): 15 | 16 | def __init__(self, FLAGS): 17 | """Constructor for BiLSTM 18 | 19 | Args: 20 | FLAGS: tf.app.flags, you can see the FLAGS of run_bi_lstm.py 21 | """ 22 | self.input_x = tf.placeholder(tf.int64, [None, FLAGS.seq_length], name="input_x") 23 | self.input_y = tf.placeholder(tf.int64, [None, ], name="input_y") 24 | self.x_len = tf.placeholder(tf.int64, [None, ], name="x_len") 25 | self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob") 26 | 27 | with tf.variable_scope("embedding", initializer=tf.orthogonal_initializer()): 28 | with tf.device('/cpu:0'): 29 | # word embedding table 30 | self.vocab = tf.get_variable('w', [FLAGS.vocab_size, FLAGS.embedding_size]) 31 | embedded = tf.nn.embedding_lookup(self.vocab, self.input_x) # [batch_size, seq_length, embedding_size] 32 | inputs = tf.split(embedded, FLAGS.seq_length, 33 | 1) # [[batch_size, 1, embedding_size], [batch_size, 1, embedding_size], number is seq_length] 34 | inputs = [tf.squeeze(input_, [1]) for input_ in 35 | inputs] # [[batch_size, embedding_size], [batch_size, embedding_size], number is seq_length] 36 | 37 | with tf.variable_scope("encoder", initializer=tf.orthogonal_initializer()): 38 | lstm_fw_cell = rnn.BasicLSTMCell(FLAGS.num_units) 39 | lstm_bw_cell = rnn.BasicLSTMCell(FLAGS.num_units) 40 | lstm_fw_cell_stack = rnn.MultiRNNCell([lstm_fw_cell] * FLAGS.lstm_layers, state_is_tuple=True) 41 | lstm_bw_cell_stack = rnn.MultiRNNCell([lstm_bw_cell] * FLAGS.lstm_layers, state_is_tuple=True) 42 | lstm_fw_cell_stack = rnn.DropoutWrapper(lstm_fw_cell_stack, input_keep_prob=self.dropout_keep_prob, 43 | output_keep_prob=self.dropout_keep_prob) 44 | lstm_bw_cell_stack = rnn.DropoutWrapper(lstm_bw_cell_stack, input_keep_prob=self.dropout_keep_prob, 45 | output_keep_prob=self.dropout_keep_prob) 46 | self.outputs, self.fw_st, self.bw_st = rnn.static_bidirectional_rnn(lstm_fw_cell_stack, lstm_bw_cell_stack, 47 | inputs, sequence_length=self.x_len, 48 | dtype=tf.float32) # multi-layer 49 | # only use the last layer 50 | last_layer_no = FLAGS.lstm_layers - 1 51 | self.states = tf.concat([self.fw_st[last_layer_no].h, self.bw_st[last_layer_no].h], 52 | 1) # [batchsize, (num_units * 2)] 53 | 54 | attention_size = 2 * FLAGS.num_units 55 | with tf.variable_scope('attention'): 56 | attention_w = tf.Variable(tf.truncated_normal([2 * FLAGS.num_units, attention_size], stddev=0.1), 57 | name='attention_w') # [num_units * 2, num_units * 2] 58 | attention_b = tf.get_variable("attention_b", initializer=tf.zeros([attention_size])) # [num_units * 2] 59 | u_list = [] 60 | for index in range(FLAGS.seq_length): 61 | u_t = tf.tanh(tf.matmul(self.outputs[index], attention_w) + attention_b) # [batchsize, num_units * 2] 62 | u_list.append(u_t) # seq_length * [batchsize, num_units * 2] 63 | u_w = tf.Variable(tf.truncated_normal([attention_size, 1], stddev=0.1), 64 | name='attention_uw') # [num_units * 2, 1] 65 | attn_z = [] 66 | for index in range(FLAGS.seq_length): 67 | z_t = tf.matmul(u_list[index], u_w) 68 | attn_z.append(z_t) # seq_length * [batchsize, 1] 69 | # transform to batch_size * sequence_length 70 | attn_zconcat = tf.concat(attn_z, axis=1) # [batchsize, seq_length] 71 | alpha = tf.nn.softmax(attn_zconcat) # [batchsize, seq_length] 72 | # transform to sequence_length * batch_size * 1 , same rank as outputs 73 | alpha_trans = tf.reshape(tf.transpose(alpha, [1, 0]), 74 | [FLAGS.seq_length, -1, 1]) # [seq_length, batchsize, 1] 75 | self.final_output = tf.reduce_sum(self.outputs * alpha_trans, 0) # [batchsize, num_units * 2] 76 | 77 | with tf.variable_scope("output_layer"): 78 | weights = tf.get_variable("weights", [2 * FLAGS.num_units, FLAGS.label_size]) 79 | biases = tf.get_variable("biases", initializer=tf.zeros([FLAGS.label_size])) 80 | 81 | with tf.variable_scope("acc"): 82 | # use attention 83 | self.logits = tf.matmul(self.final_output, weights) + biases # [batchsize, label_size] 84 | # not use attention 85 | # self.logits = tf.matmul(self.states, weights) + biases 86 | self.prediction = tf.nn.softmax(self.logits, name="prediction_softmax") # [batchsize, label_size] 87 | self.loss = tf.reduce_mean( 88 | tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)) 89 | self.global_step = tf.train.get_or_create_global_step() 90 | self.correct = tf.equal(tf.argmax(self.prediction, 1), self.input_y) 91 | self.acc = tf.reduce_mean(tf.cast(self.correct, tf.float32)) 92 | _, self.arg_index = tf.nn.top_k(self.prediction, k=FLAGS.label_size) # [batch_size, label_size] 93 | 94 | with tf.variable_scope('training'): 95 | # optimizer 96 | self.learning_rate = tf.train.exponential_decay(FLAGS.lr, self.global_step, 200, 0.96, staircase=True) 97 | self.train_step = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss, 98 | global_step=self.global_step) 99 | 100 | self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=2) 101 | 102 | def export_model(self, export_path, sess): 103 | builder = tf.saved_model.builder.SavedModelBuilder(export_path) 104 | tensor_info_x = tf.saved_model.utils.build_tensor_info(self.input_x) 105 | tensor_info_y = tf.saved_model.utils.build_tensor_info(self.prediction) 106 | tensor_info_len = tf.saved_model.utils.build_tensor_info(self.x_len) 107 | tensor_dropout_keep_prob = tf.saved_model.utils.build_tensor_info(self.dropout_keep_prob) # 1.0 for inference 108 | prediction_signature = ( 109 | tf.saved_model.signature_def_utils.build_signature_def( 110 | inputs={'input': tensor_info_x, 'sen_len': tensor_info_len, 111 | 'dropout_keep_prob': tensor_dropout_keep_prob}, 112 | outputs={'output': tensor_info_y}, 113 | method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)) 114 | legacy_init_op = None 115 | builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING], 116 | signature_def_map={'prediction': prediction_signature, }, 117 | legacy_init_op=legacy_init_op, clear_devices=True, saver=self.saver) 118 | builder.save() 119 | -------------------------------------------------------------------------------- /models/dssm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """ 4 | a lstm + dssm implementation for short text match using tensroflow library 5 | 6 | """ 7 | 8 | import random 9 | import tensorflow as tf 10 | from tensorflow.contrib import rnn 11 | 12 | 13 | class Dssm(object): 14 | def __init__(self, num_lstm_units, batch_size, negtive_size, SOFTMAX_R, learning_rate, vocab_size, 15 | embedding_size=100, use_same_cell=False): 16 | """Constructor for Dssm 17 | 18 | Args: 19 | num_lstm_units: int, The number of units in the LSTM cell. 20 | batch_size: int, The number of examples in each batch 21 | negtive_size: int, The number of negative example. 22 | SOFTMAX_R: float, A regulatory factor for cosine similarity 23 | learning_rate: float, learning rate 24 | vocab_size: int, The number of vocabulary 25 | embedding_size: int the size of vocab embedding 26 | use_same_cell: (optional) bool whether to use same cell for fw, bw lstm, default is false 27 | """ 28 | self.global_step = tf.Variable(0, trainable=False) 29 | self.keep_prob = tf.placeholder(tf.float32, name="keep_prob") 30 | self.input_x = tf.placeholder(tf.int32, [None, None], name="input_x") # [batch_size, seq_len] 31 | self.length_x = tf.placeholder(tf.int32, [None, ], name="length_x") # [batch_size, ] 32 | self.input_y = tf.placeholder(tf.int32, [None, None], name="input_y") # [batch_size, seq_len] 33 | self.length_y = tf.placeholder(tf.int32, [None, ], name="length_y") # [batch_size, ] 34 | self.lstm_fw_cell = rnn.BasicLSTMCell(num_lstm_units) 35 | if use_same_cell: 36 | self.lstm_bw_cell = self.lstm_fw_cell 37 | else: 38 | self.lstm_bw_cell = rnn.BasicLSTMCell(num_lstm_units) 39 | with tf.name_scope("keep_prob"): 40 | self.lstm_fw_cell = rnn.DropoutWrapper(self.lstm_fw_cell, input_keep_prob=self.keep_prob, 41 | output_keep_prob=self.keep_prob) 42 | self.lstm_bw_cell = rnn.DropoutWrapper(self.lstm_bw_cell, input_keep_prob=self.keep_prob, 43 | output_keep_prob=self.keep_prob) 44 | 45 | with tf.device('/cpu:0'), tf.name_scope("embedding"): 46 | # one_gram 47 | self.vocab = tf.get_variable('w', [vocab_size, embedding_size]) 48 | self.lstm_input_embedding_x = tf.nn.embedding_lookup(self.vocab, 49 | self.input_x) # [batch_size, seq_len, embedding_size] 50 | self.lstm_input_embedding_y = tf.nn.embedding_lookup(self.vocab, 51 | self.input_y) # [batch_size, seq_len, embedding_size] 52 | 53 | with tf.name_scope('representation'): 54 | self.states_x = \ 55 | tf.nn.bidirectional_dynamic_rnn(self.lstm_fw_cell, self.lstm_bw_cell, self.lstm_input_embedding_x, 56 | self.length_x, 57 | dtype=tf.float32)[1] 58 | self.output_x = tf.concat([self.states_x[0][1], self.states_x[1][1]], 1) # [batch_size, 2*num_lstm_units] 59 | self.states_y = \ 60 | tf.nn.bidirectional_dynamic_rnn(self.lstm_fw_cell, self.lstm_bw_cell, self.lstm_input_embedding_y, 61 | self.length_y, 62 | dtype=tf.float32)[1] 63 | self.output_y = tf.concat([self.states_y[0][1], self.states_y[1][1]], 1) # [batch_size, 2*num_lstm_units] 64 | self.q_y_raw = tf.nn.relu(self.output_x, name="q_y_raw") # [batch_size, num_lstm_units*2] 65 | print("self.q_y_raw: " + str(self.q_y_raw)) 66 | self.qs_y_raw = tf.nn.relu(self.output_y, name="qs_y_raw") # [batch_size, num_lstm_units*2] 67 | print("self.qs_y_raw: " + str(self.qs_y_raw)) 68 | 69 | with tf.name_scope('rotate'): 70 | temp = tf.tile(self.qs_y_raw, [1, 1]) # [batch_size, num_lstm_units*2] 71 | self.qs_y = tf.tile(self.qs_y_raw, [1, 1]) # [batch_size, num_lstm_units*2] 72 | for i in range(negtive_size): 73 | rand = int((random.random() + i) * batch_size / negtive_size) 74 | if rand == 0: 75 | rand = rand + 1 76 | rand_qs_y1 = tf.slice(temp, [rand, 0], [batch_size - rand, -1]) # [batch_size - rand, num_lstm_units*2] 77 | rand_qs_y2 = tf.slice(temp, [0, 0], [rand, -1]) # [rand, num_lstm_units*2] 78 | self.qs_y = tf.concat(axis=0, values=[self.qs_y, rand_qs_y1, 79 | rand_qs_y2]) # [batch_size*(negtive_size+1), num_lstm_units*2] 80 | 81 | with tf.name_scope('sim'): 82 | # cosine similarity 83 | q_norm = tf.tile(tf.sqrt(tf.reduce_sum(tf.square(self.q_y_raw), 1, True)), 84 | [negtive_size + 1, 1]) # [(negtive_size + 1) * batch_size, 1] 85 | qs_norm = tf.sqrt(tf.reduce_sum(tf.square(self.qs_y), 1, True)) # [batch_size*(negtive_size+1), 1] 86 | prod = tf.reduce_sum(tf.multiply(tf.tile(self.q_y_raw, [negtive_size + 1, 1]), self.qs_y), 1, 87 | True) # [batch_size*(negtive_size + 1), 1] 88 | norm_prod = tf.multiply(q_norm, qs_norm) # [batch_size*(negtive_size + 1), 1] 89 | sim_raw = tf.truediv(prod, norm_prod) # [batch_size*(negtive_size + 1), 1] 90 | self.cos_sim = tf.transpose(tf.reshape(tf.transpose(sim_raw), [negtive_size + 1, 91 | batch_size])) * SOFTMAX_R # [batch_size, negtive_size + 1] 92 | 93 | with tf.name_scope('loss'): 94 | # train Loss 95 | self.prob = tf.nn.softmax(self.cos_sim) # [batch_size, negtive_size + 1] 96 | self.hit_prob = tf.slice(self.prob, [0, 0], [-1, 1]) # [batch_size, 1] #positive 97 | raw_loss = -tf.reduce_sum(tf.log(self.hit_prob)) / batch_size 98 | self.loss = raw_loss 99 | 100 | with tf.name_scope('training'): 101 | # optimizer 102 | self.learning_rate = tf.train.exponential_decay(learning_rate, self.global_step, 1000, 0.96, staircase=True) 103 | self.train_step = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss, 104 | global_step=self.global_step) 105 | 106 | # acc for test data 107 | with tf.name_scope('cosine_similarity_pre'): 108 | # Cosine similarity 109 | self.q_norm_pre = tf.sqrt(tf.reduce_sum(tf.square(self.q_y_raw), 1, True)) # b*1 110 | self.qs_norm_pre = tf.transpose(tf.sqrt(tf.reduce_sum(tf.square(self.qs_y_raw), 1, True))) # 1*sb 111 | self.prod_nu_pre = tf.matmul(self.q_y_raw, tf.transpose(self.qs_y_raw)) # b*sb 112 | self.norm_prod_de = tf.matmul(self.q_norm_pre, self.qs_norm_pre) # b*sb 113 | self.cos_sim_pre = tf.truediv(self.prod_nu_pre, self.norm_prod_de) * SOFTMAX_R # b*sb 114 | 115 | with tf.name_scope('prob_pre'): 116 | self.prob_pre = tf.nn.softmax(self.cos_sim_pre) # b*sb 117 | # self.hit_prob_pre = tf.slice(self.prob_pre, [0, 0], [-1, 1]) # [batch_size, 1] #positive 118 | # self.test_loss = -tf.reduce_sum(tf.log(self.hit_prob_pre)) / batch_size 119 | -------------------------------------------------------------------------------- /run_bi_lstm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """ 4 | running bi-lstm for short text classification 5 | 6 | """ 7 | 8 | import os 9 | import tensorflow as tf 10 | import shutil 11 | from utils.classifier_utils import TextLoader 12 | from models.bilstm import BiLSTM 13 | 14 | flags = tf.flags 15 | FLAGS = flags.FLAGS 16 | 17 | flags.DEFINE_string("train_path", None, "dir for train data") 18 | flags.DEFINE_string("valid_path", None, "dir for valid data") 19 | flags.DEFINE_string("map_file_path", None, "dir for label std question mapping") 20 | flags.DEFINE_string("model_path", None, "dir for save checkpoint data") 21 | # flags.DEFINE_string("result_file", None, "file for valid result") 22 | flags.DEFINE_string("vocab_file", None, "file for vocab") 23 | flags.DEFINE_string("label_file", None, "file for label") 24 | flags.DEFINE_integer("embedding_size", 256, "size of word embedding") 25 | flags.DEFINE_integer("num_units", 256, "The number of units in the LSTM cell") 26 | flags.DEFINE_integer("vocab_size", 256, "The size of vocab") 27 | flags.DEFINE_integer("label_size", 256, "The num of label") 28 | flags.DEFINE_integer("batch_size", 128, "batch_size of train data") 29 | flags.DEFINE_integer("seq_length", 50, "the length of sequence") 30 | flags.DEFINE_integer("num_epcho", 30, "the epcho num") 31 | flags.DEFINE_integer("check_every", 100, "the epcho num") 32 | flags.DEFINE_integer("lstm_layers", 2, "the layers of lstm") 33 | flags.DEFINE_float("lr", 0.001, "learning rate") 34 | flags.DEFINE_float("dropout_keep_prob", 0.8, "drop_out keep prob") 35 | 36 | 37 | def main(_): 38 | tf.logging.set_verbosity(tf.logging.INFO) 39 | data_loader = TextLoader(True, FLAGS.train_path, FLAGS.map_file_path, FLAGS.batch_size, FLAGS.seq_length, None, 40 | None, None, 'utf8', False) 41 | valid_data_loader = TextLoader(False, FLAGS.valid_path, FLAGS.map_file_path, FLAGS.batch_size, FLAGS.seq_length, 42 | data_loader.vocab, 43 | data_loader.labels, data_loader.std_label_map, 'utf8', False) 44 | tf.logging.info("vocab_size: " + str(data_loader.vocab_size)) 45 | FLAGS.vocab_size = data_loader.vocab_size 46 | tf.logging.info("label_size: " + str(data_loader.label_size)) 47 | FLAGS.label_size = data_loader.label_size 48 | bilstm = BiLSTM(FLAGS) 49 | init = tf.global_variables_initializer() 50 | config = tf.ConfigProto(allow_soft_placement=True) 51 | config.gpu_options.allow_growth = True 52 | with tf.Session(config=config) as sess: 53 | sess.run(init) 54 | idx = 0 55 | test_best_acc = 0 56 | for epcho in range(FLAGS.num_epcho): # for each epoch 57 | data_loader.reset_batch_pointer() 58 | for train_batch_num in range(data_loader.num_batches): # for each batch 59 | input_x, input_y, x_len, _ = data_loader.next_batch() 60 | feed = {bilstm.input_x: input_x, bilstm.input_y: input_y, bilstm.x_len: x_len, 61 | bilstm.dropout_keep_prob: FLAGS.dropout_keep_prob} 62 | _, global_step_op, train_loss, train_acc = sess.run( 63 | [bilstm.train_step, bilstm.global_step, bilstm.loss, bilstm.acc], feed_dict=feed) 64 | tf.logging.info("training...........global_step = {}, epoch = {}, current_batch = {}, " 65 | "train_loss = {:.4f}, accuracy = {:.4f}".format(global_step_op, epcho, train_batch_num, 66 | train_loss, train_acc)) 67 | idx += 1 68 | if idx % FLAGS.check_every == 0: 69 | all_num = 0 70 | acc_num = 0 71 | valid_data_loader.reset_batch_pointer() 72 | write_result = [] 73 | for _ in range(valid_data_loader.num_batches): 74 | input_x_valid, input_y_valid, x_len_valid, _ = valid_data_loader.next_batch() 75 | feed = {bilstm.input_x: input_x_valid, bilstm.input_y: input_y_valid, bilstm.x_len: x_len_valid, 76 | bilstm.dropout_keep_prob: 1.0} 77 | prediction, arg_index = sess.run([bilstm.prediction, bilstm.arg_index], feed_dict=feed) 78 | all_num = all_num + len(input_y_valid) 79 | # write_str = "" 80 | for i, indexs in enumerate(arg_index): 81 | pre_label_id = indexs[0] 82 | real_label_id = input_y_valid[i] 83 | if pre_label_id == real_label_id: 84 | acc_num = acc_num + 1 85 | # if real_label_id in valid_data_loader.id_2_label: 86 | # write_str = valid_data_loader.id_2_label.get(real_label_id) 87 | # else: 88 | # write_str = "__label__unknown" 89 | # for index in indexs: 90 | # cur_label = valid_data_loader.id_2_label.get(index) 91 | # cur_score = prediction[i][index] 92 | # write_str = write_str + " " + cur_label + ":" + str(cur_score) 93 | # write_str = write_str + "\n" 94 | # write_result.append(write_str) 95 | test_acc = acc_num * 1.0 / all_num 96 | tf.logging.info( 97 | "testing...........global_step = {}, epoch = {}, accuracy = {:.4f}, cur_best_acc = {}".format( 98 | global_step_op, epcho, test_acc, test_best_acc)) 99 | if test_best_acc < test_acc: 100 | test_best_acc = test_acc 101 | # save_model 102 | if not os.path.exists(FLAGS.model_path): 103 | os.makedirs(FLAGS.model_path) 104 | checkpoint_path = os.path.join(FLAGS.model_path, 'lstm_model') 105 | bilstm.saver.save(sess, checkpoint_path, global_step=global_step_op) 106 | # export model 107 | export_path = os.path.join(FLAGS.model_path, 'lstm_tf_serving') 108 | if os.path.isdir(export_path): 109 | shutil.rmtree(export_path) 110 | bilstm.export_model(export_path, sess) 111 | # resultfile = open(FLAGS.result_file, 'w', encoding='utf-8') 112 | # for pre_sen in write_result: 113 | # resultfile.write(pre_sen) 114 | tf.logging.info( 115 | "has saved model and write.result...................................................................") 116 | # resultfile.close() 117 | # save label and vocab 118 | vocabfile = open(FLAGS.vocab_file, 'w', encoding='utf-8') 119 | for key, value in data_loader.vocab.items(): 120 | vocabfile.write(str(key) + "\t" + str(value) + '\n') 121 | vocabfile.close() 122 | labelfile = open(FLAGS.label_file, 'w', encoding='utf-8') 123 | for key, value in data_loader.labels.items(): 124 | labelfile.write(str(key) + "\t" + str(value) + '\n') 125 | labelfile.close() 126 | # break 127 | 128 | 129 | if __name__ == "__main__": 130 | tf.app.run() 131 | -------------------------------------------------------------------------------- /run_dssm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """ 4 | running lstm + dssm for short text matching 5 | 6 | """ 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | from utils.match_utils import DataHelper 11 | from models.dssm import Dssm 12 | 13 | flags = tf.app.flags 14 | FLAGS = flags.FLAGS 15 | 16 | # data parameters 17 | flags.DEFINE_string('train_path', None, 'dir for train data') 18 | flags.DEFINE_string('valid_path', None, 'dir for valid data') 19 | flags.DEFINE_string('map_file_path', None, 'dir for label std question mapping') 20 | flags.DEFINE_string('model_path', None, 'Model path') 21 | flags.DEFINE_string('label2id_path', None, 'label2id file path') 22 | flags.DEFINE_string('vocab2id_path', None, 'vocab2id file path') 23 | 24 | # training parameters 25 | flags.DEFINE_integer('softmax_r', 45, 'Smooth parameter for osine similarity') 26 | flags.DEFINE_integer('embedding_size', 200, 'max_sequence_len') 27 | flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') 28 | 29 | flags.DEFINE_float('keep_prob', 0.8, 'Dropout keep prob.') 30 | flags.DEFINE_integer('num_epoches', 10, "Number of epochs.") 31 | flags.DEFINE_integer('batch_size', 50, "Size of one batch.") 32 | flags.DEFINE_integer('negative_size', 5, "Size of negtive sample.") 33 | flags.DEFINE_integer('eval_every', 50, "Record summaries every n steps.") 34 | flags.DEFINE_integer('num_units', 100, "Number of units of lstm(default: 100)") 35 | flags.DEFINE_bool('use_same_cell', True, "whether to use sam cell") 36 | 37 | 38 | def feed_dict_builder(batch, keep_prob, dssm): 39 | # batch: ([(q1_len, [q1_w1, q1_w2,...]), (q2_len, [q2_w1, q2_w2,...]), ...], [(std1_len, [std1_w1, std1_w2,...]), (std2_len, [std2_w1, std2_w2,...]), ...]) 40 | length_x = [x[0] for x in batch[0]] 41 | input_x = [x[1] for x in batch[0]] 42 | length_y = [y[0] for y in batch[1]] 43 | input_y = [y[1] for y in batch[1]] 44 | feed_dict = { 45 | dssm.input_x: np.array(input_x, dtype=np.int32), 46 | dssm.length_x: np.array(length_x, dtype=np.int32), 47 | dssm.input_y: np.array(input_y, dtype=np.int32), 48 | dssm.length_y: np.array(length_y, dtype=np.int32), 49 | dssm.keep_prob: keep_prob 50 | } 51 | return feed_dict 52 | 53 | 54 | def cal_predict_acc_num(predict_prob, test_batch_q, predict_label_seq): 55 | # calculate acc 56 | assert (len(test_batch_q) == len(predict_prob)) 57 | real_labels = [] 58 | for ques in test_batch_q: 59 | label = ques[2] 60 | real_labels.append(label) 61 | acc_num = 0 62 | sorted_scores = [] 63 | for i, scores in enumerate(predict_prob): 64 | label_scores = {} 65 | for index, score in enumerate(scores): 66 | label_scores[predict_label_seq[index]] = score 67 | # sort 68 | label_scores = sorted(label_scores.items(), key=lambda x: x[1], reverse=True) 69 | sorted_scores.append(label_scores) 70 | top_label = label_scores[0][0] 71 | if top_label == real_labels[i]: 72 | acc_num = acc_num + 1 73 | return acc_num, real_labels, sorted_scores 74 | 75 | 76 | def main(_): 77 | tf.logging.set_verbosity(tf.logging.INFO) 78 | data_help = DataHelper(FLAGS.train_path, FLAGS.valid_path, None, FLAGS.map_file_path, FLAGS.batch_size, 79 | FLAGS.num_epoches, None, None, True) 80 | dssm = Dssm(FLAGS.num_units, FLAGS.batch_size, FLAGS.negative_size, FLAGS.softmax_r, FLAGS.learning_rate, 81 | data_help.vocab_size, FLAGS.embedding_size, use_same_cell=False) 82 | config = tf.ConfigProto(allow_soft_placement=True) 83 | config.gpu_options.allow_growth = True 84 | saver = tf.train.Saver(max_to_keep=1) 85 | train_batches = data_help.train_batch_iterator(data_help.train_id_ques, data_help.std_id_ques) 86 | best_valid_acc = 0 87 | # run_num = 0 88 | with tf.Session(config=config) as sess: 89 | sess.run(tf.global_variables_initializer()) 90 | for train_batch_step, train_batch in enumerate(train_batches): 91 | _, step, train_lr, train_loss = sess.run([dssm.train_step, dssm.global_step, dssm.learning_rate, dssm.loss], 92 | feed_dict=feed_dict_builder(train_batch, FLAGS.keep_prob, dssm)) 93 | tf.logging.info("Training...... global_step {}, epcho {}, train_batch_step {}, learning rate {} " 94 | "loss {}".format(step, round(step * FLAGS.batch_size / data_help.train_num, 2), 95 | train_batch_step, round(train_lr, 4), train_loss)) 96 | if (train_batch_step + 1) % FLAGS.eval_every == 0: 97 | # run_num = run_num + 1 98 | # if run_num % 2 == 0: 99 | # break 100 | all_valid_acc_num = 0 101 | all_valid_num = 0 102 | valid_batches = data_help.valid_batch_iterator() 103 | for _, valid_batch_q in enumerate(valid_batches): 104 | all_valid_num = all_valid_num + len(valid_batch_q) 105 | valid_batch = (valid_batch_q, data_help.std_batch) 106 | valid_prob = sess.run([dssm.prob_pre], feed_dict=feed_dict_builder(valid_batch, 1.0, dssm)) 107 | valid_acc_num, real_labels, _ = cal_predict_acc_num(valid_prob[0], valid_batch_q, 108 | data_help.id2label) 109 | all_valid_acc_num = all_valid_acc_num + valid_acc_num 110 | current_acc = all_valid_acc_num * 1.0 / all_valid_num 111 | tf.logging.info( 112 | "validing...... global_step {}, valid_acc {}, current_best_acc {}".format(step, current_acc, 113 | best_valid_acc)) 114 | if current_acc > best_valid_acc: 115 | tf.logging.info( 116 | "validing...... get the best acc {} and saving model and result".format(current_acc)) 117 | saver.save(sess, FLAGS.model_path + "dssm_{}".format(train_batch_step)) 118 | best_valid_acc = current_acc 119 | # save label2id, vocab2id 120 | vocabfile = open(FLAGS.vocab2id_path, 'w', encoding='utf-8') 121 | for key, value in data_help.vocab2id.items(): 122 | vocabfile.write(str(key) + "\t" + str(value) + '\n') 123 | vocabfile.close() 124 | labelfile = open(FLAGS.label2id_path, 'w', encoding='utf-8') 125 | for key, value in data_help.label2id.items(): 126 | labelfile.write(str(key) + "\t" + str(value) + '\n') 127 | labelfile.close() 128 | # break 129 | 130 | 131 | if __name__ == "__main__": 132 | tf.app.run() 133 | -------------------------------------------------------------------------------- /sptm/format_result.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | format result to qa_match standard format 5 | USAGE: 6 | python format_result.py [test_file] [model_result] [standard_format_output] 7 | """ 8 | import codecs 9 | import sys 10 | 11 | if __name__ == "__main__": 12 | # real result: label\tidx\tsentence 13 | real_result_file = sys.argv[1] 14 | # model result: __label__1|score_1,... \tsentence 15 | model_result_file = sys.argv[2] 16 | result_file = sys.argv[3] 17 | 18 | real_labels = [] 19 | for line in codecs.open(real_result_file, encoding='utf-8'): 20 | if len(line.split('\t')) == 3: 21 | real_labels.append(line.split('\t')[0]) 22 | 23 | fout = codecs.open(result_file, encoding='utf-8', mode='w+') 24 | with codecs.open(model_result_file, encoding='utf-8') as f: 25 | for idx, line in enumerate(f): 26 | line = line.strip() 27 | s_line = line.split('\t') 28 | if len(s_line) >= 2: 29 | model_res = s_line[0] 30 | sentence = s_line[1] 31 | 32 | model_res = model_res.replace("__label__", "") \ 33 | .replace('|', ":").replace(",", " ") 34 | # real_label\tsentence\tmodel_labels 35 | fout.write("{}\t{}\t{}\n".format(real_labels[idx], 36 | sentence, model_res)) 37 | f.close() 38 | fout.close() 39 | -------------------------------------------------------------------------------- /sptm/models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | models implementation for runing pretrain and finetune language models using tensroflow library 4 | 5 | """ 6 | import tensorflow as tf 7 | import collections 8 | import re 9 | import math 10 | 11 | 12 | class BiDirectionalLmModel(object): 13 | """Constructor for BiDirectionalLmModel 14 | Args: 15 | lstm_dim: int, The number of units in the LSTM cell. 16 | embedding_dim: int, The size of vocab embedding 17 | layer_num: int, The number of LSTM layer. 18 | token_num: int, The number of Token 19 | input_arg: dict, Args of inputs 20 | """ 21 | def __init__(self, input_arg, other_arg_dict): 22 | self.lstm_dim = input_arg.lstm_dim 23 | self.embedding_dim = input_arg.embedding_dim 24 | self.layer_num = input_arg.layer_num 25 | self.token_num = other_arg_dict["token_num"] 26 | self.input_arg = input_arg 27 | if 2 * self.lstm_dim != self.embedding_dim: 28 | tf.logging.info('please set the 2 * lstm_dim == embedding_dim') 29 | assert False 30 | 31 | #Build graph for SPTM 32 | def build(self): 33 | assert self.input_arg.representation_type in ["lstm", "transformer"] 34 | self.ph_tokens = tf.placeholder(dtype=tf.int32, shape=[None, None], name="ph_tokens") # [batch_size, seq_length] 35 | self.ph_length = tf.placeholder(dtype=tf.int32, shape=[None], name="ph_length") # [batch_size] 36 | self.ph_dropout_rate = tf.placeholder(dtype=tf.float32, shape=None, name="ph_dropout_rate") 37 | self.ph_input_mask = tf.placeholder(dtype=tf.int32, shape=[None, None], name="ph_input_mask") #[batch_size, seq_length] 38 | 39 | self.v_token_embedding = tf.get_variable(name="v_token_embedding", 40 | shape=[self.token_num, self.embedding_dim], 41 | dtype=tf.float32, 42 | initializer=tf.contrib.layers.xavier_initializer()) #[token_num, embedding_dim] 43 | seq_embedding = tf.nn.embedding_lookup(self.v_token_embedding, 44 | self.ph_tokens) # [batch_size, seq_length, embedding_dim] 45 | 46 | if self.input_arg.representation_type == "lstm": 47 | tf.logging.info("representation using lstm ...........................") 48 | with tf.variable_scope(tf.get_variable_scope(), reuse=False): 49 | seq_embedding = tf.nn.dropout(seq_embedding, keep_prob=1 - self.ph_dropout_rate) 50 | last_output = seq_embedding 51 | cur_state = None 52 | for layer in range(1, self.layer_num + 1): 53 | fw_cell = tf.nn.rnn_cell.LSTMCell(self.lstm_dim, name="fw_layer_" + str(layer)) 54 | bw_cell = tf.nn.rnn_cell.LSTMCell(self.lstm_dim, name="bw_layer_" + str(layer)) 55 | cur_output, cur_state = tf.nn.bidirectional_dynamic_rnn(fw_cell, bw_cell, last_output, self.ph_length, 56 | dtype=tf.float32) # [batch, length, dim] 57 | cur_output = tf.concat(cur_output, -1) # [batch, length, 2 * dim] 58 | cur_output = tf.nn.dropout(cur_output, keep_prob=1 - self.ph_dropout_rate) 59 | last_output = tf.contrib.layers.layer_norm(last_output + cur_output, begin_norm_axis=-1) # add and norm 60 | 61 | output = tf.layers.dense(last_output, self.embedding_dim, activation=tf.tanh, 62 | kernel_initializer=tf.contrib.layers.xavier_initializer()) # [batch, length, 2 * dim] 63 | output = tf.nn.dropout(output, keep_prob=1 - self.ph_dropout_rate) 64 | 65 | # sequence output 66 | self.output = tf.contrib.layers.layer_norm(last_output + output, 67 | begin_norm_axis=-1) # add and norm [batch, length, 2 * dim] 68 | 69 | # max pool output 70 | seq_len = tf.shape(self.ph_tokens)[1] 71 | mask = tf.expand_dims(tf.cast(tf.sequence_mask(self.ph_length, maxlen=seq_len), tf.float32), 72 | axis=2) # [batch, len, 1] 73 | mask = (1 - mask) * -1e5 74 | self.max_pool_output = tf.reduce_max(self.output + mask, axis=1, keepdims=False) # [batch, 2 * dim] 75 | elif self.input_arg.representation_type == "transformer": 76 | tf.logging.info("representation using transformer ...........................") 77 | input_shape = self.get_shape_list(seq_embedding) # [batch_size, seq_length, embedding_size] 78 | batch_size = input_shape[0] 79 | seq_length = input_shape[1] 80 | embedding_size = input_shape[2] 81 | 82 | with tf.variable_scope("pos_embeddings"): 83 | all_position_embeddings = tf.get_variable( 84 | name="position_embeddings", 85 | shape=[self.input_arg.max_position_embeddings, embedding_size], 86 | initializer=tf.truncated_normal_initializer(stddev=self.input_arg.initializer_range)) # [max_position_embeddings, embedding_size] 87 | position_embeddings = tf.slice(all_position_embeddings, [0, 0], [seq_length, -1]) # [seq_length, embedding_size] 88 | position_embeddings = tf.reshape(position_embeddings, [1, seq_length, embedding_size]) # [1, seq_length, embedding_size] 89 | seq_embedding += position_embeddings # [batch_size, seq_length, embedding_size] 90 | seq_embedding = tf.contrib.layers.layer_norm(seq_embedding, begin_norm_axis=-1, begin_params_axis=-1) #[batch_size, seq_length, embedding_size] 91 | 92 | with tf.variable_scope("encoder"): 93 | assert self.input_arg.hidden_size % self.input_arg.num_attention_heads == 0 94 | attention_head_size = self.input_arg.hidden_size // self.input_arg.num_attention_heads 95 | self.all_layer_outputs = [] 96 | if embedding_size != self.input_arg.hidden_size: 97 | hidden_output = self.dense_layer_2d(seq_embedding, self.input_arg.hidden_size, None, name="embedding_to_hidden") 98 | else: 99 | hidden_output = seq_embedding # [batch_size, seq_length, hidden_size] 100 | 101 | with tf.variable_scope("transformer", reuse=tf.AUTO_REUSE): 102 | for layer_id in range(self.input_arg.num_hidden_layers): 103 | with tf.name_scope("layer_%d" % layer_id): 104 | with tf.variable_scope("self_attention"): 105 | q = self.dense_layer_3d(hidden_output, self.input_arg.num_attention_heads, attention_head_size, None, "query") # [B, F, N, D] 106 | k = self.dense_layer_3d(hidden_output, self.input_arg.num_attention_heads, attention_head_size, None, "key") # [B, F, N, D] 107 | v = self.dense_layer_3d(hidden_output, self.input_arg.num_attention_heads, attention_head_size, None, "value") # [B, F, N, D] 108 | q = tf.transpose(q, [0, 2, 1, 3]) # [B, N, F, D] 109 | k = tf.transpose(k, [0, 2, 1, 3]) # [B, N, F, D] 110 | v = tf.transpose(v, [0, 2, 1, 3]) # [B, N, F, D] 111 | attention_mask = tf.reshape(self.ph_input_mask, [batch_size, 1, seq_length, 1]) # [B, 1, F, 1] 112 | logits = tf.matmul(q, k, transpose_b=True) # q*k => [B, N, F, F] 113 | logits = tf.multiply(logits, 1.0 / math.sqrt(float(self.get_shape_list(q)[-1]))) # q*k/sqrt(Dk) => [B, N, F, F] 114 | from_shape = self.get_shape_list(q) # [B, N, F, D] 115 | broadcast_ones = tf.ones([from_shape[0], 1, from_shape[2], 1], tf.float32) # [B, 1, F, 1] 116 | attention_mask = tf.matmul(broadcast_ones, tf.cast(attention_mask, tf.float32), transpose_b=True) # [B, 1, F, 1] * [B, 1, F, 1] => [B, 1, F, F] 117 | adder = (1.0 - attention_mask) * -10000.0 # [B, 1, F, F] 118 | logits += adder # [B, N, F, F] 119 | attention_probs = tf.nn.softmax(logits, name="attention_probs") # softmax(q*k/sqrt(Dk)), [B, N, F, F] 120 | attention_output = tf.matmul(attention_probs, v) # softmax(q*k/sqrt(Dk))*v , [B, N, F, F] * [B, N, F, D] => [B, N, F, D] 121 | attention_output = tf.transpose(attention_output, [0, 2, 1, 3]) #[B, F, N, D] 122 | attention_output = self.dense_layer_3d_proj(attention_output, self.input_arg.hidden_size, attention_head_size, None, name="dense") # [B, F, H] 123 | attention_output = tf.contrib.layers.layer_norm(inputs=attention_output + hidden_output, begin_norm_axis=-1, begin_params_axis=-1) # [B, F, H] 124 | 125 | with tf.variable_scope("ffn"): 126 | intermediate_output = self.dense_layer_2d(attention_output, self.input_arg.intermediate_size, tf.nn.relu, name="dense") # [B, F, intermediate_size] 127 | hidden_output = self.dense_layer_2d(intermediate_output, self.input_arg.hidden_size, None, name="output_dense") # [B, F, hidden_size] 128 | hidden_output = tf.contrib.layers.layer_norm(inputs=hidden_output + attention_output, begin_norm_axis=-1, begin_params_axis=-1) # [B, F, H] 129 | layer_output = self.dense_layer_2d(hidden_output, embedding_size, None, name="layer_output_dense") # [B, F, embedding_size] 130 | self.all_layer_outputs.append(layer_output) 131 | self.output = self.all_layer_outputs[-1] # [B, F, embedding_size] 132 | # max pool output 133 | mask = tf.expand_dims(tf.cast(tf.sequence_mask(self.ph_length, maxlen=seq_length), tf.float32), axis=2) # [B, F, 1] 134 | mask = (1 - mask) * -1e5 135 | self.max_pool_output = tf.reduce_max(self.output + mask, axis=1, keepdims=False) # [B, embedding_size] 136 | 137 | #make Matrix from 4D to 3D 138 | def dense_layer_3d_proj(self, input_tensor, hidden_size, head_size, activation, name=None): 139 | input_shape = self.get_shape_list(input_tensor) # [B,F,N,D] 140 | num_attention_heads = input_shape[2] 141 | with tf.variable_scope(name): 142 | w = tf.get_variable(name="kernel", shape=[num_attention_heads * head_size, hidden_size], initializer=tf.truncated_normal_initializer(stddev=self.input_arg.initializer_range)) 143 | w = tf.reshape(w, [num_attention_heads, head_size, hidden_size]) 144 | b = tf.get_variable(name="bias", shape=[hidden_size], initializer=tf.zeros_initializer) 145 | output = tf.einsum("BFND,NDH->BFH", input_tensor, w) # [B, F, H] 146 | output += b 147 | if activation is not None: 148 | return activation(output) 149 | else: 150 | return output 151 | 152 | #make Matrix for 3D transformation in the last index 153 | def dense_layer_2d(self, input_tensor, output_size, activation, name=None): 154 | input_shape = self.get_shape_list(input_tensor) # [B, F, H] 155 | hidden_size = input_shape[2] 156 | with tf.variable_scope(name): 157 | w = tf.get_variable(name="kernel", shape=[hidden_size, output_size], initializer=tf.truncated_normal_initializer(stddev=self.input_arg.initializer_range)) 158 | b = tf.get_variable(name="bias", shape=[output_size], initializer=tf.zeros_initializer) 159 | output = tf.einsum("BFH,HO->BFO", input_tensor, w) # [B, F, O] 160 | output += b 161 | if activation is not None: 162 | return activation(output) 163 | else: 164 | return output 165 | 166 | # make Matrix from 3D to 4D 167 | def dense_layer_3d(self, input_tensor, num_attention_heads, head_size, activation, name=None): 168 | input_shape = self.get_shape_list(input_tensor) # [B, F, H] 169 | hidden_size = input_shape[2] 170 | with tf.variable_scope(name): 171 | w = tf.get_variable(name="kernel", shape=[hidden_size, num_attention_heads * head_size], initializer=tf.truncated_normal_initializer(stddev=self.input_arg.initializer_range)) 172 | w = tf.reshape(w, [hidden_size, num_attention_heads, head_size]) 173 | b = tf.get_variable(name="bias", shape=[num_attention_heads * head_size], initializer=tf.zeros_initializer) 174 | b = tf.reshape(b, [num_attention_heads, head_size]) 175 | output = tf.einsum("BFH,HND->BFND", input_tensor, w) #[B, F, N, D] 176 | output += b 177 | if activation is not None: 178 | return activation(output) 179 | else: 180 | return output 181 | 182 | def get_shape_list(self, tensor): 183 | """Returns a list of the shape of tensor, preferring static dimensions. 184 | """ 185 | tensor_shape = tensor.shape.as_list() 186 | none_indexes = [] 187 | for (index, dim) in enumerate(tensor_shape): 188 | if dim is None: 189 | none_indexes.append(index) 190 | if not none_indexes: 191 | return tensor_shape 192 | dynamic_shape = tf.shape(tensor) 193 | for index in none_indexes: 194 | tensor_shape[index] = dynamic_shape[index] 195 | return tensor_shape 196 | 197 | 198 | 199 | def create_bidirectional_lm_training_op(input_arg, other_arg_dict): 200 | loss_op, model = create_bidirectional_lm_model(input_arg, other_arg_dict) 201 | train_op, learning_rate_op = create_optimizer(loss_op, input_arg.learning_rate, input_arg.train_step, 202 | input_arg.warmup_step, input_arg.clip_norm, input_arg.weight_decay) 203 | model.loss_op = loss_op 204 | model.train_op = train_op 205 | model.learning_rate_op = learning_rate_op 206 | return model 207 | 208 | 209 | def create_bidirectional_lm_model(input_arg, other_arg_dict): 210 | model = BiDirectionalLmModel(input_arg, other_arg_dict) 211 | model.build() 212 | max_predictions_per_seq = input_arg.max_predictions_per_seq 213 | 214 | model.global_step = tf.train.get_or_create_global_step() 215 | model.ph_labels = tf.placeholder(dtype=tf.int32, shape=[None, max_predictions_per_seq], 216 | name="ph_labels") # [batch, max_predictions_per_seq] 217 | model.ph_positions = tf.placeholder(dtype=tf.int32, shape=[None, max_predictions_per_seq], 218 | name="ph_positions") # [batch, max_predictions_per_seq] 219 | model.ph_weights = tf.placeholder(dtype=tf.float32, shape=[None, max_predictions_per_seq], 220 | name="ph_weights") # [batch, max_predictions_per_seq] 221 | 222 | real_output = gather_indexes(model.output, model.ph_positions) # [batch * max_predictions_per_seq, embedding_dim] 223 | 224 | bias = tf.get_variable("bias", shape=[model.token_num], initializer=tf.zeros_initializer()) 225 | logits = tf.matmul(real_output, model.v_token_embedding, transpose_b=True) #[batch * max_predictions_per_seq, token_num] 226 | logits = tf.nn.bias_add(logits, bias) 227 | 228 | log_probs = tf.nn.log_softmax(logits, axis=-1) #[batch * max_predictions_per_seq, token_num] 229 | one_hot_labels = tf.one_hot(tf.reshape(model.ph_labels, [-1]), depth=model.token_num, dtype=tf.float32) #[batch * max_predictions_per_seq, token_num] 230 | 231 | per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) # [batch * max_predictions_per_seq] 232 | weights = tf.reshape(model.ph_weights, [-1]) # [batch * max_predictions_per_seq] 233 | loss = (tf.reduce_sum(weights * per_example_loss)) / (tf.reduce_sum(weights) + 1e-5) 234 | 235 | return loss, model 236 | 237 | 238 | def create_optimizer(loss, init_lr=5e-5, num_train_steps=1000000, num_warmup_steps=20000, clip_nom=1.0, 239 | weight_decay=0.01): 240 | global_step = tf.train.get_or_create_global_step() 241 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 242 | learning_rate = tf.train.polynomial_decay(learning_rate, global_step, num_train_steps, end_learning_rate=0.0, 243 | power=1.0, cycle=False) # linear warmup 244 | 245 | if num_warmup_steps: 246 | global_steps_int = tf.cast(global_step, tf.int32) 247 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 248 | warmup_learning_rate = init_lr * tf.cast(global_steps_int, tf.float32) / tf.cast(warmup_steps_int, tf.float32) 249 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 250 | learning_rate = ((1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 251 | 252 | optimizer = tf.contrib.opt.AdamWOptimizer(weight_decay=weight_decay, learning_rate=learning_rate) 253 | tvars = tf.trainable_variables() 254 | grads = tf.gradients(loss, tvars) 255 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=clip_nom) 256 | train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=global_step) 257 | return train_op, learning_rate 258 | 259 | 260 | def gather_indexes(seq_output, positions): 261 | #seq_output:[batch, length, 2 * dim] 262 | #positions:[batch, max_predictions_per_seq] 263 | 264 | batch_size = tf.shape(seq_output)[0] 265 | length = tf.shape(seq_output)[1] 266 | dim = tf.shape(seq_output)[2] 267 | 268 | flat_offsets = tf.reshape(tf.range(0, batch_size, dtype=tf.int32) * length, [-1, 1]) #[batch_size, 1] 269 | output_tensor = tf.gather(tf.reshape(seq_output, [-1, dim]), tf.reshape(positions + flat_offsets, [-1])) 270 | return output_tensor # [batch * max_predictions_per_seq, dim] 271 | 272 | 273 | # finetune model utils 274 | def create_finetune_classification_training_op(input_arg, other_arg_dict): 275 | model = create_finetune_classification_model(input_arg, other_arg_dict) 276 | repre = model.max_pool_output # [batch, 2 * dim] 277 | 278 | model.ph_labels = tf.placeholder(dtype=tf.int32, shape=[None], name="ph_labels") # [batch] 279 | logits = tf.layers.dense(repre, other_arg_dict["label_num"], 280 | kernel_initializer=tf.contrib.layers.xavier_initializer(), name="logits") 281 | model.softmax_op = tf.nn.softmax(logits, -1, name="softmax_pre") 282 | model.loss_op = tf.reduce_mean( 283 | tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=model.ph_labels), -1) 284 | model.global_step_op = tf.train.get_or_create_global_step() 285 | 286 | tf.logging.info("learning_rate : {}".format(input_arg.learning_rate)) 287 | if input_arg.opt_type == "sgd": 288 | tf.logging.info("use sgd") 289 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=input_arg.learning_rate) 290 | elif input_arg.opt_type == "adagrad": 291 | tf.logging.info("use adagrad") 292 | optimizer = tf.train.AdagradOptimizer(learning_rate=input_arg.learning_rate) 293 | elif input_arg.opt_type == "adam": 294 | tf.logging.info("use adam") 295 | optimizer = tf.train.AdamOptimizer(learning_rate=input_arg.learning_rate) 296 | else: 297 | assert False 298 | 299 | list_g_v_pair = optimizer.compute_gradients(model.loss_op) 300 | model.train_op = optimizer.apply_gradients(list_g_v_pair, global_step=model.global_step_op) 301 | 302 | return model 303 | 304 | #create finetune model for classification 305 | def create_finetune_classification_model(input_arg, other_arg_dict): 306 | model = BiDirectionalLmModel(input_arg, other_arg_dict) 307 | 308 | model.build() 309 | 310 | tvars = tf.trainable_variables() 311 | initialized_variable_names = {} 312 | if input_arg.init_checkpoint: 313 | tf.logging.info("init from checkpoint!") 314 | assignment_map, initialized_variable_names = get_assignment_map_from_checkpoint(tvars, 315 | input_arg.init_checkpoint) 316 | tf.train.init_from_checkpoint(input_arg.init_checkpoint, assignment_map) 317 | 318 | tf.logging.info("**** Trainable Variables ****") 319 | for var in tvars: 320 | init_string = "" 321 | if var.name in initialized_variable_names: 322 | init_string = ", *INIT_FROM_CKPT*" 323 | tf.logging.info("name = {}, shape = {}{}".format(var.name, var.shape, init_string)) 324 | 325 | return model 326 | 327 | 328 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint): 329 | name_to_variable = collections.OrderedDict() # trainable variables 330 | for var in tvars: 331 | name = var.name 332 | m = re.match("^(.*):\\d+$", name) 333 | if m is not None: 334 | name = m.group(1) 335 | name_to_variable[name] = var 336 | 337 | assignment_map = collections.OrderedDict() 338 | initialized_variable_names = {} 339 | init_vars = tf.train.list_variables(init_checkpoint) # variables in checkpoint 340 | for x in init_vars: 341 | (name, var) = (x[0], x[1]) 342 | if name not in name_to_variable: 343 | continue 344 | assignment_map[name] = name 345 | initialized_variable_names[name] = 1 346 | initialized_variable_names[name + ":0"] = 1 347 | 348 | return assignment_map, initialized_variable_names 349 | -------------------------------------------------------------------------------- /sptm/run_classifier.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | finetune on pretrained model with trainset and devset 4 | """ 5 | 6 | import sys 7 | import os 8 | import tensorflow as tf 9 | import numpy as np 10 | import argparse 11 | import models 12 | import utils 13 | 14 | 15 | def evaluate(sess, full_tensors, args, model): 16 | total_num = 0 17 | right_num = 0 18 | for batch_data in utils.gen_batchs(full_tensors, args.batch_size, is_shuffle=False): 19 | softmax_re = sess.run(model.softmax_op, 20 | feed_dict={model.ph_dropout_rate: 0, 21 | model.ph_tokens: batch_data[0], 22 | model.ph_labels: batch_data[1], 23 | model.ph_length: batch_data[2], 24 | model.ph_input_mask: batch_data[3]}) 25 | pred_re = np.argmax(softmax_re, axis=1) 26 | total_num += len(pred_re) 27 | right_num += np.sum(pred_re == batch_data[1]) 28 | acc = 1.0 * right_num / (total_num + 1e-5) 29 | 30 | tf.logging.info("dev total num: " + str(total_num) + ", right num: " + str(right_num) + ", acc: " + str(acc)) 31 | return acc 32 | 33 | 34 | def main(_): 35 | tf.logging.set_verbosity(tf.logging.INFO) 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--train_file", type=str, default="", help="Input train file.") 38 | parser.add_argument("--dev_file", type=str, default="", help="Input dev file.") 39 | parser.add_argument("--vocab_file", type=str, default="", help="Input vocab file.") 40 | parser.add_argument("--output_id2label_file", type=str, default="./id2label", 41 | help="File containing (id, class label) map.") 42 | parser.add_argument("--model_save_dir", type=str, default="", 43 | help="Specified the directory in which the model should stored.") 44 | parser.add_argument("--lstm_dim", type=int, default=500, help="Dimension of LSTM cell.") 45 | parser.add_argument("--embedding_dim", type=int, default=1000, help="Dimension of word embedding.") 46 | parser.add_argument("--opt_type", type=str, default='adam', help="Type of optimizer.") 47 | parser.add_argument("--batch_size", type=int, default=32, help="Batch size.") 48 | parser.add_argument("--epoch", type=int, default=20, help="Epoch.") 49 | parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.") 50 | parser.add_argument("--dropout_rate", type=float, default=0.1, help="Dropout rate") 51 | parser.add_argument("--seed", type=int, default=1, help="Random seed value.") 52 | parser.add_argument("--print_step", type=int, default=1000, help="Print log every x step.") 53 | parser.add_argument("--init_checkpoint", type=str, default='', 54 | help="Initial checkpoint (usually from a pre-trained model).") 55 | parser.add_argument("--max_len", type=int, default=100, help="Max seqence length.") 56 | parser.add_argument("--layer_num", type=int, default=2, help="LSTM layer num.") 57 | 58 | parser.add_argument("--representation_type", type=str, default="lstm", 59 | help="representation type include:lstm, transformer") 60 | 61 | # transformer args 62 | parser.add_argument("--initializer_range", type=float, default="0.02", help="Embedding initialization range") 63 | parser.add_argument("--max_position_embeddings", type=int, default=512, help="max position num") 64 | parser.add_argument("--hidden_size", type=int, default=768, help="hidden size") 65 | parser.add_argument("--num_hidden_layers", type=int, default=12, help="num hidden layer") 66 | parser.add_argument("--num_attention_heads", type=int, default=12, help="num attention heads") 67 | parser.add_argument("--intermediate_size", type=int, default=3072, help="intermediate_size") 68 | 69 | args = parser.parse_args() 70 | 71 | np.random.seed(args.seed) 72 | tf.set_random_seed(args.seed) 73 | tf.logging.info(str(args)) 74 | if not os.path.exists(args.model_save_dir): 75 | os.mkdir(args.model_save_dir) 76 | 77 | tf.logging.info("load training sens") 78 | train_sens = utils.load_training_data(args.train_file, skip_invalid=True) 79 | tf.logging.info("\nload dev sens") 80 | dev_sens = utils.load_training_data(args.dev_file, skip_invalid=True) 81 | 82 | word2id, id2word, label2id, id2label = utils.load_vocab(train_sens + dev_sens, args.vocab_file) 83 | fw = open(args.output_id2label_file, 'w+') 84 | for k, v in id2label.items(): 85 | fw.write(str(k) + "\t" + v + "\n") 86 | fw.close() 87 | 88 | utils.gen_ids(train_sens, word2id, label2id, args.max_len) 89 | utils.gen_ids(dev_sens, word2id, label2id, args.max_len) 90 | 91 | train_full_tensors = utils.make_full_tensors(train_sens) 92 | dev_full_tensors = utils.make_full_tensors(dev_sens) 93 | 94 | other_arg_dict = {} 95 | other_arg_dict['token_num'] = len(word2id) 96 | other_arg_dict['label_num'] = len(label2id) 97 | model = models.create_finetune_classification_training_op(args, other_arg_dict) 98 | 99 | steps_in_epoch = int(len(train_sens) // args.batch_size) 100 | tf.logging.info("batch size: " + str(args.batch_size) + ", training sample num : " + str( 101 | len(train_sens)) + ", print step : " + str(args.print_step)) 102 | tf.logging.info( 103 | "steps_in_epoch : " + str(steps_in_epoch) + ", epoch num :" + str(args.epoch) + ", total steps : " + str( 104 | args.epoch * steps_in_epoch)) 105 | print_step = min(args.print_step, steps_in_epoch) 106 | tf.logging.info("eval dev every {} step".format(print_step)) 107 | 108 | save_vars = [v for v in tf.global_variables() if 109 | v.name.find('adam') < 0 and v.name.find('Adam') < 0 and v.name.find('ADAM') < 0] 110 | tf.logging.info(str(save_vars)) 111 | tf.logging.info(str(tf.all_variables())) 112 | 113 | saver = tf.train.Saver(max_to_keep=2) 114 | config = tf.ConfigProto(allow_soft_placement=True) 115 | config.gpu_options.allow_growth = True 116 | with tf.Session(config=config) as sess: 117 | sess.run(tf.global_variables_initializer()) 118 | total_loss = 0 119 | dev_best_so_far = 0 120 | for epoch in range(1, args.epoch + 1): 121 | tf.logging.info("\n" + "*" * 20 + "epoch num :" + str(epoch) + "*" * 20) 122 | for batch_data in utils.gen_batchs(train_full_tensors, args.batch_size, is_shuffle=True): 123 | _, global_step, loss = sess.run([model.train_op, model.global_step_op, model.loss_op], 124 | feed_dict={model.ph_dropout_rate: args.dropout_rate, 125 | model.ph_tokens: batch_data[0], 126 | model.ph_labels: batch_data[1], 127 | model.ph_length: batch_data[2], 128 | model.ph_input_mask: batch_data[3]}) 129 | total_loss += loss 130 | if global_step % print_step == 0: 131 | tf.logging.info( 132 | "\nglobal step : " + str(global_step) + ", avg loss so far : " + str(total_loss / global_step)) 133 | tf.logging.info("begin to eval dev set: ") 134 | acc = evaluate(sess, dev_full_tensors, args, model) 135 | if acc > dev_best_so_far: 136 | dev_best_so_far = acc 137 | tf.logging.info("!" * 20 + "best got : " + str(acc)) 138 | # constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["scores"]) 139 | saver.save(sess, args.model_save_dir + '/finetune.ckpt', global_step=global_step) 140 | 141 | tf.logging.info("\n----------------------eval after one epoch: ") 142 | tf.logging.info( 143 | "global step : " + str(global_step) + ", avg loss so far : " + str(total_loss / global_step)) 144 | tf.logging.info("begin to eval dev set: ") 145 | sys.stdout.flush() 146 | acc = evaluate(sess, dev_full_tensors, args, model) 147 | if acc > dev_best_so_far: 148 | dev_best_so_far = acc 149 | tf.logging.info("!" * 20 + "best got : " + str(acc)) 150 | saver.save(sess, args.model_save_dir + '/finetune.ckpt', global_step=global_step) 151 | 152 | 153 | if __name__ == "__main__": 154 | tf.app.run() 155 | -------------------------------------------------------------------------------- /sptm/run_prediction.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | predict with finetuned model with testset 4 | """ 5 | 6 | import sys 7 | import tensorflow as tf 8 | import numpy as np 9 | import argparse 10 | import utils 11 | 12 | 13 | def get_output(g): 14 | return {"softmax": g.get_tensor_by_name("softmax_pre:0")} 15 | 16 | 17 | def get_input(g): 18 | return {"tokens": g.get_tensor_by_name("ph_tokens:0"), 19 | "length": g.get_tensor_by_name("ph_length:0"), 20 | "dropout_rate": g.get_tensor_by_name("ph_dropout_rate:0"), 21 | "input_mask": g.get_tensor_by_name("ph_input_mask:0")} 22 | 23 | 24 | def main(_): 25 | tf.logging.set_verbosity(tf.logging.INFO) 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--input_file", type=str, default="", help="Input file for prediction.") 28 | parser.add_argument("--vocab_file", type=str, default="", help="Input train file.") 29 | parser.add_argument("--model_path", type=str, default="", help="Path to model file.") 30 | parser.add_argument("--model_dir", type=str, default="", help="Directory which contains model.") 31 | parser.add_argument("--id2label_file", type=str, default="./id2label", 32 | help="File containing (id, class label) map.") 33 | args = parser.parse_args() 34 | 35 | word2id, id2word = utils.load_vocab_file(args.vocab_file) 36 | sys.stderr.write("vocab num : " + str(len(word2id)) + "\n") 37 | 38 | sens = utils.gen_test_data(args.input_file, word2id) 39 | sys.stderr.write("sens num : " + str(len(sens)) + "\n") 40 | 41 | id2label = utils.load_id2label_file(args.id2label_file) 42 | sys.stderr.write('label num : ' + str(len(id2label)) + "\n") 43 | 44 | # use latest checkpoint 45 | if "" == args.model_path: 46 | args.model_path = tf.train.latest_checkpoint(checkpoint_dir=args.model_dir) 47 | 48 | config = tf.ConfigProto() 49 | config.gpu_options.allow_growth = True 50 | with tf.Session(config=config) as sess: 51 | saver = tf.train.import_meta_graph("{}.meta".format(args.model_path)) 52 | saver.restore(sess, args.model_path) 53 | 54 | graph = tf.get_default_graph() 55 | input_dict = get_input(graph) 56 | output_dict = get_output(graph) 57 | 58 | for sen in sens: 59 | re = sess.run(output_dict['softmax'], feed_dict={input_dict['tokens']: [sen[0]], 60 | input_dict['input_mask']: [sen[1]], 61 | input_dict['length']: [len(sen[0])], 62 | input_dict["dropout_rate"]: 0.0}) 63 | sorted_idx = np.argsort(-1 * re[0]) # sort by desc 64 | s = "" 65 | for i in sorted_idx[:3]: 66 | s += id2label[i] + "|" + str(re[0][i]) + "," 67 | print(s + "\t" + " ".join([id2word[t] for t in sen[0]])) 68 | 69 | 70 | if __name__ == "__main__": 71 | tf.app.run() 72 | -------------------------------------------------------------------------------- /sptm/run_pretraining.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | pretrain a specified language model(modified bi-lstm as default) 4 | """ 5 | 6 | from __future__ import print_function 7 | 8 | import os 9 | import tensorflow as tf 10 | import numpy as np 11 | import argparse 12 | import models 13 | import utils 14 | import gc 15 | import time 16 | 17 | 18 | def main(_): 19 | tf.logging.set_verbosity(tf.logging.INFO) 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--train_file", type=str, default="", help="Input train file.") 22 | parser.add_argument("--vocab_file", default="", help="Input vocab file.") 23 | parser.add_argument("--model_save_dir", type=str, default="", 24 | help="Specified the directory in which the model should stored.") 25 | parser.add_argument("--lstm_dim", type=int, default=500, help="Dimension of LSTM cell.") 26 | parser.add_argument("--embedding_dim", type=int, default=1000, help="Dimension of word embedding.") 27 | parser.add_argument("--layer_num", type=int, default=2, help="LSTM layer num.") 28 | parser.add_argument("--batch_size", type=int, default=32, help="Batch size.") 29 | parser.add_argument("--train_step", type=int, default=10000, help="Number of training steps.") 30 | parser.add_argument("--warmup_step", type=int, default=1000, help="Number of warmup steps.") 31 | parser.add_argument("--learning_rate", type=float, default=0.001, help="The initial learning rate") 32 | parser.add_argument("--dropout_rate", type=float, default=0.5, help="Dropout rate") 33 | parser.add_argument("--seed", type=int, default=0, help="Random seed value.") 34 | parser.add_argument("--print_step", type=int, default=1000, help="Print log every x step.") 35 | parser.add_argument("--max_predictions_per_seq", type=int, default=10, 36 | help="For each sequence, predict x words at most.") 37 | parser.add_argument("--weight_decay", type=float, default=0, help="Weight decay rate") 38 | parser.add_argument("--clip_norm", type=float, default=1, help='Clip normalization rate.') 39 | parser.add_argument("--max_seq_len", type=int, default=100, help="Max seqence length.") 40 | parser.add_argument("--use_queue", type=int, default=0, help="Whether or not using a queue for input.") 41 | parser.add_argument("--init_checkpoint", type=str, default="", help="Initial checkpoint") 42 | parser.add_argument("--enqueue_thread_num", type=int, default=5, help="Enqueue thread count.") 43 | parser.add_argument("--representation_type", type=str, default="lstm", help="representation type include:lstm, transformer") 44 | 45 | #transformer args 46 | parser.add_argument("--initializer_range", type=float, default="0.02", help="Embedding initialization range") 47 | parser.add_argument("--max_position_embeddings", type=int, default=512, help="max position num") 48 | parser.add_argument("--hidden_size", type=int, default=768, help="hidden size") 49 | parser.add_argument("--num_hidden_layers", type=int, default=12, help ="num hidden layer") 50 | parser.add_argument("--num_attention_heads", type=int, default=12, help="num attention heads") 51 | parser.add_argument("--intermediate_size", type=int, default=3072, help="intermediate_size") 52 | 53 | args = parser.parse_args() 54 | 55 | np.random.seed(args.seed) 56 | tf.set_random_seed(args.seed) 57 | tf.logging.info(args) 58 | if not os.path.exists(args.model_save_dir): 59 | os.mkdir(args.model_save_dir) 60 | 61 | # load data 62 | word2id, id2word = utils.load_vocab_file(args.vocab_file) 63 | training_sens = utils.load_pretraining_data(args.train_file, args.max_seq_len) 64 | 65 | if not args.use_queue: 66 | utils.to_ids(training_sens, word2id, args, id2word) 67 | 68 | other_arg_dict = {} 69 | other_arg_dict['token_num'] = len(word2id) 70 | 71 | # load model 72 | model = models.create_bidirectional_lm_training_op(args, other_arg_dict) 73 | 74 | gc.collect() 75 | saver = tf.train.Saver(max_to_keep=2) 76 | config = tf.ConfigProto(allow_soft_placement=True) 77 | config.gpu_options.allow_growth = True 78 | with tf.Session(config=config) as sess: 79 | sess.run(tf.global_variables_initializer()) 80 | 81 | if args.init_checkpoint: 82 | tf.logging.info('restore the checkpoint : ' + str(args.init_checkpoint)) 83 | saver.restore(sess, args.init_checkpoint) 84 | 85 | total_loss = 0 86 | num = 0 87 | global_step = 0 88 | while global_step < args.train_step: 89 | if not args.use_queue: 90 | iterator = utils.gen_batches(training_sens, args.batch_size) 91 | else: 92 | iterator = utils.queue_gen_batches(training_sens, args, word2id, id2word) 93 | assert iterator is not None 94 | for batch_data in iterator: 95 | feed_dict = {model.ph_tokens: batch_data[0], 96 | model.ph_length: batch_data[1], 97 | model.ph_labels: batch_data[2], 98 | model.ph_positions: batch_data[3], 99 | model.ph_weights: batch_data[4], 100 | model.ph_input_mask: batch_data[5], 101 | model.ph_dropout_rate: args.dropout_rate} 102 | _, global_step, loss, learning_rate = sess.run([model.train_op, \ 103 | model.global_step, model.loss_op, 104 | model.learning_rate_op], feed_dict=feed_dict) 105 | 106 | total_loss += loss 107 | num += 1 108 | if global_step % args.print_step == 0: 109 | tf.logging.info("\nglobal step : " + str(global_step) + 110 | ", avg loss so far : " + str(total_loss / num) + 111 | ", instant loss : " + str(loss) + 112 | ", learning_rate : " + str(learning_rate) + 113 | ", time :" + str(time.strftime('%Y-%m-%d %H:%M:%S'))) 114 | tf.logging.info("save model ...") 115 | saver.save(sess, args.model_save_dir + '/lm_pretrain.ckpt', global_step=global_step) 116 | gc.collect() 117 | 118 | if not args.use_queue: 119 | utils.to_ids(training_sens, word2id, args, id2word) # MUST run this for randomization for each sentence 120 | gc.collect() 121 | 122 | 123 | if __name__ == "__main__": 124 | tf.app.run() 125 | -------------------------------------------------------------------------------- /sptm/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | tools for runing pretrain and finetune language models 4 | 5 | """ 6 | import sys 7 | import numpy as np 8 | import codecs 9 | from collections import namedtuple 10 | import queue 11 | import threading 12 | import tensorflow as tf 13 | 14 | 15 | # sample representation 16 | class Sentence(object): 17 | def __init__(self, raw_tokens, raw_label=None): 18 | self.raw_tokens = raw_tokens 19 | self.raw_label = raw_label 20 | self.label_id = None 21 | self.token_ids = [] 22 | 23 | # for pretrain 24 | def to_id(self, word2id, args): 25 | # for each epoch, this should be rerun to get random results for each sentence. 26 | self.positions = [] 27 | self.labels = [] 28 | self.weights = [] 29 | self.fw_labels = [] 30 | self.bw_labels = [] 31 | self.token_ids = [] 32 | self.input_masks = [] 33 | 34 | for t in self.raw_tokens: 35 | self.token_ids.append(word2id[t]) 36 | self.input_masks.append(1) 37 | 38 | for ta in self.bidirectional_targets: 39 | # predict itself 40 | self.labels.append(self.token_ids[ta.position]) 41 | # in-place modify the target token in the sentence 42 | self.token_ids[ta.position] = word2id[ta.replace_token] 43 | self.positions.append(ta.position) 44 | self.weights.append(1.0) 45 | 46 | # fix to tensors for the predictions of LM 47 | cur_len = len(self.labels) 48 | self.labels = self.labels + [0] * (args.max_predictions_per_seq - cur_len) 49 | self.positions = self.positions + [0] * (args.max_predictions_per_seq - cur_len) 50 | self.weights = self.weights + [0] * (args.max_predictions_per_seq - cur_len) 51 | 52 | # for finetune 53 | def to_ids(self, word2id, label2id, max_len): 54 | self.label_id = label2id[self.raw_label] 55 | self.raw_tokens = self.raw_tokens[:max_len] # cut off to the max length 56 | self.input_masks = [] 57 | all_unk = True 58 | for t in self.raw_tokens: 59 | if t in word2id: 60 | self.token_ids.append(word2id[t]) 61 | all_unk = False 62 | else: 63 | self.token_ids.append(word2id[""]) 64 | self.input_masks.append(1) 65 | assert not all_unk 66 | 67 | self.token_ids = self.token_ids + [0] * (max_len - len(self.token_ids)) 68 | self.input_masks = self.input_masks + [0] * (max_len - len(self.input_masks)) 69 | 70 | # file utils 71 | 72 | 73 | def load_vocab_file(vocab_file): 74 | word2id = {} 75 | id2word = {} 76 | for l in codecs.open(vocab_file, 'r', 'utf-8'): 77 | l = l.strip() 78 | assert l != "" 79 | assert l not in word2id 80 | word2id[l] = len(word2id) 81 | id2word[len(id2word)] = l 82 | sys.stderr.write("uniq token num : " + str(len(word2id)) + "\n") 83 | return word2id, id2word 84 | 85 | 86 | def load_vocab(sens, vocab_file): 87 | label2id = {} 88 | id2label = {} 89 | for sen in sens: 90 | if sen.raw_label not in label2id: 91 | label2id[sen.raw_label] = len(label2id) 92 | id2label[len(id2label)] = sen.raw_label 93 | 94 | word2id, id2word = load_vocab_file(vocab_file) 95 | assert len(word2id) == len(id2word) 96 | tf.logging.info("\ntoken num : " + str(len(word2id))) 97 | tf.logging.info(", label num : " + str(len(label2id))) 98 | tf.logging.info(", labels: " + str(id2label)) 99 | return word2id, id2word, label2id, id2label 100 | 101 | 102 | def load_id2label_file(id2label_file): 103 | di = {} 104 | for l in open(id2label_file, 'r'): 105 | fs = l.rstrip().split('\t') 106 | assert len(fs) == 2 107 | di[int(fs[0])] = fs[1] 108 | return di 109 | 110 | 111 | def gen_test_data(test_file, word2id): 112 | """ 113 | read and encode test file. 114 | """ 115 | sens = [] 116 | for l in codecs.open(test_file, 'r', 'utf-8'): 117 | fs = l.rstrip().split('\t')[-1].split() 118 | sen = [] 119 | mask = [] 120 | for f in fs: 121 | if f in word2id: 122 | sen.append(word2id[f]) 123 | else: 124 | sen.append(word2id['']) 125 | mask.append(1) 126 | sens.append((sen, mask)) 127 | return sens 128 | 129 | 130 | def load_pretraining_data(train_file, max_seq_len): 131 | sens = [] 132 | for l in codecs.open(train_file, 'r', 'utf-8'): 133 | sen = Sentence(l.rstrip().split("\t")[-1].split()[:max_seq_len]) 134 | if len(sen.raw_tokens) == 0: 135 | continue 136 | sens.append(sen) 137 | if len(sens) % 2000000 == 0: 138 | tf.logging.info("load sens :" + str(len(sens))) 139 | tf.logging.info("training sens num :" + str(len(sens))) 140 | return sens 141 | 142 | 143 | def load_training_data(file_path, skip_invalid=True): 144 | sens = [] 145 | invalid_num = 0 146 | max_len = 0 147 | for l in codecs.open(file_path, 'r', 'utf-8'): # load as utf-8 encoding. 148 | if l.strip() == "": 149 | continue 150 | fs = l.rstrip().split('\t') 151 | assert len(fs) == 3 152 | tokens = fs[2].split() # discard empty strings 153 | for t in tokens: 154 | assert t != "" 155 | label = "__label__{}".format(fs[0]) 156 | if skip_invalid: 157 | if label.find(',') >= 0 or label.find('NONE') >= 0: 158 | invalid_num += 1 159 | continue 160 | if len(tokens) > max_len: 161 | max_len = len(tokens) 162 | sens.append(Sentence(tokens, label)) 163 | tf.logging.info("invalid sen num : " + str(invalid_num)) 164 | tf.logging.info("valid sen num : " + str(len(sens))) 165 | tf.logging.info("max_len : " + str(max_len)) 166 | return sens 167 | 168 | 169 | # pretrain utils 170 | BiReplacement = namedtuple("BiReplacement", ["position", "replace_token"]) 171 | 172 | 173 | def gen_pretrain_targets(raw_tokens, id2word, max_predictions_per_seq): 174 | assert max_predictions_per_seq > 0 175 | assert len(raw_tokens) > 0 176 | pred_num = min(max_predictions_per_seq, max(1, int(round(len(raw_tokens) * 0.15)))) 177 | 178 | re = [] 179 | covered_pos_set = set() 180 | for _ in range(pred_num): 181 | cur_pos = np.random.randint(0, len(raw_tokens)) 182 | if cur_pos in covered_pos_set: 183 | continue 184 | covered_pos_set.add(cur_pos) 185 | 186 | prob = np.random.uniform() 187 | if prob < 0.8: 188 | replace_token = '' 189 | elif prob < 0.9: 190 | replace_token = raw_tokens[cur_pos] # itself 191 | else: 192 | while True: 193 | fake_pos = np.random.randint(0, len(id2word)) # random one 194 | replace_token = id2word[fake_pos] 195 | if raw_tokens[cur_pos] != replace_token: 196 | break 197 | re.append(BiReplacement(position=cur_pos, replace_token=replace_token)) 198 | return re 199 | 200 | 201 | def gen_ids(sens, word2id, label2id, max_len): 202 | for sen in sens: 203 | sen.to_ids(word2id, label2id, max_len) 204 | 205 | 206 | def to_ids(sens, word2id, args, id2word): 207 | num = 0 208 | for sen in sens: 209 | if num % 2000000 == 0: 210 | tf.logging.info("to_ids handling num : " + str(num)) 211 | num += 1 212 | sen.bidirectional_targets = gen_pretrain_targets(sen.raw_tokens, id2word, args.max_predictions_per_seq) 213 | sen.to_id(word2id, args) 214 | 215 | 216 | def gen_batches(sens, batch_size): 217 | per = np.array([i for i in range(len(sens))]) 218 | np.random.shuffle(per) 219 | 220 | cur_idx = 0 221 | token_batch = [] 222 | input_mask_batch = [] 223 | length_batch = [] 224 | 225 | position_batch = [] 226 | label_batch = [] 227 | weight_batch = [] 228 | 229 | while cur_idx < len(sens): 230 | token_batch.append(sens[per[cur_idx]].token_ids) 231 | length_batch.append(len(sens[per[cur_idx]].token_ids)) 232 | input_mask_batch.append(sens[per[cur_idx]].input_masks) 233 | 234 | label_batch.append(sens[per[cur_idx]].labels) 235 | position_batch.append(sens[per[cur_idx]].positions) 236 | weight_batch.append(sens[per[cur_idx]].weights) 237 | if len(token_batch) == batch_size or cur_idx == len(sens) - 1: 238 | max_len = max(length_batch) 239 | for ts in token_batch: ts.extend([0] * (max(length_batch) - len(ts))) 240 | for im in input_mask_batch: im.extend([0] * (max_len - len(im))) 241 | 242 | yield token_batch, length_batch, label_batch, position_batch, weight_batch, input_mask_batch 243 | 244 | del token_batch 245 | del input_mask_batch 246 | del length_batch 247 | del label_batch 248 | del position_batch 249 | del weight_batch 250 | token_batch = [] 251 | input_mask_batch = [] 252 | length_batch = [] 253 | label_batch = [] 254 | position_batch = [] 255 | weight_batch = [] 256 | cur_idx += 1 257 | 258 | 259 | def queue_gen_batches(sens, args, word2id, id2word): 260 | def enqueue(sens, q): 261 | permu = np.arange(len(sens)) 262 | np.random.shuffle(permu) 263 | idx = 0 264 | tf.logging.info("thread started!") 265 | while True: 266 | sen = sens[permu[idx]] 267 | sen.bidirectional_targets = gen_pretrain_targets(sen.raw_tokens, id2word, 268 | args.max_predictions_per_seq) 269 | sen.to_id(word2id, args) 270 | q.put(sen) 271 | idx += 1 272 | if idx >= len(sens): 273 | np.random.shuffle(permu) 274 | idx = idx % len(sens) 275 | 276 | q = queue.Queue(maxsize=50000) 277 | 278 | for i in range(args.enqueue_thread_num): 279 | tf.logging.info("enqueue thread started : " + str(i)) 280 | enqeue_thread = threading.Thread(target=enqueue, args=(sens, q)) 281 | enqeue_thread.setDaemon(True) 282 | enqeue_thread.start() 283 | 284 | qu_sens = [] 285 | while True: 286 | cur_sen = q.get() 287 | qu_sens.append(cur_sen) 288 | if len(qu_sens) >= args.batch_size: 289 | for data in gen_batches(qu_sens, args.batch_size): 290 | yield data 291 | qu_sens = [] 292 | 293 | 294 | def make_full_tensors(sens): 295 | tokens = np.zeros((len(sens), len(sens[0].token_ids)), dtype=np.int32) 296 | masks = np.zeros((len(sens), len(sens[0].input_masks)), dtype=np.int32) 297 | labels = np.zeros((len(sens)), dtype=np.int32) 298 | length = np.zeros((len(sens)), dtype=np.int32) 299 | for idx, sen in enumerate(sens): 300 | tokens[idx] = sen.token_ids 301 | masks[idx] = sen.input_masks 302 | labels[idx] = sen.label_id 303 | length[idx] = len(sen.raw_tokens) 304 | return tokens, labels, length, masks 305 | 306 | 307 | def gen_batchs(full_tensors, batch_size, is_shuffle): 308 | tokens, labels, length, masks = full_tensors 309 | per = np.array([i for i in range(len(tokens))]) 310 | if is_shuffle: 311 | np.random.shuffle(per) 312 | 313 | cur_idx = 0 314 | token_batch = [] 315 | mask_batch = [] 316 | label_batch = [] 317 | length_batch = [] 318 | while cur_idx < len(tokens): 319 | token_batch.append(tokens[per[cur_idx]]) 320 | mask_batch.append(masks[per[cur_idx]]) 321 | label_batch.append(labels[per[cur_idx]]) 322 | length_batch.append(length[per[cur_idx]]) 323 | 324 | if len(token_batch) == batch_size or cur_idx == len(tokens) - 1: 325 | # make the tokens to real max length 326 | real_max_len = max(length_batch) 327 | for i in range(len(token_batch)): 328 | token_batch[i] = token_batch[i][:real_max_len] 329 | mask_batch[i] = mask_batch[i][:real_max_len] 330 | 331 | yield token_batch, label_batch, length_batch, mask_batch 332 | token_batch = [] 333 | label_batch = [] 334 | length_batch = [] 335 | mask_batch = [] 336 | cur_idx += 1 337 | 338 | 339 | if __name__ == "__main__": 340 | pass 341 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuba/qa_match/f74ffeb4a66589eb383a6c251b0a7413e0be7f20/utils/__init__.py -------------------------------------------------------------------------------- /utils/classifier_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """ 4 | tools for run bi-lstm short text classification 5 | """ 6 | 7 | import numpy as np 8 | import math 9 | 10 | 11 | class TextLoader(object): 12 | def __init__(self, is_training, data_path, map_file_path, batch_size, seq_length, vocab, labels, std_label_map, 13 | encoding='utf8', is_reverse=False): 14 | self.data_path = data_path 15 | self.map_file_path = map_file_path 16 | self.batch_size = batch_size 17 | self.seq_length = seq_length 18 | self.is_train = is_training 19 | self.encoding = encoding 20 | # load label std mapping index 21 | self.std_label_map = {} 22 | self.label_num_map = {} 23 | label_set = set() 24 | word_set = set() 25 | if is_training: 26 | with open(map_file_path, 'r', encoding=encoding) as map_index_file: 27 | for line in map_index_file: 28 | tokens = line.strip().split('\t') 29 | assert len(tokens) == 3 30 | label = tokens[0] 31 | label_set.add(label) 32 | std_id = tokens[1] 33 | self.std_label_map[std_id] = label 34 | words = tokens[2].split(" ") 35 | for token in words: 36 | word_set.add(token) 37 | 38 | train_file = data_path 39 | with open(train_file, 'r', encoding=encoding) as fin: 40 | for line in fin: 41 | tokens = line.strip().split('\t') 42 | assert len(tokens) == 3 43 | std_ids = tokens[0].split(",") 44 | words = tokens[2].split(" ") 45 | if len(std_ids) > 1: 46 | label = '__label__list' # answer list 47 | label_set.add(label) 48 | elif std_ids[0] == '0': 49 | label = '__label__none' # refuse answer 50 | label_set.add(label) 51 | else: 52 | assert std_ids[0] in self.std_label_map 53 | label = self.std_label_map.get(std_ids[0]) # __label__xx:some label 54 | for token in words: 55 | word_set.add(token) 56 | if label not in self.label_num_map: 57 | self.label_num_map[label] = 1 58 | else: 59 | self.label_num_map[label] = self.label_num_map[label] + 1 60 | 61 | self.labels = dict( 62 | zip(list(label_set), range(0, len(label_set)))) # {__label__1:0, __label_2:1, __label_3:2, ...} 63 | # print("self.labels: " + str(self.labels)) 64 | # print("self.std_label_map: " + str(self.std_label_map)) 65 | self.id_2_label = {value: key for key, value in self.labels.items()} 66 | self.label_size = len(self.labels) 67 | self.vocab = dict(zip(list(word_set), range(1, len(word_set) + 1))) 68 | self.id_2_vocab = {value: key for key, value in self.vocab.items()} 69 | self.vocab_size = len( 70 | self.vocab) + 1 # self.vocab.size + 1, 0 for pad, not encoding unknown, if care for it, you can modify here 71 | self.load_preprocessed(data_path, is_reverse) 72 | elif vocab is not None and labels is not None and std_label_map is not None: 73 | self.vocab = vocab 74 | self.id_2_vocab = {value: key for key, value in self.vocab.items()} 75 | self.vocab_size = len(vocab) + 1 76 | self.labels = labels 77 | self.id_2_label = {value: key for key, value in self.labels.items()} 78 | self.label_size = len(self.labels) 79 | self.std_label_map = std_label_map 80 | self.load_preprocessed(data_path, is_reverse) 81 | self.num_batches = 1 82 | self.x_batches = None 83 | self.y_batches = None 84 | self.len_batches = None 85 | self.reset_batch_pointer() 86 | 87 | def load_preprocessed(self, data_path, is_reverse): 88 | train_file = data_path 89 | self.raw_lines = [] 90 | with open(train_file, 'r', encoding=self.encoding) as fin: 91 | train_x = [] 92 | train_y = [] 93 | train_len = [] 94 | for line in fin: 95 | temp_x = [] 96 | temp_y = [] 97 | x_len = [] 98 | tokens = line.strip().split('\t') 99 | assert len(tokens) == 3 100 | std_ids = tokens[0].split(",") 101 | words = tokens[2].split(" ") 102 | if len(std_ids) > 1: 103 | label = '__label__list' # answer list 104 | elif std_ids[0] == '0': 105 | label = '__label__none' # refuse answer 106 | else: 107 | if std_ids[0] not in self.std_label_map: 108 | label = '__label__none' 109 | else: 110 | label = self.std_label_map.get(std_ids[0]) # __label__xx:some label 111 | # if label not in self.labels: 112 | # print("label: <" + label + ">") 113 | # print("self.labels: ") 114 | # print(str(self.labels)) 115 | temp_y.append(self.labels[label]) 116 | for item in words: 117 | if item in self.vocab: # not encoding unknown, if care for it, you can modify here 118 | temp_x.append(self.vocab[item]) 119 | if len(temp_x) == 0: 120 | print("all word in line is not in vocab, line: " + line) 121 | continue 122 | if len(temp_x) >= self.seq_length: 123 | x_len.append(self.seq_length) 124 | temp_x = temp_x[:self.seq_length] 125 | if is_reverse: 126 | temp_x.reverse() 127 | else: 128 | x_len.append(len(temp_x)) 129 | if is_reverse: 130 | temp_x.reverse() 131 | temp_x = temp_x + [0] * (self.seq_length - len(temp_x)) 132 | train_x.append(temp_x) 133 | train_y.append(temp_y) 134 | train_len.append(x_len) 135 | self.raw_lines.append(tokens[2]) 136 | tensor_x = np.array(train_x) 137 | tensor_y = np.array(train_y) 138 | tensor_len = np.array(train_len) 139 | # print("tensor_x.shape: " + str(tensor_x.shape)) 140 | # print("len(self.raw_lines): " + str(len(self.raw_lines))) 141 | 142 | self.tensor = np.c_[tensor_x, tensor_y, tensor_len].astype(int) # tensor_x.size * (40+1+1) 143 | 144 | def list_split_n(self, raw_items, split_len_batches): 145 | split_items = [] 146 | j = 0 147 | for i in range(len(split_len_batches)): 148 | split_items.append(raw_items[j: j + split_len_batches[i]]) 149 | j += split_len_batches[i] 150 | return split_items 151 | 152 | def create_batches(self): 153 | self.num_batches = int(self.tensor.shape[0] / self.batch_size) 154 | if int(self.tensor.shape[0] % self.batch_size): 155 | self.num_batches = self.num_batches + 1 156 | if self.num_batches == 0: 157 | assert False, 'Not enough data, make batch_size small.' 158 | if self.is_train: 159 | np.random.shuffle(self.tensor) 160 | # print("self.num_batches: " + str(self.num_batches)) 161 | tensor = self.tensor[:self.num_batches * self.batch_size] 162 | # print("len(tensor): " + str(len(tensor))) 163 | raw_lines = self.raw_lines[ 164 | :self.num_batches * self.batch_size] # if train raw_lines order is different from tensor 165 | # print("len(raw_lines): " + str(len(raw_lines))) 166 | 167 | self.x_batches = np.array_split(tensor[:, :-2], self.num_batches, 0) 168 | self.y_batches = np.array_split(tensor[:, -2], self.num_batches, 0) 169 | self.len_batches = np.array_split(tensor[:, -1], self.num_batches, 0) 170 | split_len_batches = [] 171 | for i in range(len(self.x_batches)): 172 | split_len_batches.append(len(self.x_batches[i])) 173 | self.raw_lines_batches = self.list_split_n(raw_lines, split_len_batches) # should split by np.array_split 174 | sum = 0 175 | # for i in range(len(self.x_batches)): 176 | # print("i: " + str(i) + "len(self.x_batches[i]): " + str(len(self.x_batches[i]))) 177 | # sum += len(self.x_batches[i]) 178 | # print("sum: " + str(sum)) 179 | # 180 | # print("len(self.x_batches): " + str(len(self.x_batches))) 181 | # print("len(self.y_batches): " + str(len(self.y_batches))) 182 | # print("len(self.len_batches): " + str(len(self.len_batches))) 183 | # print("len(self.raw_lines_batches): " + str(len(self.raw_lines_batches))) 184 | 185 | def next_batch(self): 186 | batch_x = self.x_batches[self.pointer] 187 | batch_y = self.y_batches[self.pointer] 188 | xlen = self.len_batches[self.pointer] 189 | batch_line = self.raw_lines_batches[self.pointer] 190 | 191 | self.pointer += 1 192 | return batch_x, batch_y, xlen, batch_line 193 | 194 | def reset_batch_pointer(self): 195 | self.create_batches() 196 | self.pointer = 0 197 | 198 | 199 | if __name__ == "__main__": 200 | pass 201 | -------------------------------------------------------------------------------- /utils/match_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """ 4 | tools for run bi-lstm + dssm short text matching 5 | """ 6 | 7 | import random 8 | import math 9 | import copy 10 | 11 | 12 | class DataHelper(object): 13 | 14 | def __init__(self, train_path, valid_path, test_path, standard_path, batch_size, epcho_num, label2id_file, vocab2id_file, is_train): 15 | if is_train: 16 | self.train_valid_generator(train_path, valid_path, standard_path, batch_size, epcho_num) 17 | else: 18 | self.test_generator(test_path, standard_path, batch_size, label2id_file, vocab2id_file) 19 | 20 | def test_generator(self, test_path, standard_path, batch_size, label2id_file, vocab2id_file): 21 | self.label2id = {} 22 | self.id2label = {} 23 | self.vocab2id = {} 24 | self.id2vocab = {} 25 | self.std_id_ques = {} 26 | label_file = open(label2id_file, 'r', encoding='utf-8') 27 | for line in label_file.readlines(): 28 | label_ids = line.strip().split('\t') 29 | self.label2id[label_ids[0]] = label_ids[1] 30 | self.id2label[label_ids[1]] = label_ids[0] 31 | label_file.close() 32 | vocab_file = open(vocab2id_file, 'r', encoding='utf-8') 33 | for line in vocab_file.readlines(): 34 | vocab_ids = line.strip().split('\t') 35 | self.vocab2id[vocab_ids[0]] = vocab_ids[1] 36 | self.id2vocab[vocab_ids[1]] = vocab_ids[0] 37 | vocab_file.close() 38 | std_file = open(standard_path, 'r', encoding='utf-8') 39 | max_std_len = 0 40 | for line in std_file: 41 | label_words = line.strip().split("\t") 42 | label = label_words[1] 43 | w_temp = [] 44 | words = label_words[2].split(" ") 45 | for word in words: 46 | w_temp.append(self.vocab2id[word]) 47 | if max_std_len < len(w_temp): 48 | max_std_len = len(w_temp) 49 | self.std_id_ques[self.label2id[label]] = (len(w_temp), w_temp, label, line.strip()) 50 | std_file.close() 51 | self.std_batch = [] 52 | self.predict_label_seq = [] 53 | self.predict_id_seq = [] 54 | # when predicted test data must order by this sequence 55 | for std_id, ques_info in self.std_id_ques.items(): 56 | self.std_batch.append((ques_info[0], ques_info[1])) 57 | self.predict_label_seq.append(self.id2label[std_id]) 58 | self.predict_id_seq.append(std_id) 59 | 60 | # std question padding 61 | for ques_info in self.std_batch: 62 | for _ in range(max_std_len - ques_info[0]): 63 | ques_info[1].append(self.vocab2id['PAD']) 64 | 65 | file = open(test_path, 'r', encoding='utf-8') 66 | self.test_num = 0 67 | self.test_batch = [] 68 | for line in file.readlines(): 69 | label_words = line.strip().split('\t') 70 | #label include list answer(label likes id1,id2,...) and rufuse answer(label is 0) 71 | label = label_words[0] 72 | w_temp = [] 73 | words = label_words[2].split(' ') 74 | for word in words: 75 | if word not in self.vocab2id: 76 | w_temp.append(self.vocab2id['UNK']) 77 | else: 78 | w_temp.append(self.vocab2id[word]) 79 | self.test_batch.append((len(w_temp), w_temp, label, label_words[2])) 80 | self.test_num = self.test_num + 1 81 | file.close() 82 | self.batch_size = batch_size 83 | self.test_num_batch = math.ceil(self.test_num / self.batch_size) 84 | 85 | def train_valid_generator(self, train_path, valid_path, standard_path, batch_size, epcho_num): 86 | self.label2id = {} 87 | self.id2label = {} 88 | self.vocab2id = {} 89 | self.vocab2id['PAD'] = 0 90 | self.vocab2id['UNK'] = 1 91 | self.id2vocab = {} 92 | self.id2vocab[0] = 'PAD' 93 | self.id2vocab[1] = 'UNK' 94 | #standard question 95 | file = open(standard_path, 'r', encoding='utf-8') 96 | self.std_id_ques = {} 97 | max_std_len = 0 98 | for line in file.readlines(): 99 | label_words = line.strip().split("\t") 100 | label = label_words[1] 101 | if label not in self.label2id: 102 | self.label2id[label] = len(self.label2id) 103 | self.id2label[self.label2id[label]] = label 104 | w_temp = [] 105 | words = label_words[2].split(" ") 106 | for word in words: 107 | if word not in self.vocab2id: 108 | self.vocab2id[word] = len(self.vocab2id) 109 | self.id2vocab[self.vocab2id[word]] = word 110 | w_temp.append(self.vocab2id[word]) 111 | if max_std_len < len(w_temp): 112 | max_std_len = len(w_temp) 113 | self.std_id_ques[self.label2id[label]] = (len(w_temp), w_temp, label, line.strip()) 114 | file.close() 115 | self.std_batch = [] 116 | self.predict_label_seq = [] 117 | self.predict_id_seq = [] 118 | #when predicted valid data must order by this sequence 119 | for std_id, ques_info in self.std_id_ques.items(): 120 | self.std_batch.append((ques_info[0], ques_info[1])) 121 | self.predict_label_seq.append(self.id2label[std_id]) 122 | self.predict_id_seq.append(std_id) 123 | self.train_num = 0 124 | self.train_id_ques = {} 125 | file = open(train_path, 'r', encoding='utf-8') 126 | for line in file.readlines(): 127 | label_words = line.strip().split('\t') 128 | label = label_words[0] 129 | if ',' in label or '0' == label: 130 | continue 131 | assert label in self.label2id 132 | w_temp = [] 133 | words = label_words[2].split(' ') 134 | for word in words: 135 | if word not in self.vocab2id: 136 | self.vocab2id[word] = len(self.vocab2id) 137 | self.id2vocab[self.vocab2id[word]] = word 138 | w_temp.append(self.vocab2id[word]) 139 | label_id = self.label2id[label] 140 | if label_id not in self.train_id_ques: 141 | self.train_id_ques[label_id] = [] 142 | self.train_id_ques[label_id].append((len(w_temp), w_temp)) 143 | else: 144 | self.train_id_ques[label_id].append((len(w_temp), w_temp)) 145 | self.train_num = self.train_num + 1 146 | file.close() 147 | self.vocab_size = len(self.vocab2id) 148 | #std question padding 149 | for ques_info in self.std_batch: 150 | for _ in range(max_std_len - ques_info[0]): 151 | ques_info[1].append(self.vocab2id['PAD']) 152 | file = open(valid_path, 'r', encoding='utf-8') 153 | self.valid_num = 0 154 | self.valid_batch = [] 155 | for line in file.readlines(): 156 | label_words = line.strip().split('\t') 157 | label = label_words[0] 158 | #del list answer and rufuse answer 159 | if ',' in label or '0' == label: 160 | continue 161 | assert label in self.label2id 162 | w_temp = [] 163 | words = label_words[2].split(' ') 164 | for word in words: 165 | if word not in self.vocab2id: 166 | w_temp.append(self.vocab2id['UNK']) 167 | else: 168 | w_temp.append(self.vocab2id[word]) 169 | self.valid_batch.append((len(w_temp), w_temp, label, label_words[2])) 170 | self.valid_num = self.valid_num + 1 171 | file.close() 172 | 173 | self.batch_size = batch_size 174 | self.train_num_epcho = epcho_num 175 | self.train_num_batch = math.ceil(self.train_num / self.batch_size) 176 | self.valid_num_batch = math.ceil(self.valid_num / self.batch_size) 177 | 178 | 179 | def weight_random(self, label_questions, batch_size): 180 | def index_choice(weight): 181 | index_sum_weight = random.randint(0, sum(weight) - 1) 182 | for i, val in enumerate(weight): 183 | index_sum_weight -= val 184 | if index_sum_weight < 0: 185 | return i 186 | return 0 187 | batch_keys = [] 188 | keys = list(label_questions.keys()).copy() 189 | weights = [len(label_questions[key]) for key in keys] 190 | for _ in range(batch_size): 191 | index = index_choice(weights) 192 | key = keys.pop(index) 193 | batch_keys.append(key) 194 | weights.pop(index) 195 | return batch_keys 196 | 197 | def train_batch_iterator(self, label_questions, standard_label_question): 198 | ''' 199 | select a couple question for each class 200 | ''' 201 | num_batch = self.train_num_batch 202 | num_epcho = self.train_num_epcho 203 | for _ in range(num_batch * num_epcho): 204 | query_batch = [] 205 | doc_batch = [] 206 | batch_keys = self.weight_random(label_questions, self.batch_size) 207 | batch_query_max_num = 0 208 | for key in batch_keys: 209 | questions = copy.deepcopy(random.sample(label_questions[key], 1)[0]) 210 | current_num = questions[0] 211 | if current_num > batch_query_max_num: 212 | batch_query_max_num = current_num 213 | query_batch.append(questions) 214 | doc = standard_label_question[key] 215 | doc_batch.append(doc) 216 | #padding 217 | for query, doc in zip(query_batch, doc_batch): 218 | for _ in range(batch_query_max_num - query[0]): 219 | query[1].append(self.vocab2id['PAD']) 220 | yield query_batch, doc_batch 221 | 222 | def valid_batch_iterator(self): 223 | num_batch = self.valid_num_batch 224 | num_epcho = 1 225 | for i in range(num_batch * num_epcho): 226 | if i * self.batch_size + self.batch_size < self.valid_num: 227 | query_batch = copy.deepcopy(self.valid_batch[i * self.batch_size : i * self.batch_size + self.batch_size]) 228 | else: 229 | query_batch = copy.deepcopy(self.valid_batch[i * self.batch_size : ]) 230 | batch_query_max_num = 0 231 | for q_len, _, _, _ in query_batch: 232 | if batch_query_max_num < q_len: 233 | batch_query_max_num = q_len 234 | #padding 235 | for q_len, label_words, _, _ in query_batch: 236 | for _ in range(batch_query_max_num - q_len): 237 | label_words.append(self.vocab2id['PAD']) 238 | yield query_batch 239 | 240 | def test_batch_iterator(self): 241 | num_batch = self.test_num_batch 242 | num_epcho = 1 243 | for i in range(num_batch * num_epcho): 244 | if i * self.batch_size + self.batch_size < self.test_num: 245 | query_batch = copy.deepcopy(self.test_batch[i * self.batch_size : i * self.batch_size + self.batch_size]) 246 | else: 247 | query_batch = copy.deepcopy(self.test_batch[i * self.batch_size : ]) 248 | batch_query_max_num = 0 249 | for q_len, _, _, _ in query_batch: 250 | if batch_query_max_num < q_len: 251 | batch_query_max_num = q_len 252 | #padding 253 | for q_len, label_words, _, _ in query_batch: 254 | for _ in range(batch_query_max_num - q_len): 255 | label_words.append(self.vocab2id['PAD']) 256 | yield query_batch 257 | 258 | if __name__ == "__main__": 259 | pass 260 | --------------------------------------------------------------------------------