├── LICENSE
├── README.md
└── code
├── data.py
├── data_argument.py
├── doc2unix.py
├── docker_process.py
├── inference.py
├── model.py
├── module.py
├── run.sh
├── statistic.py
├── test.py
├── train.py
├── train_evaluate_constraint.py
├── triples.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 2021CCKS-QA-Task
2 | A solution to 2021-CCKS-QA-Task(https://tianchi.aliyun.com/competition/entrance/531904/introduction)
3 |
4 | ### 写在前面
5 | 这个比赛结束已经大半年了,期间参赛者似乎都没有做方案的开源,讨论区也很冷清。
6 | 其实自结束以来一直都想做开源工作,分享自己在算法比赛过程中的思路,奈何学业科研繁忙一直搁置。最近因为机缘巧合,打算践行开源,给有需要的同学,提供一些思路(抛砖引玉233)。同时也希望NLP开源社区可以越来越活跃。
7 | 最后这个比赛是我自学NLP之后认真筹备的第一个项目,构建代码时做了很多尝试,比较乱,望见谅。
8 |
9 | ____________________________________________
10 |
11 |
12 |
13 |
14 | # CCKS2021 运营商知识图谱推理问答 - 湖人总冠军方案
15 |
16 |
17 |
18 | ### 1 任务背景
19 |
20 | 基于给定的运营商知识图谱,使用模型预测用户所提出问题的答案,任务来自2021CCKS会议在阿里云天池平台组织的算法竞赛。在自然语言处理领域该任务属于**KBQA**(**K**nowledge **B**ase **Q**uestion **A**nswering)。
21 |
22 | 场景案例
23 |
24 | ```
25 | Q:流量日包的流量有多少?
26 | A:100MB
27 |
28 | Q:不含彩铃的套餐有哪些?
29 | A:流量月包|流量年包
30 | ```
31 |
32 |
33 |
34 | ### 2 数据介绍
35 |
36 | 用以问答的数据的数据包含(数据下载地址 https://tianchi.aliyun.com/dataset/dataDetail?dataId=109340 )
37 |
38 | * 知识图谱schema:定义了结点的实体、关系类型。
39 |
40 | * 知识图谱三元组:类似于<流量日包, 流量, 100MB>形式的三元组。
41 |
42 | * 同义词文件:用户口语化问题中实体的同义词,辅助实体的分类。
43 |
44 | * 训练数据:5000条。
45 |
46 | * 测试数据:1000条。
47 |
48 | | **用户问题** | **答案类型** | **属性名** | **实体** | **约束属性名** | **约束属性值** | **约束算子** | **答案** |
49 | | ----------------------- | ------------ | ----------------- | ---------------- | -------------- | -------------- | ------------ | ------------ |
50 | | 我的20元20G流量怎么取消 | 属性值 | 档位介绍-取消方式 | 新惠享流量券活动 | 价格\|流量 | 20\|20 | =\|= | 取消方式_133 |
51 |
52 |
表1 训练集样本示例
53 |
54 | | **问题** |
55 | | ------------------------------ |
56 | | 什么时候可以办理70元的优酷会员 |
57 |
58 | 表2 测试集样本示例
59 |
60 | ### 3 数据分析
61 |
62 | 目前知识图谱问答系统在简单句(单实体单属性)上已经取得比较好的效果,而在约束句:条件约束句、时间约束句,以及推理型问句:比较句、最值句、是否型问句以及问句中带有交集、并集和取反的问句等,其逻辑推理能力还有待提升。
63 |
64 | 图1展示了问句类型的分布。训练数据中4146个样本是属性句,696条样本是并列句,158条样本是比较句。虽然预测难度较小的属性句占比较大,但是在测试集上取得较高的指标仍然需要聚焦于其他两种类型的数据。
65 |
66 | 图2展示了数据中的实体分布,前三的实体是新惠享流量券活动,全国亲情网,北京移动plus会员。可以看出实体的分布存在长尾效应,而类别最少的样本仅1条。
67 |
68 |
69 |
70 | 图1 问句类型分布
71 |
72 |
73 |
74 | 图2 实体分布
75 |
76 |
77 |
78 | ### 4 整体设计
79 |
80 | 为了提高推理的性能,我们算法的整体思路是用NLP模型对文本进行语义解析后,将解析到的成分拼接成SPARQL查询语句,利用查询语句查询知识库得到最终的答案。
81 | 语义解析部分分为两个部分
82 |
83 | - 多任务分类模型预测答案类型、属性名、实体;
84 | - 命名实体识别模型预测约束属性名和约束属性值。
85 |
86 | 
87 |
88 | 图3 整体流程
89 |
90 | ### 5 数据预处理
91 |
92 | 数据预处理主要分为数据清洗、数据集处理、数据增强三个方面。
93 |
94 | #### 5.1 数据清洗
95 |
96 | 数据清洗工作包含了对训练、测试数据,三元组数据的清洗,主要做以下处理
97 |
98 | - 英文字母转化成小写形式,例如A—>a
99 | - 中文的数字表示用阿拉伯数字表示,例如五十—>50
100 | - 去除问句中包含的空格
101 |
102 | #### 5.2 数据集处理
103 |
104 | 这里的数据集处理主要将原数据转化成可训练的数据格式。
105 |
106 | | **用户问题** | **答案类型** | **属性名** | **实体** | **约束属性名** | **约束属性值** | **约束算子** | **答案** |
107 | | ----------------------- | ------------ | ----------------- | ---------------- | -------------- | -------------- | ------------ | ------------ |
108 | | 我的20元20G流量怎么取消 | 属性值 | 档位介绍-取消方式 | 新惠享流量券活动 | 价格|流量 | 20|20 | =|= | 取消方式_133 |
109 |
110 | 表3 训练集样本示例
111 |
112 | 答案类型、属性名、实体的预测被作为视作是文本的分类,其中答案类型、实体预测属于多分类任务,属性名是多标签分类任务。约束属性名、约束属性值是在字符级别上的分类,这被我们理解为命名实体识别任务,这就需要将约束属性名对应到原问句中的约束属性值。
113 | 在仔细分析标注数据后发现约束属性值基本都会出现在原问句中,如表3中的两个20。因此使用正则匹配的方法自动进行约束属性名和属性值的标注。
114 | **标注数据** 采用BIEO方法,用train.xlsx中的「约束属性」和「约束值」对「用户问题」进行标注。注意标注前需要统一数字和字母的大小写。标注示例如下:
115 |
116 | '我 的 **2** **0** 元 **2** **0** G 流 量 怎 么 取 消'
117 |
118 | 'O O **_B-PRICE E-PRICE_** O **B-FLOW E-FLOW** O O O O O O O'
119 |
120 | #### 5.3 数据增强
121 |
122 | 数据增强从主要考虑三个角度
123 |
124 | **少样本增强** 少样本增强的动机是观察到数据中的标签存在分布不平衡的现象。我们利用`nlpcda`包,使用EDA的方案对数据量少于某个阈值的标签样本进行增强。由于EDA中的随机删除操作会导致模型性能变差,随机删除操作被我们去除。
125 |
126 | **同义词增强** 比赛数据中包含了实体的同义词文件,为了尽最大化利用该数据,提高模型在测试集中的泛化能力,我们利用该文件对问句中的实体用其同义词进行替换,生成新的问句,从而达到扩充训练数据的目的。但是由于同义词较多,初次生成的样本达20000+例,是原训练样本的四倍,增加了训练时间开销且提升效果不明显。通过对数据的分析,我们发现训练集中样本量多的实体其同义词一般也较多。如果直接替换,会加剧标签分布不平衡的问题,从而导致提升不明显。基于这个问题,我们遵从少样本增强的思想,对样本量较多的实体按概率来进行同义词替换,对样本量较少的实体全量进行替换。基于该方法,得到增强样本6000+例,提高了训练速度,明显地增加了模型的泛化性能。
127 |
128 | **基于Bert的文本生成增强** 前两种增强方式都是基于原样本的小范围修改,本次增强方式利用`nlpcda`包中的simbert模型来生成与原句意思相近,但表达方式不同的样本,进一步提升模型的泛化能力。基于该方法得到增强样本4000+例。
129 |
130 |
131 |
132 | ### 6 分类模块
133 |
134 | 
135 |
136 | 图4 分类模型结构
137 |
138 | 分类模块作用是对答案类型、属性名、实体的分类。由于一个问句只有一个答案类型和一个实体,而可以有多个属性名,所以我们把答案类型和实体分类视作多分类任务,把属性名视作多标签分类任务。
139 | 使用多任务学习的考虑是多个相近的任务一起来训练可以提高准确率、学习速度和泛化能力。我们的实验证明多任务模型比单任务模型准确率整体更高且更稳定。
140 | 最终的分类结果由五个模型平均融合得到,分别是BERT-base、XLNET-base、RoBERTa-base、ELECTRA-base、MacBERT-base。
141 | 接下来开始介绍单模型中使用的方法。
142 |
143 | 1. 任务级的注意力机制
144 |
145 | 
146 |
147 | 图5 任务级注意力机制
148 |
149 | 对BERT等预训练模型产生的字符级embedding使用任务级别的Attention机制,得到适合于当前子任务的句子级embedding,进而用于后续的分类器。实验证明结果该注意力机制对于结果的稳定与提升有明显的效果,比赛实验结果无记录。方法详情可见[Same Representation, Different Attentions: Shareable Sentence Representation Learning from Multiple Tasks](https://arxiv.org/pdf/1804.08139.pdf)。
150 |
151 | 2. 样本有效率
152 |
153 | 在经过数据增强操作后,得到了近15000例的增强文本。这些样本的质量或多或少低于原数据,直接用于模型的训练会使模型向这些低质量样本拟合,进而导致模型无法准确学习到原标注样本中数据分布,最终导致训练费时且效果低下。
154 | 设置样本有效率权重是解决该问题的一个思路。我们在仔细分析了不同数据增强文本后,为不同方法得到的样本设置了不同的有效率,以让模型可以从这堆样本中学习到东西同时又不会在低质量的样本上过拟合。以下是实验结果。
155 |
156 | | | Base | Base + **加入样本权重** |
157 | | -------- | ------ | ----------------------- |
158 | | 复赛结果 | 0.9423 | 0.9453 |
159 |
160 | 3. 二阶段的训练模式
161 |
162 | 当原数据混合了增强样本,在训练时如果还是按照原来的模式——将训练样本随机划分20%作为验证集,必然会在验证集中混入增强的样本。挑选验证集上表现最佳的模型时,这个模型是在低质量的数据上表现最优,这会导致模型在高质量的测试集下表现得远不如验证集。
163 | 我们在解决这个问题的思路是使用两个阶段来训练模型。一阶段时,模型在全量的增强样本上进行训练,不划分验证集,训练λ个epoch,λ作为超参来调整,在训练完后保存该模型参数。二阶段时,加载一阶段的模型并在原训练集上进行“微调”训练,划分20%的数据作为验证集,挑选在验证集上表现最优的模型。实验结果显示,模型有少量的提升,更重要的是由于减少了整体训练的epoch,训练的时间开销减少约一倍。
164 |
165 | | | 原训练方式 | **二阶段训练** |
166 | | -------- | ---------- | -------------- |
167 | | 复赛结果 | 0.9477 | 0.9485 |
168 | | 运行时间 | 50min | 29min |
169 |
170 | 4. 损失函数优化
171 |
172 | 多任务模型的损失值来自于三个子任务的损失值的加权和,加权方式使用可学习的参数来进行自动调整,相比于固定权重的方法,该方法可以使模型结果更加稳定。比赛实验结果无记录。
173 |
174 | $$
175 | \alpha = softmax(W)
176 | $$
177 |
178 | $$
179 | \mathcal{L_{total}} = - \sum_{i=1}\alpha_i\mathcal{L_{Task_i}}
180 | $$
181 |
182 | ### 7 实体识别模块
183 |
184 | 本模块目的是为了预测出句子中的「约束属性」和「约束值」。
185 | 
186 |
187 | 图6 实体识别模型结构
188 |
189 | **重采样** 实体对应很多同义词,但是有些同义词在训练集中出现次数很少,甚至只出现一次,如果该同义词中又出现了价格或者流量等,就很有可能导致预测错误。比如「和家庭流量包」有一个同义词「100元5g」,「100元5g」是一个同义词,其中的「100元」不应该被预测成约束;但是由于「100元5g」在训练集中就只出现了一次,并且其他样本中「100元」经常被标注出来,所以导致模型错误地把同义词中的价格也预测出来了。因此对句子实体存在同义词的样本进行重采样,让模型充分学习这些样本。
190 |
191 | **同义词增强** 同一个实体中不同约束出现的频率不同,比如「天气预报」的价格有「3元」「5元」「6元」三档,但「3元」在训练集中出现次数远少于其他两个。针对这些样本我们对实体进行同义词替换后加入到了训练集。
192 |
193 | | | *base* | *base* + **重采样** | *base* + **同义词增强** |
194 | | -------- | ------ | -------------------- | ------------------------ |
195 | | 复赛结果 | 0.9518 | 0.9534 | 0.9567 |
196 |
197 | **模型训练** 模型采用了BiLSTM-crf 模型进行训练,数据分为5折交叉训练,batch_size 选用64,lr 选用0.001,epoch选用50。
198 |
199 | **推理前的同义词替换** 很多实体对应大量同义词,其中部分同义词会对实体识别造成干扰(干扰主要是由于同义词中包含价格、流量、子业务),比如「天气预报」和「天气预报年包」是同义词。模型会把「年包」标注出来(但实际不应该标注出来),所以在推理前采用同义词替换,将对标注造成干扰的词进行同义词替换。
200 |
201 | **推理部分** 5折训练的模型分别推理,并将5折模型结果进行平均融合,得分相加取平均再转换到BIEO标的预测结果。
202 |
203 | | | **base** | *base* + **5折平均融合** | *base* + **5折平均融合+推理前的同义词替换** |
204 | | -------- | -------- | -------------------------- | --------------------------------------------- |
205 | | 复赛结果 | 0.9486 | 0.9518 | 0.9544 |
206 |
207 | ### 8 后处理
208 |
209 | **实体识别预测结果补全** 实体识别预测出的结果存在部分错误,比如「680」只识别了「68」,或「半年包」只识别了「年包」。针对以上情况,通过判断「年包」前面的字符是否是「半」,「68」紧邻的前后字符是否是数字等,对结果进行补全。
210 |
211 | **实体识别结果规则处理** 对预测结果不合理的部分进行规则处理。比如:(1)价格和流量预测反。通过判断‘元’‘G’等单位,来判断是否出现价格预测成了流量,或者流量预测成了价格的情况。(2)针对容易预测错误的句子,采用规则方法进行纠正。比如「新惠享流量券活动」只需要判断句子中是否存在「20」或者「30」就可以得到约束属性。
212 |
213 | **空结果处理** 在使用SPARQL查询过程中可能会有空结果的情况出现。当发生此类情况时,将模型预测的标签转化成其可能会混淆的标签,设置多层的判定条件,可以尽最大可能性避免空结果的产生。
214 |
215 | **属性名细分处理** 分类模型在“开通方式”和“开通条件”上混淆较为严重。当多任务模型只预测了上述标签中的其中一个,我们使用一个二分类模型和规则方法去修正标签是否正确。二分类模型的分类标签为“开通方式”和“开通条件”。
216 |
217 | **实体规则处理** 使用同义词表中的同义词来修正模型预测的属性名标签结果。
218 |
219 | ### 9 总结
220 |
221 | 从比赛初期的方案构思到复赛白热化的“军备竞赛”,团队成员在这次比赛中成长了很多,收获了很多。以下是我们本次比赛的总结。
222 |
223 | - 知识图谱推理问答与一般的文本匹配或分类任务有较大的不同,获取答案需要有多个维度的信息,系统中每一个薄弱点都可能成为木桶里的“短板”,因此既要兼顾大局又要各个击破。
224 | - 比赛由于任务自身的问题,标注信息没有那么准确,因此需要对数据进行仔细的甄别,有疑问的数据需要去除;
225 | - 由于复赛的docker运行时长只有6小时,所以不光要考虑模型的性能还需要考虑训练的效率;
226 | - 没有万金油的方案,深度学习需要不断实现方案并实验;
227 |
228 | 此外,由于实验室的学业压力,仍有很多方案没有进行尝试,同时很多方案尝试了但没有提升结果,与前排大佬也有不少差距。希望吸取本次比赛的经验和教训,继续提升自身的实力。
229 |
230 |
231 | ### 10 附录
232 |
233 | ##### 10.1 项目目录树
234 |
235 | ```sh
236 | .
237 | └── code
238 | ├── data.py # 数据模块
239 | ├── data_argument.py # 数据增强模块
240 | ├── doc2unix.py # 修改回车符
241 | ├── docker_process.py # docker中预处理模块
242 | ├── inference.py # 推理模块
243 | ├── model.py # 模型结构
244 | ├── module.py # 必要模块
245 | ├── run.sh # 一键数据处理、训练、推理
246 | ├── statistic.py # 数据统计模块
247 | ├── test.py # 测试类
248 | ├── train.py # 多任务学习训练模块
249 | ├── train_evaluate_constraint.py # 命名实体识别训练模块
250 | ├── triples.py # 知识图谱与三元组处理模块
251 | └── utils.py # 工具类
252 | ```
253 |
254 | ##### 10.2 环境依赖
255 |
256 | ```
257 | bert4keras==0.7.7
258 | jieba==0.42.1
259 | Keras==2.3.1
260 | nlpcda==2.5.6
261 | numpy==1.16.5
262 | openpyxl==3.0.7
263 | pandas==1.2.3
264 | scikit-learn==0.23.2
265 | tensorboard==1.14.0
266 | tensorflow-gpu==1.14.0
267 | textda==0.1.0.6
268 | torch==1.5.0
269 | tqdm==4.50.2
270 | transformers==4.0.1
271 | xlrd==2.0.1
272 | rdflib==5.0.0
273 | nvidia-ml-py3==7.352.0
274 | ```
275 |
276 | ##### 10.3 运行
277 |
278 | ```shell
279 | sh run.sh
280 | ```
281 |
282 |
283 |
284 |
285 |
--------------------------------------------------------------------------------
/code/data.py:
--------------------------------------------------------------------------------
1 | """
2 | @Author: hezf
3 | @Time: 2021/6/3 14:41
4 | @desc:
5 | """
6 | from torch.utils.data import Dataset
7 | import torch
8 | import copy
9 | import json
10 | from utils import label_to_multi_hot, remove_stop_words
11 | import re
12 | import pandas as pd
13 |
14 |
15 | class BertDataset(Dataset):
16 | """
17 | 用于bert的Dataset类
18 | """
19 | def __init__(self, train_file, tokenizer, label_hub, init=True):
20 | super(BertDataset, self).__init__()
21 | self.train_file = train_file
22 | self.data = []
23 | if init:
24 | self.tokenizer = tokenizer
25 | self.label_hub = label_hub
26 | # self.stopwords = []
27 | self.init()
28 |
29 | def init(self):
30 | # print('加载停用词表...')
31 | # with open('../data/file/stopword.txt', 'r', encoding='utf-8') as f:
32 | # for line in f:
33 | # self.stopwords.append(line.strip())
34 | print('读取数据...')
35 | with open(self.train_file, 'r', encoding='utf-8') as f:
36 | for line in f:
37 | # blocks: 0:问题;1:答案类型;2:属性名;3:实体;4:答案;5:样本有效率
38 | blocks = line.strip().split('\t')
39 | # 单独处理”属性名“
40 | prop_label_ids = [self.label_hub.prop_label2id[label] for label in blocks[2].split('|')]
41 | prop_label = label_to_multi_hot(len(self.label_hub.prop_label2id), prop_label_ids)
42 | self.data.append({'token': self.tokenizer(blocks[0],
43 | add_special_tokens=True, max_length=100,
44 | padding='max_length', return_tensors='pt',
45 | truncation=True),
46 | 'ans_label': self.label_hub.ans_label2id[blocks[1]],
47 | 'prop_label': prop_label,
48 | 'entity_label': self.label_hub.entity_label2id[blocks[3]],
49 | 'efficiency': float(blocks[-1])})
50 |
51 | def __getitem__(self, item):
52 | return self.data[item]
53 |
54 | def __len__(self):
55 | return len(self.data)
56 |
57 |
58 | class BinaryDataset(Dataset):
59 | """
60 | 用于bert的Dataset类
61 | """
62 | def __init__(self, train_file, tokenizer, label_hub, init=True):
63 | super(BinaryDataset, self).__init__()
64 | self.train_file = train_file
65 | self.data = []
66 | if init:
67 | self.tokenizer = tokenizer
68 | self.label_hub = label_hub
69 | # self.stopwords = []
70 | self.init()
71 |
72 | def init(self):
73 | # print('加载停用词表...')
74 | # with open('../data/file/stopword.txt', 'r', encoding='utf-8') as f:
75 | # for line in f:
76 | # self.stopwords.append(line.strip())
77 | print('读取数据...')
78 | with open(self.train_file, 'r', encoding='utf-8') as f:
79 | for line in f:
80 | # blocks: 0:问题;1:标签;2:样本有效率
81 | blocks = line.strip().split('\t')
82 | # 单独处理”属性名“
83 | self.data.append({'token': self.tokenizer(blocks[0],
84 | add_special_tokens=True, max_length=100,
85 | padding='max_length', return_tensors='pt',
86 | truncation=True),
87 | 'label': self.label_hub.binary_label2id[blocks[1]],
88 | 'efficiency': float(blocks[-1])})
89 |
90 | def __getitem__(self, item):
91 | return self.data[item]
92 |
93 | def __len__(self):
94 | return len(self.data)
95 |
96 |
97 | class PredDataset(Dataset):
98 | """
99 | 用于预测的Dataset类
100 | """
101 | def __init__(self, data_path, tokenizer):
102 | super(PredDataset, self).__init__()
103 | self.data_path = data_path
104 | self.tokenizer = tokenizer
105 | self.data = []
106 | self.init()
107 |
108 | def init(self):
109 | with open(self.data_path, 'r', encoding='utf-8') as f:
110 | for line in f:
111 | line = line.strip()
112 | line = line.split('\t')
113 | self.data.append({'token': self.tokenizer(line[0], add_special_tokens=True, max_length=100,
114 | padding='max_length', return_tensors='pt',
115 | truncation=True),
116 | 'question': line[0]})
117 |
118 | def __len__(self):
119 | return len(self.data)
120 |
121 | def __getitem__(self, item):
122 | return self.data[item]
123 |
124 |
125 | class PredEnsembleDataset(Dataset):
126 | """
127 | 用于集成预测的Dataset类
128 | """
129 | def __init__(self, data_path):
130 | super(PredEnsembleDataset, self).__init__()
131 | self.data_path = data_path
132 | self.data = []
133 | self.init()
134 |
135 | def init(self):
136 | with open(self.data_path, 'r', encoding='utf-8') as f:
137 | for line in f:
138 | line = line.strip()
139 | self.data.append(line)
140 |
141 | def __len__(self):
142 | return len(self.data)
143 |
144 | def __getitem__(self, item):
145 | return self.data[item]
146 |
147 |
148 | def process_data(data_lists, bieo_dict=None, test=False):
149 | '''
150 | 由[question] 得到 [['你','好','世','界']] [['O','O','B','E']]
151 | 另外 返回的word_list后面需要加上'',非test tag_list后面也要加
152 | @param data_lists: question的list
153 | @param bieo_dict: 由create_test_BIEO得到的标注字典 {question:['O','O','B','E']}
154 | @return: question逐字的list 和 标注的list
155 | '''
156 | data_word_lists = []
157 | tag_word_lists = []
158 | for data_list in data_lists:
159 | data_word_list = []
160 | for word in data_list:
161 | data_word_list.append(word)
162 | data_word_list.append('')
163 | data_word_lists.append(data_word_list)
164 | if bieo_dict is not None:
165 | tag_word_list = bieo_dict[data_list] + []
166 | if not test: tag_word_list.append('')
167 | tag_word_lists.append(tag_word_list)
168 | if bieo_dict is None:
169 | return data_word_lists
170 | return data_word_lists, tag_word_lists
171 |
172 |
173 | def process_syn(data_list, entities, ans_labels):
174 | '''
175 | 识别出句子中的实体或者同义词,并将同义词换成实体
176 | @param data_list:
177 | @param entities:
178 | @return:
179 | '''
180 | new_data_list = []
181 | from data_argument import read_synonyms
182 | syn_dict = read_synonyms()
183 | for index in range(len(data_list)):
184 | question = data_list[index]
185 | ans_label = ans_labels[index]
186 | entity = entities[index]
187 | if entity in syn_dict:
188 | # 排除以下实体,因为同义词中包含有价格、流量、子业务等 #
189 | if entity in ['视频会员通用流量月包', '1元5gb流量券', '校园卡活动', '新惠享流量券活动', '专属定向流量包',
190 | '任我看视频流量包', '快手定向流量包', '5g智享套餐家庭版', '视频会员5gb通用流量7天包',
191 | '语音信箱', '全国亲情网', '通话圈', '和家庭分享']:
192 | new_data_list.append(question)
193 | continue
194 | candidate_list = []
195 | for syn in syn_dict[entity] + [entity]:
196 | if syn == '无': continue
197 | if syn == '': continue
198 | if syn in question:
199 | candidate_list.append(syn)
200 | candidate_list = sorted(candidate_list, key=lambda x:len(x))
201 | if len(candidate_list) == 0:
202 | new_data_list.append(question)
203 | continue
204 | syn_curr = candidate_list[-1]
205 | question = question.replace(syn_curr, entity)
206 | new_data_list.append(question)
207 | else:
208 | new_data_list.append(question)
209 | return new_data_list
210 |
211 |
212 | def process_postdo(subs, question ,entity):
213 | if entity == '彩铃':
214 | if len(subs) == 0:
215 | if '3元' in question: subs.append(('价格', '3'))
216 | if '5元' in question: subs.append(('价格', '5'))
217 | if entity == '承诺低消优惠话费活动':
218 | if '百度包' in question:
219 | if ('子业务', '百度包') not in subs:
220 | subs.append(('子业务', '百度包'))
221 | if entity == '新惠享流量券活动':
222 | if '30' in question:
223 | subs = [('价格', '30')]
224 | if '40' in question:
225 | subs.append(('流量', '40'))
226 | if '20' in question:
227 | subs = [('价格', '20')]
228 | if question.count('20') > 1:
229 | subs.append(('流量', '20'))
230 | if entity == '神州行5元卡':
231 | if len(subs) == 0:
232 | if question.count('5') == 2 or ('5' in question and '30' in question):
233 | subs.append(('价格', '5'))
234 | if entity == '承诺低消优惠话费活动':
235 | if '38' in question:
236 | subs = [('价格', '38')]
237 | if '18' in question:
238 | subs = [('价格', '18')]
239 |
240 | # todo 加入其他后处理
241 | return subs
242 |
243 |
244 | def process_postdo_last(question, entity_label, prop_label):
245 | if '家庭亲情网' in question or '家庭亲情通话' in question or '互打免费' in question:
246 | entity_label = '全国亲情网'
247 | elif '家庭亲情号' in question:
248 | entity_label = '和家庭分享'
249 | if '成员' in question:
250 | entity_label = '通话圈'
251 | elif '家庭亲情' in question:
252 | entity_label = '和家庭分享'
253 | # 北京移动plus会员权益卡 和 北京移动plus会员 混淆
254 | if '权益卡' in question:
255 | entity_label = '北京移动plus会员权益卡'
256 | # 30元5gb半价体验版 和 30元5gb 混淆
257 | if '半价' in question:
258 | entity_label = '30元5gb半价体验版'
259 | # 流量无忧包 和 畅享套餐 混淆
260 | if '无忧' in question:
261 | entity_label = '流量无忧包'
262 | if '全国亲情网' in question and '全国亲情网功能费优惠活动' not in question:
263 | entity_label = '全国亲情网'
264 | # 和留言服务中的留言服务二字,同时是语音信箱的同义词,容易预测错误
265 | if '和留言服务' in question:
266 | entity_label = '和留言'
267 | # 训练集中所有'帮我开通'都对应'开通方式'
268 | if '帮我开通' in question:
269 | if len(prop_label) == 1:
270 | prop_label = ['档位介绍-开通方式']
271 | # 训练集中的'怎么取消不了'都对应'取消方式'
272 | if '怎么取消不了' in question:
273 | if len(prop_label) == 1:
274 | prop_label = ['档位介绍-取消方式']
275 | return entity_label, prop_label
276 |
277 | def get_anno_dict(question, anno_list, service_names=None):
278 | '''
279 | 从bmes标注 得到 实际的 约束属性名 和 约束属性值
280 | :param question: 用户问题(去掉空格的)
281 | :param anno_list: bmes的标注 ['O','B','I','E']
282 | :return: anno_map: {'价钱':['100','20']}
283 | '''
284 | type_map = {'PRICE': '价格', 'FLOW': '流量', 'SERVICE': '子业务', 'EXPIRE': '有效期'}
285 | anno_index = 0
286 | anno_map = {}
287 | while anno_index < len(anno_list):
288 | if anno_list[anno_index] == 'O':
289 | anno_index += 1
290 | else:
291 | if anno_list[anno_index].startswith('S'):
292 | anno_type = anno_list[anno_index][2:]
293 | anno_type = type_map[anno_type]
294 | anno_index += 1
295 | if anno_type not in anno_map:
296 | anno_map[anno_type] = []
297 | anno_map[anno_type].append(question[anno_index - 1: anno_index])
298 | continue
299 | anno_type = anno_list[anno_index][2:]
300 | anno_type = type_map[anno_type]
301 | anno_start_index = anno_index
302 | # B-xxx的次数统计
303 | B_count = 0
304 | while not anno_list[anno_index].startswith('E'):
305 | if anno_list[anno_index] == 'O':
306 | anno_index -= 1
307 | break
308 | if anno_index == len(anno_list)-1:
309 | break
310 | if anno_list[anno_index][0] == 'B':
311 | B_count += 1
312 | if B_count > 1:
313 | anno_index -= 1
314 | break
315 | anno_index += 1
316 | anno_index += 1
317 | # 对 子业务 和 数字 进行补全
318 | anno_value = question[anno_start_index: anno_index]
319 | if service_names != None:
320 | if anno_type == '子业务':
321 | candidate_list = []
322 | for service in service_names:
323 | if anno_value in service:
324 | candidate_list.append(service)
325 | # 如果没有符合的实体则筛去该数据
326 | if len(candidate_list) == 0:
327 | continue
328 | drop = True
329 | for candidate in candidate_list:
330 | if candidate in question:
331 | drop = False
332 | if drop: continue
333 | # 对照句子,找到最符合的对象
334 | candidate_list = sorted(candidate_list, key=lambda x:len(x), reverse=True)
335 | for candidate in candidate_list:
336 | if candidate == '半年包' and anno_value=='年包':break
337 | if candidate == '腾讯视频' and anno_value=='腾讯':
338 | if '任我看' in question or '24' in question:
339 | pass
340 | else:
341 | break
342 | if candidate in question:
343 | anno_value = candidate
344 | break
345 | if anno_type == '流量' or anno_type == '价格':
346 | l, r = anno_start_index, anno_index
347 | while r < len(question) and question[r] == '0':
348 | anno_value = anno_value + '0'
349 | r += 1
350 | while l > 0 and question[l-1].isdigit():
351 | anno_value = question[l-1] + anno_value
352 | l -= 1
353 | if anno_type not in anno_map:
354 | anno_map[anno_type] = []
355 | # 去除 同一个属性值 出现两次的现象(如果是合约版(比较句出现)、20(价格20流量20)、18(train.xlsx出现),则是正常现象)
356 | if anno_value in anno_map[anno_type]:
357 | if anno_value == '合约版' or anno_value == '20' or anno_value == '18':
358 | pass
359 | else: continue
360 | anno_map[anno_type].append(anno_value)
361 | # if service_names is not None:
362 | # if '和留言' in question and '年包' in question and '月包和留言年包' not in question:
363 | # if '子业务' in anno_map:
364 | # if '年包' not in anno_map['子业务']:
365 | # anno_map['子业务'].append('年包')
366 | # else:
367 | # anno_map['子业务'] = ['年包']
368 | return anno_map
369 |
370 |
371 | def get_anno_dict_with_pos(question, anno_list, service_names=None):
372 | '''
373 | 从bmes标注 得到 实际的 约束属性名 和 约束属性值, 以及位置
374 | :param question: 用户问题(去掉空格的)
375 | :param anno_list: bmes的标注 ['O','B','I','E']
376 | :return: anno_map: {'价钱':[['100', 1], ['20', 3]]} 记录值和在句子中出现的位置
377 | '''
378 | type_map = {'PRICE': '价格', 'FLOW': '流量', 'SERVICE': '子业务', 'EXPIRE': '有效期'}
379 | anno_index = 0
380 | anno_map = {}
381 | while anno_index < len(anno_list):
382 | if anno_list[anno_index] == 'O':
383 | anno_index += 1
384 | else:
385 | if anno_list[anno_index].startswith('S'):
386 | anno_type = anno_list[anno_index][2:]
387 | anno_type = type_map[anno_type]
388 | anno_index += 1
389 | if anno_type not in anno_map:
390 | anno_map[anno_type] = []
391 | anno_map[anno_type].append([question[anno_index - 1: anno_index], anno_index - 1])
392 | continue
393 | anno_type = anno_list[anno_index][2:]
394 | anno_type = type_map[anno_type]
395 | anno_start_index = anno_index
396 | # B-xxx的次数统计
397 | B_count = 0
398 | while not anno_list[anno_index].startswith('E'):
399 | if anno_list[anno_index] == 'O':
400 | anno_index -= 1
401 | break
402 | if anno_index == len(anno_list)-1:
403 | break
404 | if anno_list[anno_index][0] == 'B':
405 | B_count += 1
406 | if B_count > 1:
407 | anno_index -= 1
408 | break
409 | anno_index += 1
410 | anno_index += 1
411 | # 对 子业务 和 数字 进行补全
412 | anno_value = question[anno_start_index: anno_index]
413 | pos = anno_start_index
414 | if service_names != None:
415 | if anno_type == '子业务':
416 | candidate_list = []
417 | for service in service_names:
418 | if anno_value in service:
419 | candidate_list.append(service)
420 | # 如果没有符合的实体则筛去该数据
421 | if len(candidate_list) == 0:
422 | continue
423 | drop = True
424 | for candidate in candidate_list:
425 | if candidate in question:
426 | drop = False
427 | if drop: continue
428 | # 对照句子,找到最符合的对象
429 | candidate_list = sorted(candidate_list, key=lambda x:len(x), reverse=True)
430 | for candidate in candidate_list:
431 | if candidate == '半年包' and anno_value=='年包':break
432 | if candidate == '腾讯视频' and anno_value=='腾讯':
433 | if '任我看' in question or '24' in question:
434 | pass
435 | else:
436 | break
437 | if candidate in question:
438 | anno_value = candidate
439 | break
440 | if anno_type == '流量' or anno_type == '价格':
441 | l, r = anno_start_index, anno_index
442 | while r < len(question) and question[r] == '0':
443 | anno_value = anno_value + '0'
444 | r += 1
445 | while l > 0 and question[l-1].isdigit():
446 | anno_value = question[l-1] + anno_value
447 | l -= 1
448 | if anno_type not in anno_map:
449 | anno_map[anno_type] = []
450 | # 去除 同一个属性值 出现两次的现象(如果是合约版(比较句出现)、20(价格20流量20)、18(train.xlsx出现),则是正常现象)
451 | # enhence版 不需要去重
452 | # if anno_value in anno_map[anno_type]:
453 | # if anno_value == '合约版' or anno_value == '20' or anno_value == '18':
454 | # pass
455 | # else: continue
456 | anno_map[anno_type].append([anno_value, pos])
457 | return anno_map
458 |
459 |
460 | def pred_collate_fn(batch_data):
461 | """
462 | 用于用于预测的collate函数
463 | :param batch_data:
464 | :return:
465 | """
466 | input_ids, token_type_ids, attention_mask = [], [], []
467 | questions = []
468 | for instance in copy.deepcopy(batch_data):
469 | questions.append(instance['question'])
470 | input_ids.append(instance['token']['input_ids'][0].squeeze(0))
471 | token_type_ids.append(instance['token']['token_type_ids'][0].squeeze(0))
472 | attention_mask.append(instance['token']['attention_mask'][0].squeeze(0))
473 | return torch.stack(input_ids), torch.stack(token_type_ids), \
474 | torch.stack(attention_mask), questions
475 |
476 |
477 | def bert_collate_fn(batch_data):
478 | """
479 | 用于BERT训练的collate函数
480 | :param batch_data:
481 | :return:
482 | """
483 | input_ids, token_type_ids, attention_mask = [], [], []
484 | ans_labels, prop_labels, entity_labels = [], [], []
485 | efficiency_list = []
486 | for instance in copy.deepcopy(batch_data):
487 | input_ids.append(instance['token']['input_ids'][0].squeeze(0))
488 | token_type_ids.append(instance['token']['token_type_ids'][0].squeeze(0))
489 | attention_mask.append(instance['token']['attention_mask'][0].squeeze(0))
490 | ans_labels.append(instance['ans_label'])
491 | prop_labels.append(torch.tensor(instance['prop_label']))
492 | entity_labels.append(instance['entity_label'])
493 | efficiency_list.append(instance['efficiency'])
494 | return torch.stack(input_ids), torch.stack(token_type_ids), \
495 | torch.stack(attention_mask), torch.tensor(ans_labels), \
496 | torch.stack(prop_labels), torch.tensor(entity_labels), \
497 | torch.tensor(efficiency_list, dtype=torch.float)
498 |
499 |
500 | def binary_collate_fn(batch_data):
501 | """
502 | 用于二类训练的collate函数
503 | """
504 | input_ids, token_type_ids, attention_mask = [], [], []
505 | labels = []
506 | efficiency_list = []
507 | for instance in copy.deepcopy(batch_data):
508 | input_ids.append(instance['token']['input_ids'][0].squeeze(0))
509 | token_type_ids.append(instance['token']['token_type_ids'][0].squeeze(0))
510 | attention_mask.append(instance['token']['attention_mask'][0].squeeze(0))
511 | labels.append(instance['label'])
512 | efficiency_list.append(instance['efficiency'])
513 | return torch.stack(input_ids), torch.stack(token_type_ids), \
514 | torch.stack(attention_mask), torch.tensor(labels), \
515 | torch.tensor(efficiency_list, dtype=torch.float)
516 |
517 |
518 | class LabelHub(object):
519 | """
520 | 分类任务的Label数据中心
521 | """
522 | def __init__(self, label_file):
523 | super(LabelHub, self).__init__()
524 | self.label_file = label_file
525 | self.ans_label2id = {}
526 | self.ans_id2label = {}
527 | self.prop_label2id = {}
528 | self.prop_id2label = {}
529 | self.entity_label2id = {}
530 | self.entity_id2label = {}
531 | self.binary_label2id = {}
532 | self.binary_id2label = {}
533 | self.load_label()
534 |
535 | def load_label(self):
536 | with open(self.label_file, 'r', encoding='utf-8') as f:
537 | label_dict = json.load(f)
538 | self.ans_label2id = label_dict['ans_type']
539 | self.prop_label2id = label_dict['main_property']
540 | self.entity_label2id = label_dict['entity']
541 | self.binary_label2id = label_dict['binary_type']
542 | for k, v in self.ans_label2id.items():
543 | self.ans_id2label[v] = k
544 | for k, v in self.prop_label2id.items():
545 | self.prop_id2label[v] = k
546 | for k, v in self.entity_label2id.items():
547 | self.entity_id2label[v] = k
548 | for k, v in self.binary_label2id.items():
549 | self.binary_id2label[v] = k
550 |
551 |
552 | def create_test_BIEO(excel_path, test = True):
553 | """
554 | 由excel生成对应的bieo标注,并保存,同时测试标注方法的准确性
555 | @param excel_path: 数据路径
556 | @return:
557 | """
558 | def get_bieo(data):
559 | '''
560 | 得到bmes
561 | :param data: train_denoised.xlsx的dataframe
562 | :return: dict ['question': ['O','B','E']]
563 | '''
564 | # TODO
565 | # 标注失败:最便宜最优惠的标注 应该为 价格1;
566 | # 三十、四十、 十元(先改三十元)、 六块、 一个月
567 | # 有效期太少了
568 | type_map = {'价格': 'PRICE', '流量': 'FLOW', '子业务': 'SERVICE'}
569 | result_dict = {}
570 |
571 | for index, row in data.iterrows():
572 | question = row['用户问题']
573 | char_list = ['O' for _ in range(len(question))]
574 |
575 | constraint_names = row['约束属性名']
576 | constraint_values = row['约束属性值']
577 | constraint_names_list = re.split(r'[|\|]', str(constraint_names))
578 | constraint_values_list = re.split(r'[|\|]', str(constraint_values))
579 | constraint_names_list = [name.strip() for name in constraint_names_list]
580 | constraint_values_list = [value.strip() for value in constraint_values_list]
581 | question_len = len(question)
582 | question_index = 0
583 | # 在句子中标注constraint
584 | for cons_index in range(len(constraint_values_list)):
585 | name = constraint_names_list[cons_index]
586 | if name == '有效期': continue
587 | value = constraint_values_list[cons_index]
588 | if value in question[question_index:]:
589 | temp_index = question[question_index:].find(value) + question_index
590 | if len(value) == 1:
591 | char_list[temp_index] = 'S-' + type_map[name]
592 | continue
593 | else:
594 | for temp_i in range(temp_index + 1, temp_index + len(value) - 1):
595 | char_list[temp_i] = 'I-' + type_map[name]
596 | char_list[temp_index] = 'B-' + type_map[name]
597 | char_list[temp_index + len(value) - 1] = 'E-' + type_map[name]
598 | question_index = min(temp_index + len(value), question_len)
599 | elif value in question:
600 | temp_index = question.find(value)
601 | if len(value) == 1:
602 | char_list[temp_index] = 'S-' + type_map[name]
603 | continue
604 | else:
605 | for temp_i in range(temp_index + 1, temp_index + len(value) - 1):
606 | char_list[temp_i] = 'I-' + type_map[name]
607 | char_list[temp_index] = 'B-' + type_map[name]
608 | char_list[temp_index + len(value) - 1] = 'E-' + type_map[name]
609 | result_dict[question] = char_list
610 |
611 | return result_dict
612 |
613 | def test_bieo(excel_data):
614 | '''
615 | 用来测试正则化生成的bieo相比于excel中的真值,准确度如何
616 | :param: excel中读取的dataframe
617 | :return:
618 | '''
619 |
620 | annos = get_bieo(excel_data)
621 |
622 | TP, TN, FP = 0, 0, 0
623 |
624 | for index, row in excel_data.iterrows():
625 | question = row['用户问题']
626 | # 获取约束的标注
627 | anno_list = annos[question]
628 | anno_dict = get_anno_dict(question, anno_list)
629 | # 获取约束的真值
630 | constraint_names = row['约束属性名']
631 | constraint_values = row['约束属性值']
632 | constraint_names_list = re.split(r'[|\|]', str(constraint_names))
633 | constraint_values_list = re.split(r'[|\|]', str(constraint_values))
634 | constraint_dict = {}
635 | for constraint_index in range(len(constraint_names_list)):
636 | constraint_name = constraint_names_list[constraint_index].strip()
637 | if constraint_name == '有效期': continue
638 | constraint_value = constraint_values_list[constraint_index].strip()
639 | if constraint_name not in constraint_dict:
640 | constraint_dict[constraint_name] = []
641 | constraint_dict[constraint_name].append(constraint_value)
642 | # 比较约束的真值和标注
643 | tp = 0
644 | anno_kv = []
645 | constraint_kv = []
646 | for k, vs in anno_dict.items():
647 | for v in vs:
648 | anno_kv.append(k + v)
649 | for k, vs in constraint_dict.items():
650 | for v in vs:
651 | constraint_kv.append(k + v)
652 | # 排除 二者均为空 的情况和 比较句 的情况
653 | if len(anno_kv) == 0 and constraint_kv[0] == 'nannan': continue
654 | if len(anno_kv) == 0 and (constraint_kv[0] == '价格1' or constraint_kv[0] == '流量1'): continue
655 |
656 | anno_len = len(anno_kv)
657 | cons_len = len(constraint_kv)
658 | for kv in constraint_kv:
659 | if kv in anno_kv:
660 | tp += 1
661 | anno_kv.remove(kv)
662 | if tp != cons_len:
663 | print('-------')
664 | print(question)
665 | print('anno: ', anno_kv)
666 | print('cons: ', constraint_kv)
667 | TP += tp
668 | FP += (cons_len - tp)
669 | TN += (anno_len - tp)
670 | print('测试bmes结果:' + 'TP: {} FP: {} TN:{} '.format(TP, FP, TN))
671 |
672 | def add_lost_anno(bieo_dict):
673 | from triples import KnowledgeGraph
674 |
675 | kg = KnowledgeGraph('../data/process_data/triples.rdf')
676 | df = pd.read_excel('../data/raw_data/train_denoised.xlsx')
677 | df.fillna('')
678 | id_list = set()
679 | ans_list = []
680 | for iter, row in df.iterrows():
681 | ans_true = list(set(row['答案'].split('|')))
682 | question = row['用户问题']
683 | ans_type = row['答案类型']
684 | # 只对属性值的句子做处理
685 | if ans_type != '属性值':
686 | continue
687 | entity = row['实体']
688 | main_property = row['属性名'].split('|')
689 | # 排除属性中没有'-'的情况,只要'档位介绍-xx'的情况
690 | if '-' not in main_property[0]:
691 | continue
692 | operator = row['约束算子']
693 | # 排除operator为min或max的情况
694 | if operator != 'min' and operator != 'max':
695 | operator == 'other'
696 | else:
697 | continue
698 | sub_properties = {}
699 | cons_names = str(row['约束属性名']).split('|')
700 | cons_values = str(row['约束属性值']).split('|')
701 | if cons_names == ['nan']: cons_names = []
702 | for index in range(len(cons_names)):
703 | if cons_names[index] not in sub_properties:
704 | sub_properties[cons_names[index]] = []
705 | sub_properties[cons_names[index]].append(cons_values[index])
706 | price_ans, flow_ans, service_ans = kg.fetch_wrong_ans(question, ans_type, entity, main_property, operator,
707 | [])
708 | rdf_properties = {}
709 | rdf_properties['价格'] = price_ans
710 | rdf_properties['流量'] = flow_ans
711 | rdf_properties['子业务'] = service_ans
712 | compare_result = []
713 | for name, values in rdf_properties.items():
714 | for value in values:
715 | if value in question:
716 | if name in sub_properties and value in sub_properties[name]:
717 | continue
718 | elif name in sub_properties:
719 | if value == '年包' and '半年包' in sub_properties[name]:
720 | continue
721 | if value == '百度' and '百度包' in sub_properties[name]:
722 | continue
723 | elif value in entity:
724 | continue
725 | else:
726 | compare_result.append(name + '_' + value)
727 | id_list.add(iter)
728 | if compare_result != []:
729 | ans_list.append(compare_result)
730 |
731 | raw_data = pd.read_excel(excel_path)
732 |
733 | if test:
734 | test_bieo(raw_data)
735 | bieo_dict = get_bieo(raw_data)
736 |
737 | with open(r'../data/file/train_bieo.json', 'w') as f:
738 | json.dump(bieo_dict, f, indent=2, ensure_ascii=False)
739 |
740 | return bieo_dict
741 |
742 |
743 | def create_test_BIO(excel_path, test = True):
744 | """
745 | 由excel生成对应的BIO标注,并保存,同时测试标注方法的准确性
746 | @param excel_path: 数据路径
747 | @return:
748 | """
749 | def get_bio(data):
750 | '''
751 | 得到bmes
752 | :param data: train_denoised.xlsx的dataframe
753 | :return: dict ['question': ['O','B','I']]
754 | '''
755 | # TODO
756 | # 标注失败:最便宜最优惠的标注 应该为 价格1;
757 | # 三十、四十、 十元(先改三十元)、 六块、 一个月
758 | # 有效期太少了 可用规则处理
759 | type_map = {'价格': 'PRICE', '流量': 'FLOW', '子业务': 'SERVICE'}
760 | result_dict = {}
761 |
762 | for index, row in data.iterrows():
763 | question = row['用户问题']
764 | char_list = ['O' for _ in range(len(question))]
765 |
766 | constraint_names = row['约束属性名']
767 | constraint_values = row['约束属性值']
768 | constraint_names_list = re.split(r'[|\|]', str(constraint_names))
769 | constraint_values_list = re.split(r'[|\|]', str(constraint_values))
770 | constraint_names_list = [name.strip() for name in constraint_names_list]
771 | constraint_values_list = [value.strip() for value in constraint_values_list]
772 | question_len = len(question)
773 | question_index = 0
774 | # 在句子中标注constraint
775 | for cons_index in range(len(constraint_values_list)):
776 | name = constraint_names_list[cons_index]
777 | if name == '有效期':
778 | continue
779 | value = constraint_values_list[cons_index]
780 | if value in question[question_index:]:
781 | temp_index = question[question_index:].find(value) + question_index
782 | char_list[temp_index] = 'B-' + type_map[name]
783 | for temp_i in range(temp_index + 1, temp_index + len(value)):
784 | char_list[temp_i] = 'I-' + type_map[name]
785 | question_index = min(temp_index + len(value), question_len)
786 | elif value in question:
787 | temp_index = question.find(value)
788 | if char_list[temp_index] == 'O':
789 | char_list[temp_index] = 'B-' + type_map[name]
790 | for temp_i in range(temp_index + 1, temp_index + len(value)):
791 | if char_list[temp_i] == 'O':
792 | char_list[temp_i] = 'I-' + type_map[name]
793 | else:
794 | print('标注冲突:"{}"'.format(question))
795 | break
796 | else:
797 | print('标注冲突。"{}"'.format(question))
798 | result_dict[question] = char_list
799 |
800 | return result_dict
801 |
802 | def test_bio(excel_data):
803 | '''
804 | 用来测试正则化生成的BIO相比于excel中的真值,准确度如何
805 | :param: excel中读取的dataframe
806 | :return:
807 | '''
808 |
809 | annos = get_bio(excel_data)
810 | TP, TN, FP = 0, 0, 0
811 | for index, row in excel_data.iterrows():
812 | question = row['用户问题']
813 | # 获取约束的标注
814 | anno_list = annos[question]
815 | anno_dict = get_anno_dict(question, anno_list)
816 | # 获取约束的真值
817 | constraint_names = row['约束属性名']
818 | constraint_values = row['约束属性值']
819 | constraint_names_list = re.split(r'[|\|]', str(constraint_names))
820 | constraint_values_list = re.split(r'[|\|]', str(constraint_values))
821 | constraint_dict = {}
822 | for constraint_index in range(len(constraint_names_list)):
823 | constraint_name = constraint_names_list[constraint_index].strip()
824 | if constraint_name == '有效期':
825 | continue
826 | constraint_value = constraint_values_list[constraint_index].strip()
827 | if constraint_name not in constraint_dict:
828 | constraint_dict[constraint_name] = []
829 | constraint_dict[constraint_name].append(constraint_value)
830 | # 比较约束的真值和标注
831 | tp = 0
832 | anno_kv = []
833 | constraint_kv = []
834 | for k, vs in anno_dict.items():
835 | for v in vs:
836 | anno_kv.append(k + v)
837 | for k, vs in constraint_dict.items():
838 | for v in vs:
839 | constraint_kv.append(k + v)
840 | # 排除 二者均为空 的情况和 比较句 的情况
841 | if len(anno_kv) == 0 and constraint_kv[0] == 'nannan':
842 | continue
843 | if len(anno_kv) == 0 and (constraint_kv[0] == '价格1' or constraint_kv[0] == '流量1'):
844 | continue
845 |
846 | anno_len = len(anno_kv)
847 | cons_len = len(constraint_kv)
848 | for kv in constraint_kv:
849 | if kv in anno_kv:
850 | tp += 1
851 | anno_kv.remove(kv)
852 | if tp != cons_len:
853 | print('-------')
854 | print(question)
855 | print('anno: ', anno_kv)
856 | print('cons: ', constraint_kv)
857 | TP += tp
858 | FP += (cons_len - tp)
859 | TN += (anno_len - tp)
860 | print('测试bmes结果:' + 'TP: {} FP: {} TN:{} '.format(TP, FP, TN))
861 |
862 | raw_data = pd.read_excel(excel_path)
863 |
864 | if test:
865 | test_bio(raw_data)
866 | bio_dict = get_bio(raw_data)
867 |
868 | with open(r'../data/file/train_bio.json', 'w') as f:
869 | json.dump(bio_dict, f, ensure_ascii=False, indent=2)
870 | return bio_dict
871 |
872 |
873 | if __name__ == '__main__':
874 | # ----------------自动创建NER标注文件--------------------
875 | create_test_BIEO(excel_path='../data/raw_data/train_denoised.xlsx')
876 | create_test_BIO(excel_path='../data/raw_data/train_denoised.xlsx')
877 |
--------------------------------------------------------------------------------
/code/data_argument.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.filterwarnings("ignore")
3 | from utils import setup_seed, get_time_dif, view_gpu_info
4 | import pandas as pd
5 | from copy import deepcopy
6 | import math
7 | import random
8 | from tqdm import tqdm
9 | import numpy as np
10 | import json
11 | import re
12 |
13 |
14 | def denoising(source_file=r'../data/raw_data/train.xlsx', target_file=r'../data/raw_data/train_denoised.xlsx'):
15 | """
16 | 对原数据 train.xlsx 进行去噪,并保存为 train_denoised.xlsx
17 | 主要进行了以下处理:1 去除多余空格 2 统一'|' 3 字母转换为小写 4 去除约束算子 5 float(nan)转换为''
18 | @return:
19 | """
20 | def process_field(field):
21 | """
22 | 对字段进行处理: 1 去掉空格 2 统一'|' 3 全部用小写
23 | @param field
24 | @return: field
25 | """
26 | field = field.replace(' ','').replace('|', '|')
27 | field = field.lower()
28 | return field
29 |
30 | def question_replace(question):
31 | """
32 | 对问题进行去噪
33 | :param question:
34 | :return:
35 | """
36 | question = question.replace('二十', '20')
37 | question = question.replace('三十', '30')
38 | question = question.replace('四十', '40')
39 | question = question.replace('五十', '50')
40 | question = question.replace('六十', '60')
41 | question = question.replace('七十', '70')
42 | question = question.replace('八十', '80')
43 | question = question.replace('九十', '90')
44 | question = question.replace('一百', '100')
45 | question = question.replace('十块', '10块')
46 | question = question.replace('十元', '10元')
47 | question = question.replace('六块', '6块')
48 | question = question.replace('一个月', '1个月')
49 | question = question.replace('2O', '20')
50 | if '一元一个g' not in question:
51 | question = question.replace('一元', '1元')
52 | if 'train' in source_file:
53 | question = question.replace(' ', '_')
54 | else:
55 | question = question.replace(' ', '')
56 | question = question.lower()
57 | return question
58 | raw_data = pd.read_excel(source_file)
59 | # 训练数据
60 | if 'train' in source_file:
61 | raw_data['有效率'] = [1] * len(raw_data)
62 | for index, row in raw_data.iterrows():
63 | row['用户问题'] = question_replace(row['用户问题'])
64 | # TODO 检查是否有效
65 | # if row['实体'] == '畅享套餐促销优惠':
66 | # row['实体'] = '畅享套餐促销优惠活动'
67 | for index_row in range(len(row)):
68 | field = str(row.iloc[index_row])
69 | if field == 'nan':
70 | field = ''
71 | row.iloc[index_row] = process_field(field)
72 | if not (row['约束算子'] == 'min' or row['约束算子'] == 'max'):
73 | row['约束算子'] = ''
74 | raw_data.iloc[index] = row
75 | # 测试数据
76 | else:
77 | length = []
78 | columns = raw_data.columns
79 | for column in columns:
80 | length.append(len(str(raw_data.iloc[1].at[column])))
81 | max_id = np.argmax(length)
82 | for index, row in raw_data.iterrows():
83 | # raw_data.loc[index, 'query'] = question_replace(row['query'])
84 | raw_data.loc[index, columns[max_id]] = question_replace(row[columns[max_id]])
85 | # print(raw_data)
86 | raw_data.to_excel(target_file, index=False)
87 |
88 |
89 | def read_synonyms():
90 | """
91 | 读取synonyms.txt 并转换成dict
92 | @return:
93 | """
94 | synonyms_dict = {}
95 | with open(r'../data/raw_data/synonyms.txt', 'r') as f:
96 | lines = f.readlines()
97 | for line in lines:
98 | line = line.strip('\n')
99 | ent, syn_str = line.split()
100 | syns = syn_str.split('|')
101 | synonyms_dict[ent] = syns
102 |
103 | return synonyms_dict
104 |
105 |
106 | def create_argument_data(synonyms_dict):
107 | """
108 | 对数据进行同义词替换,并保存为新的xlsx
109 | @param synonyms_dict: 来自synonyms.txt
110 | @return:
111 | """
112 | raw_data = pd.read_excel(r'../data/raw_data/train_denoised.xlsx')
113 | new_data = []
114 | for index, row in raw_data.iterrows():
115 | question = row['用户问题']
116 | for k, vs in synonyms_dict.items():
117 | if k in question:
118 | if '无' in vs:
119 | continue
120 | for v in vs:
121 | row_temp = deepcopy(row)
122 | row_temp['用户问题'] = question.replace(k, v)
123 | new_data.append(row_temp)
124 | new_data = pd.DataFrame(new_data)
125 | save_data = pd.concat([raw_data, new_data], ignore_index=True)
126 | print(save_data)
127 | save_data.to_excel(r'../data/raw_data/train_syn.xlsx', index=False)
128 |
129 |
130 | def augment_for_few_data(source_file, target_file, t=20, efficiency=0.8):
131 | """
132 | 扩充样本量少于t的到t
133 | :return:
134 | """
135 | from nlpcda import Similarword, RandomDeleteChar, CharPositionExchange
136 |
137 | def augment2t_nlpcda(column2id, threshold=t):
138 | """
139 | 使用nlpcda包来扩充
140 | """
141 | from utils import setup_seed
142 | setup_seed(1)
143 | a1 = CharPositionExchange(create_num=2, change_rate=0.1, seed=1)
144 | # a2 = RandomDeleteChar(create_num=2, change_rate=0.1, seed=1)
145 | a3 = Similarword(create_num=2, change_rate=0.1, seed=1)
146 | for k, v in tqdm(column2id.items()):
147 | if len(v) < threshold:
148 | p = int(math.ceil(threshold / len(v) * 1.0))
149 | for row_id in v:
150 | aug_list = []
151 | row = raw_data.loc[row_id]
152 | question = str(row['用户问题']).strip()
153 | while len(aug_list) < p:
154 | aug_list += a1.replace(question)[1:]
155 | # aug_list += a2.replace(question)[1:]
156 | aug_list += a3.replace(question)[1:]
157 | aug_list = list(set(aug_list))
158 | for a_q in sorted(aug_list):
159 | copy_row = deepcopy(row)
160 | copy_row['用户问题'] = a_q
161 | target_data.loc[len(target_data)] = copy_row
162 | # 为并列句去除逗号增强
163 | if k == '并列句':
164 | for row_id in v:
165 | row = raw_data.loc[row_id]
166 | question = str(row['用户问题']).strip()
167 | if ',' in question:
168 | question = question.replace(',', '')
169 | copy_row = deepcopy(row)
170 | copy_row['用户问题'] = question
171 | target_data.loc[len(target_data)] = copy_row
172 | target_data['有效率'] = [efficiency] * len(target_data)
173 | print('开始扩充少量数据...目前阈值为{}'.format(t))
174 | ans2id, prop2id, entity2id = {}, {}, {}
175 | raw_data = pd.read_excel(source_file)
176 | target_data = raw_data.drop(raw_data.index)
177 | for idx, row in raw_data.iterrows():
178 | # 答案类型
179 | if row['答案类型'] not in ans2id:
180 | ans2id[row['答案类型']] = set()
181 | ans2id[row['答案类型']].add(idx)
182 | # 属性名
183 | prop_list = row['属性名'].split('|')
184 | for prop in prop_list:
185 | if prop not in prop2id:
186 | prop2id[prop] = set()
187 | prop2id[prop].add(idx)
188 | # 实体
189 | if row['实体'] not in entity2id:
190 | entity2id[row['实体']] = set()
191 | entity2id[row['实体']].add(idx)
192 | # 添加
193 | # augment_to_t(ans2id)
194 | # augment_to_t(prop2id)
195 | # augment_to_t(entity2id)
196 | augment2t_nlpcda(ans2id)
197 | augment2t_nlpcda(prop2id)
198 | augment2t_nlpcda(entity2id)
199 | target_data.to_excel(target_file, index=False)
200 |
201 |
202 | def augment_for_synonyms(source_file, target_file):
203 | """
204 | 根据一定的规则,有选择的将同义词替换成原句子
205 | :param source_file:
206 | :param target_file:
207 | :return:
208 | """
209 | print('开始进行同义词增强...')
210 | with open('../data/file/entity_count.json', 'r', encoding='utf-8') as f:
211 | entity_count = json.load(f)
212 | with open('../data/file/synonyms.json', 'r', encoding='utf-8') as f:
213 | synonym_dict = json.load(f)
214 | entity2synonym = synonym_dict['entity2synonym']
215 | raw_data = pd.read_excel(source_file)
216 | target_data = raw_data.drop(raw_data.index)
217 | random.seed(1)
218 | for idx, row in tqdm(raw_data.iterrows(), total=len(raw_data)):
219 | question = row['用户问题']
220 | entity = row['实体']
221 | if entity in entity2synonym:
222 | synonym_list = sorted(entity2synonym[entity] + [entity], key=lambda item: len(item), reverse=True)
223 | contain_word = ''
224 | for s in synonym_list:
225 | if s in question:
226 | contain_word = s
227 | break
228 | if contain_word != '':
229 | synonym_list.remove(contain_word)
230 | rest_word = synonym_list
231 | # 根据entity的样本量来进行按概率采样,防止增强的样本量太大
232 | # 同义词增强,本质上还是需要侧重考虑少样本
233 | count = entity_count[entity]
234 | alpha = 1
235 | # 不需要修改rest_word
236 | if count <= 30:
237 | pass
238 | elif 30 < count <= 50:
239 | rest_word = random.sample(rest_word, int(math.ceil(len(rest_word)*0.5)))
240 | elif 50 < count <= 100:
241 | rest_word = random.sample(rest_word, int(math.ceil(len(rest_word)*0.2)))
242 | else:
243 | if random.random() < 0.5:
244 | rest_word = random.sample(rest_word, k=1)
245 | else:
246 | rest_word = []
247 | for s in rest_word:
248 | copy_row = deepcopy(row)
249 | copy_row['用户问题'] = question.replace(contain_word, s)
250 | target_data.loc[len(target_data)] = copy_row
251 | target_data['有效率'] = [0.99] * len(target_data)
252 | target_data.to_excel(target_file, index=False)
253 |
254 |
255 | def augment_from_simbert(source_file, target_file,
256 | model_file='/data2/hezhenfeng/other_model_files/chinese_simbert_L-6_H-384_A-12',
257 | gpu_id=0,
258 | start_line=0,
259 | end_line=5000,
260 | efficiency=0.95):
261 | """
262 | 利用Simbert生成相似句,取相似度大于95%的句子
263 | :param source_file:
264 | :param target_file:
265 | :param model_file:
266 | :param start_line: 起始行号
267 | :param end_line: 结束行号
268 | :param gpu_id:
269 | :param efficiency:
270 | :return:
271 | """
272 | from nlpcda import Simbert
273 | print('加载simbert模型...')
274 | config = {
275 | 'model_path': model_file,
276 | 'CUDA_VISIBLE_DEVICES': '{}'.format(gpu_id),
277 | 'max_len': 40,
278 | 'seed': 1,
279 | 'device': 'cuda',
280 | 'threshold': efficiency
281 | }
282 | simbert = Simbert(config=config)
283 | raw_data = pd.read_excel(source_file)
284 | target_data = raw_data.drop(raw_data.index)
285 | for idx, row in tqdm(raw_data.iterrows(), total=len(raw_data)):
286 | if start_line <= idx < end_line:
287 | # 用pandas自带的str类的数据会无法复现结果
288 | synonyms = simbert.replace(sent=str(row['用户问题']).strip(), create_num=5)
289 | for synonym, similarity in synonyms:
290 | if similarity >= config['threshold']:
291 | copy_row = deepcopy(row)
292 | copy_row['用户问题'] = synonym
293 | target_data.loc[len(target_data)] = copy_row
294 | else:
295 | break
296 | elif idx >= end_line:
297 | break
298 | target_data['有效率'] = [efficiency] * len(target_data)
299 | target_data.to_excel(target_file, index=False)
300 |
301 |
302 | def function1(args):
303 | """
304 | simbert多进程增强的子进程
305 | :param args:
306 | :return:
307 | """
308 | idx, gpu_id, start, end = args['idx'], args['gpu_id'], args['start'], args['end']
309 | augment_from_simbert(source_file='../data/raw_data/train_denoised.xlsx',
310 | target_file=f'../data/raw_data/train_augment_simbert_{idx}.xlsx',
311 | efficiency=0.95,
312 | start_line=start,
313 | end_line=end,
314 | gpu_id=gpu_id)
315 |
316 |
317 | def run_process():
318 | """
319 | simbert多进程增强的子进程
320 | :return:
321 | """
322 | from multiprocessing import Pool
323 | args = []
324 | # 初始设置为1
325 | cpu_worker_num = 1
326 | span = 5000//cpu_worker_num
327 | start = 0
328 | end = span
329 | for i in range(cpu_worker_num):
330 | args.append({'idx': i, 'gpu_id': 0, 'start': start, 'end': end})
331 | start = end
332 | end += span
333 | if i == cpu_worker_num-2:
334 | end = 5000
335 | with Pool(cpu_worker_num) as p:
336 | p.map(function1, args)
337 | print('生成完毕,开始合并结果...')
338 | df = None
339 | for i in range(len(args)):
340 | temp_df = pd.read_excel(f'../data/raw_data/train_augment_simbert_{i}.xlsx')
341 | if i == 0:
342 | df = temp_df
343 | else:
344 | df = pd.concat([df, temp_df], ignore_index=True)
345 | df.to_excel('../data/raw_data/train_augment_simbert.xlsx', index=False)
346 |
347 |
348 | def multi_process_augment_from_simbert():
349 | """
350 | simbert用多进程来完成
351 | :return:
352 | """
353 | import time
354 | start_time = time.time()
355 | run_process()
356 | print('生成时间为:', get_time_dif(start_time))
357 |
358 |
359 | def augment_for_binary(source_file, target_file):
360 | """
361 | 同义词增强以及随机去除标点符号
362 | :return:
363 | """
364 | from utils import rm_symbol
365 | print('开始进行开通方式、条件数据增强...')
366 | with open('../data/file/synonyms.json', 'r', encoding='utf-8') as f:
367 | synonym_dict = json.load(f)
368 | synonym2entity = synonym_dict['synonym2entity']
369 | for entity in set(synonym2entity.values()):
370 | synonym2entity[entity] = entity
371 | synonym_list = sorted(list(synonym2entity.keys()), key=lambda item: [len(item), item], reverse=True)
372 | raw_data = pd.read_excel(source_file)
373 | target_data = raw_data.drop(raw_data.index)
374 | random.seed(1)
375 | neg_word = ('取消掉', '不需要', '不用', '不想', '不要', '什么时候可以', '可以取消', '可以退订')
376 | for idx, row in tqdm(raw_data.iterrows(), total=len(raw_data)):
377 | copy_row = deepcopy(row)
378 | question_text = str(copy_row['用户问题'])
379 | p = random.random()
380 | if '开通' in row['属性名']:
381 | # 删除这些样本
382 | if question_text in ('20元20g还可以办理吗', ):
383 | continue
384 | # 随机删除逗号
385 | if p > 0.5 and ',' in row['用户问题']:
386 | copy_row['用户问题'] = rm_symbol(question_text)
387 | copy_row['有效率'] = 0.99
388 | target_data.loc[len(target_data)] = copy_row
389 | else:
390 | for synonym in synonym_list:
391 | if synonym in question_text:
392 | rw = random.choice(synonym_list)
393 | if rw != synonym:
394 | question_text.replace(synonym, rw)
395 | break
396 | copy_row['有效率'] = 0.99
397 | copy_row['用户问题'] = question_text
398 | target_data.loc[len(target_data)] = copy_row
399 | elif '取消' in row['属性名']:
400 | flag = False
401 | for n_w in neg_word:
402 | if n_w in question_text:
403 | flag = True
404 | break
405 | if flag:
406 | continue
407 | if '退订' in question_text or '取消' in question_text:
408 | replace_word = '办理' if p >= 0.5 else '开通'
409 | question_text = question_text.replace('退订', replace_word)
410 | question_text = question_text.replace('取消', replace_word)
411 | copy_row['用户问题'] = question_text
412 | copy_row['属性名'] = str(copy_row['属性名']).replace('取消', '开通')
413 | copy_row['有效率'] = 0.9
414 | target_data.loc[len(target_data)] = copy_row
415 | target_data.to_excel(target_file, index=False)
416 |
417 |
418 | def augment_for_ner(source_file, target_file):
419 | synonyms_dict = read_synonyms()
420 | df = pd.read_excel(source_file)
421 | df_syn1 = pd.DataFrame(columns = df.columns)
422 | df_syn2 = pd.DataFrame(columns = df.columns)
423 | import collections
424 | entity_dict = collections.defaultdict(dict)
425 | # 遍历得到对于每个实体,各个约束出现的次数
426 | for index, row in df.iterrows():
427 | if row['约束算子'] == 'min' or row['约束算子'] == 'max':
428 | continue
429 | constraint_names = row['约束属性名']
430 | constraint_values = row['约束属性值']
431 | constraint_names_list = re.split(r'[|\|]', str(constraint_names))
432 | constraint_values_list = re.split(r'[|\|]', str(constraint_values))
433 | constraint_names_list = [name.strip() for name in constraint_names_list]
434 | constraint_values_list = [value.strip() for value in constraint_values_list]
435 | for i in range(len(constraint_values_list)):
436 | name = constraint_names_list[i]
437 | value = constraint_values_list[i]
438 | if name == '有效期': continue
439 | if name == 'nan': continue
440 | if value == '流量套餐': continue
441 | if name + '_' + value not in entity_dict[row['实体']]:
442 | entity_dict[row['实体']][name + '_' + value] = []
443 | entity_dict[row['实体']][name + '_' + value].append(index)
444 | # 对约束次数少的实体进行同义词替换增强
445 | for entity, cons in entity_dict.items():
446 | # 喜马拉雅的ner增强会因错别字无法标注
447 | if entity == '喜马拉雅流量包': continue
448 | if entity not in synonyms_dict: continue
449 | for con, con_list in cons.items():
450 | if len(con_list) < 3:
451 | for index in con_list:
452 | question = df.iloc[index].at['用户问题']
453 | syn_curr = None
454 | syn_curr_list = []
455 | for syn in synonyms_dict[entity] + [entity]:
456 | if syn == '无': continue
457 | if syn in question:
458 | syn_curr_list.append(syn)
459 | syn_curr_list = sorted(syn_curr_list, key=lambda x:len(x))
460 | syn_curr = syn_curr_list[-1]
461 | for syn in synonyms_dict[entity] + [entity]:
462 | if syn == '无': continue
463 | new_row = df.iloc[index].copy()
464 | new_row.at['用户问题'] = question.replace(syn_curr, syn)
465 | df_syn1 = df_syn1.append(new_row, ignore_index=True)
466 | for index, row in df.iterrows():
467 | entity = row['实体']
468 | question = row['用户问题']
469 | if entity not in synonyms_dict: continue
470 | syn_curr = None
471 | syn_curr_list = []
472 | for syn in synonyms_dict[entity] + [entity]:
473 | if syn == '无': continue
474 | if syn in question:
475 | syn_curr_list.append(syn)
476 | if len(syn_curr_list) == 0: continue
477 | syn_curr_list = sorted(syn_curr_list, key=lambda x:len(x))
478 | syn_curr = syn_curr_list[-1]
479 | if any(char.isdigit() for char in syn_curr) or '年包' in syn_curr:
480 | df_syn2 = df_syn2.append(row.copy(), ignore_index=True)
481 | df_syn2 = df_syn2.append(row.copy(), ignore_index=True)
482 | df = df.append(df_syn1, ignore_index=True)
483 | df = df.append(df_syn2, ignore_index=True)
484 | df.to_excel(target_file, index=False)
485 |
486 |
487 | if __name__ == '__main__':
488 | # denoising(source_file='../data/raw_data/test.xlsx', target_file='../data/raw_data/test_denoised.xlsx')
489 | # synonyms_dict = read_synonyms()
490 | # create_argument_data(synonyms_dict)
491 | # setup_seed(1)
492 | # augment_for_few_data(source_file='../data/raw_data/train_denoised.xlsx',
493 | # target_file='../data/raw_data/train_augment_few_nlpcda.xlsx',
494 | # efficiency=0.8)
495 | # augment_for_synonyms(source_file='../data/raw_data/train_denoised.xlsx',
496 | # target_file='../data/raw_data/train_augment_synonyms_test2.xlsx')
497 | # multi_process_augment_from_simbert()
498 | # augment_from_simbert(source_file='../data/raw_data/train_denoised.xlsx',
499 | # target_file=f'../data/raw_data/train_augment_simbert.xlsx',
500 | # efficiency=0.95,
501 | # start_line=0,
502 | # end_line=5000,
503 | # gpu_id=8)
504 | augment_for_binary(source_file='../data/raw_data/train_denoised.xlsx',
505 | target_file='../data/raw_data/train_augment_binary.xlsx')
506 |
--------------------------------------------------------------------------------
/code/doc2unix.py:
--------------------------------------------------------------------------------
1 | if __name__ == '__main__':
2 | with open('./run.sh', 'r') as f:
3 | text = f.read()
4 | text.replace('\r', '')
5 |
6 | with open('./run.sh', 'w') as f:
7 | f.write(text)
8 |
--------------------------------------------------------------------------------
/code/docker_process.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | from data_argument import *
5 | from statistic import *
6 | from utils import *
7 |
8 | def prepare_dir():
9 | '''
10 | 用于生成存放数据的文件夹
11 | @return:
12 | '''
13 | if not os.path.exists('/data'):
14 | os.mkdir('/data')
15 | if not os.path.exists('/data/raw_data'):
16 | os.mkdir('/data/raw_data')
17 | if not os.path.exists('/data/process_data'):
18 | os.mkdir('/data/process_data')
19 | if not os.path.exists('/data/file'):
20 | os.mkdir('/data/file')
21 | if not os.path.exists('/data/dataset'):
22 | os.mkdir('/data/dataset')
23 | if not os.path.exists('/data/Log'):
24 | os.mkdir('/data/Log')
25 | if not os.path.exists('/data/trained_model'):
26 | os.mkdir('/data/trained_model')
27 | if not os.path.exists('/data/results'):
28 | os.mkdir('/data/results')
29 |
30 |
31 | def prepare_copy_file():
32 | '''
33 | 为统一路径,将天池数据/tcdata中的内容复制到/data/raw_data中
34 | @return:
35 | '''
36 | if os.path.exists('/tcdata'):
37 | for file_name in os.listdir('/tcdata'):
38 | shutil.copy(os.path.join('/tcdata', file_name), '/data/raw_data')
39 |
40 | # # 查看是否复制成功
41 | # raw_data = os.listdir('/data/raw_data')
42 | # for file_name in os.listdir('/tcdata'):
43 | # if file_name not in raw_data:
44 | # print('复制失败: ', file_name)
45 |
46 |
47 | def prepare_data():
48 | '''
49 | 一系列文件生成工作
50 | @return:
51 | '''
52 | # 生成denoiseed.xlsx
53 | denoising(source_file='../data/raw_data/train.xlsx',
54 | target_file='../data/raw_data/train_denoised.xlsx')
55 | denoising(source_file='../tcdata/test2.xlsx',
56 | target_file='../data/raw_data/test_denoised.xlsx')
57 | # 生成统计文件:实体个数、同义词个数
58 | statistic_entity()
59 | statistic_synonyms()
60 | # entity_map()
61 | # 生成 数据增强 文件
62 | augment_for_few_data(source_file='../data/raw_data/train_denoised.xlsx',
63 | target_file='../data/raw_data/train_augment_few_nlpcda.xlsx')
64 | multi_process_augment_from_simbert()
65 | augment_for_synonyms(source_file='../data/raw_data/train_denoised.xlsx',
66 | target_file='../data/raw_data/train_augment_synonyms.xlsx')
67 | # 生成标注好的训练数据
68 | make_dataset(['../data/raw_data/train_denoised.xlsx'],
69 | target_file='../data/dataset/cls_labeled.txt',
70 | label_file='../data/dataset/cls_label2id.json',
71 | train=True)
72 | # 将增强文件合并
73 | make_dataset(['../data/raw_data/train_augment_few_nlpcda.xlsx',
74 | '../data/raw_data/train_augment_simbert.xlsx',
75 | '../data/raw_data/train_augment_synonyms.xlsx'],
76 | target_file='../data/dataset/augment3.txt', # 修改此处应该修改run.sh中的文件
77 | label_file=None,
78 | train=True)
79 | make_dataset(['../data/raw_data/test_denoised.xlsx'],
80 | target_file='../data/dataset/cls_unlabeled.txt',
81 | label_file=None,
82 | train=False)
83 | # ner数据增强
84 | augment_for_ner(source_file='../data/raw_data/train_denoised.xlsx',
85 | target_file='../data/raw_data/train_denoised_ner.xlsx')
86 | # 制作二分类训练数据
87 | augment_for_binary(source_file='../data/raw_data/train_denoised.xlsx',
88 | target_file='../data/raw_data/train_augment_binary.xlsx')
89 | make_dataset_for_binary(['../data/raw_data/train_denoised.xlsx'],
90 | target_file='../data/dataset/binary_labeled.txt')
91 | make_dataset_for_binary(['../data/raw_data/train_augment_binary.xlsx'],
92 | target_file='../data/dataset/binary_augment3.txt')
93 | # 制作rdf文件
94 | parse_triples_file('../data/raw_data/triples.txt')
95 |
96 |
97 | if __name__ == '__main__':
98 | print('docker运行开始,时间为', get_time_str())
99 | print('---------------------开始process-----------------------')
100 | setup_seed(1) # todo
101 | prepare_dir()
102 | prepare_copy_file()
103 | prepare_data()
104 | print('---------------------process结束-----------------------')
105 | end = time.time()
106 |
107 |
--------------------------------------------------------------------------------
/code/module.py:
--------------------------------------------------------------------------------
1 | """
2 | @Author: hezf
3 | @Time: 2021/6/3 19:39
4 | @desc:
5 | """
6 | from sklearn.metrics import f1_score, jaccard_score, confusion_matrix
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch
10 | import numpy as np
11 | from typing import List
12 |
13 |
14 | class MutiTaskLoss(nn.Module):
15 | def __init__(self, use_efficiency=False):
16 | """
17 | :param use_efficiency: 样本有效率标志,需要搭配起forward中的efficiency参数
18 | """
19 | super(MutiTaskLoss, self).__init__()
20 | self.use_efficiency = use_efficiency
21 | if use_efficiency:
22 | self.ans_loss = nn.CrossEntropyLoss(reduction='none')
23 | else:
24 | self.ans_loss = nn.CrossEntropyLoss()
25 | self.prop_loss = FocalLoss(alpha=[0.759, 0.68, 0.9458, 0.9436, 0.9616, 0.915, 0.9618, 0.871, 0.8724, 0.994, 0.986, 0.994, 0.9986, 0.993, 0.991, 0.9968, 0.9986, 0.999, 0.9996],
26 | multi_label=True)
27 | # self.prop_loss = nn.BCELoss()
28 | if use_efficiency:
29 | self.entity_loss = nn.CrossEntropyLoss(reduction='none')
30 | else:
31 | self.entity_loss = nn.CrossEntropyLoss()
32 | # 可学习的权重
33 | self.loss_weight = nn.Parameter(torch.ones(3), requires_grad=True)
34 |
35 | def forward(self, ans_true, ans_pred, prop_true, prop_pred, entity_true, entity_pred, efficiency=None):
36 | # 使用样本有效率
37 | if self.use_efficiency:
38 | assert efficiency is not None, 'efficiency is None'
39 | batch_size = ans_true.shape[0]
40 | loss1 = self.ans_loss(ans_pred, ans_true).unsqueeze(0).mm(efficiency.unsqueeze(1)).squeeze(0)/batch_size
41 | loss2 = self.prop_loss(prop_pred, prop_true.float(), efficiency)
42 | loss3 = self.entity_loss(entity_pred, entity_true).unsqueeze(0).mm(efficiency.unsqueeze(1)).squeeze(0)/batch_size
43 | else:
44 | loss1 = self.ans_loss(ans_pred, ans_true)
45 | loss2 = self.prop_loss(prop_pred, prop_true.float())
46 | loss3 = self.entity_loss(entity_pred, entity_true)
47 | weight = self.loss_weight.softmax(dim=0)
48 | loss = weight[0] * loss1 + weight[1]*loss2 + weight[2]*loss3
49 | return loss
50 |
51 |
52 | class MutiTaskLossV1(nn.Module):
53 | def __init__(self, use_efficiency=False):
54 | """
55 | :param use_efficiency: 样本有效率标志,需要搭配起forward中的efficiency参数
56 | """
57 | super(MutiTaskLossV1, self).__init__()
58 | self.use_efficiency = use_efficiency
59 | if use_efficiency:
60 | self.ans_loss = nn.CrossEntropyLoss(reduction='none')
61 | self.entity_loss = nn.CrossEntropyLoss(reduction='none')
62 | self.method_loss = nn.CrossEntropyLoss(reduction='none')
63 | self.condition_loss = nn.CrossEntropyLoss(reduction='none')
64 | else:
65 | self.ans_loss = nn.CrossEntropyLoss()
66 | self.entity_loss = nn.CrossEntropyLoss()
67 | self.method_loss = nn.CrossEntropyLoss()
68 | self.condition_loss = nn.CrossEntropyLoss()
69 | self.prop_loss = FocalLoss(alpha=[0.759, 0.68, 0.9458, 0.9436, 0.9616, 0.915, 0.9618, 0.871, 0.8724, 0.994, 0.986, 0.994, 0.9986, 0.993, 0.991, 0.9968, 0.9986, 0.999, 0.9996],
70 | multi_label=True)
71 | # self.prop_loss = nn.BCELoss()
72 | # 可学习的权重
73 | self.loss_weight = nn.Parameter(torch.ones(4), requires_grad=True)
74 |
75 | def forward(self, ans_true, ans_pred, prop_true, prop_pred, entity_true, entity_pred, method_true, method_pred, condition_true, condition_pred, efficiency=None):
76 | # 使用样本有效率
77 | if self.use_efficiency:
78 | assert efficiency is not None, 'efficiency is None'
79 | batch_size = ans_true.shape[0]
80 | loss1 = self.ans_loss(ans_pred, ans_true).unsqueeze(0).mm(efficiency.unsqueeze(1)).squeeze(0)/batch_size
81 | loss2 = self.prop_loss(prop_pred, prop_true.float(), efficiency)
82 | loss3 = self.entity_loss(entity_pred, entity_true).unsqueeze(0).mm(efficiency.unsqueeze(1)).squeeze(0)/batch_size
83 | loss4 = self.method_loss(method_pred, method_true).unsqueeze(0).mm(efficiency.unsqueeze(1)).squeeze(0)/batch_size
84 | loss5 = self.condition_loss(condition_pred, condition_true).unsqueeze(0).mm(efficiency.unsqueeze(1)).squeeze(0)/batch_size
85 | else:
86 | loss1 = self.ans_loss(ans_pred, ans_true)
87 | loss2 = self.prop_loss(prop_pred, prop_true.float())
88 | loss3 = self.entity_loss(entity_pred, entity_true)
89 | loss4 = self.method_loss(method_pred, method_true)
90 | loss5 = self.condition_loss(condition_pred, condition_true)
91 | weight = self.loss_weight.softmax(dim=0)
92 | loss = weight[0] * loss1 + weight[1]*loss2 + weight[2]*loss3 + weight[3]*(loss4 + loss5)/2
93 | return loss
94 |
95 |
96 | class CrossEntropyWithEfficiency(nn.Module):
97 | def __init__(self):
98 | super(CrossEntropyWithEfficiency, self).__init__()
99 | self.loss = nn.CrossEntropyLoss(reduction='none')
100 |
101 | def forward(self, true, pred, efficiency):
102 | batch_size = true.shape[0]
103 | loss = self.loss(pred, true).unsqueeze(0).mm(efficiency.unsqueeze(1)).squeeze(0) / batch_size
104 | return loss
105 |
106 | # # rdrop 加入有效率前的备份
107 | # class RDropLoss(nn.Module):
108 | # def __init__(self, alpha=4):
109 | # super(RDropLoss, self).__init__()
110 | # self.ans_loss = nn.CrossEntropyLoss()
111 | # self.prop_loss = FocalLoss(alpha=[0.759, 0.68, 0.9458, 0.9436, 0.9616, 0.915, 0.9618, 0.871, 0.8724, 0.994, 0.986, 0.994, 0.9986, 0.993, 0.991, 0.9968, 0.9986, 0.999, 0.9996],
112 | # multi_label=True)
113 | # self.entity_loss = nn.CrossEntropyLoss()
114 | # self.kl = nn.KLDivLoss(reduction='batchmean')
115 | # self.alpha = alpha
116 | # # 可学习的权重
117 | # self.loss_weight = nn.Parameter(torch.ones(3), requires_grad=True)
118 | #
119 | # def r_drop(self, loss_func, pred: List, true, multi_label=True):
120 | # loss_0 = loss_func(pred[0], true)
121 | # loss_1 = loss_func(pred[1], true)
122 | # if multi_label:
123 | # kl_loss = (F.kl_div(pred[0].log(), pred[1], reduction='batchmean') + F.kl_div(pred[1].log(), pred[0], reduction='batchmean')) / 2
124 | # else:
125 | # kl_loss = (F.kl_div(F.log_softmax(pred[0], -1), F.softmax(pred[1], -1), reduction='batchmean') + F.kl_div(F.log_softmax(pred[1], -1), F.softmax(pred[0], -1), reduction='batchmean')) / 2
126 | # return loss_0 + loss_1 + self.alpha * kl_loss
127 | #
128 | # def forward(self, ans_true, ans_pred: List, prop_true, prop_pred: List, entity_true, entity_pred: List):
129 | # loss1 = self.r_drop(self.ans_loss, ans_pred, ans_true, multi_label=False)
130 | # # loss1 = (self.ans_loss(ans_pred[0], ans_true) + self.ans_loss(ans_pred[1], ans_true))/2 # prop_only
131 | # loss2 = self.r_drop(self.prop_loss, prop_pred, prop_true.float(), multi_label=True)
132 | # loss3 = self.r_drop(self.entity_loss, entity_pred, entity_true, multi_label=False)
133 | # # loss3 = (self.entity_loss(entity_pred[0], entity_true) + self.entity_loss(entity_pred[1], entity_true))/2 # prop_only
134 | # weight = self.loss_weight.softmax(dim=0)
135 | # loss = weight[0] * loss1 + weight[1]*loss2 + weight[2]*loss3
136 | # return loss
137 |
138 |
139 | class RDropLoss(nn.Module):
140 | def __init__(self, alpha=4):
141 | super(RDropLoss, self).__init__()
142 | self.ans_loss = nn.CrossEntropyLoss(reduction='none')
143 | self.prop_loss = FocalLoss(alpha=[0.759, 0.68, 0.9458, 0.9436, 0.9616, 0.915, 0.9618, 0.871, 0.8724, 0.994, 0.986, 0.994, 0.9986, 0.993, 0.991, 0.9968, 0.9986, 0.999, 0.9996],
144 | multi_label=True)
145 | self.entity_loss = nn.CrossEntropyLoss(reduction='none')
146 | self.kl = nn.KLDivLoss(reduction='batchmean')
147 | self.alpha = alpha
148 | # 可学习的权重
149 | self.loss_weight = nn.Parameter(torch.ones(3), requires_grad=True)
150 |
151 | def r_drop(self, loss_func, pred: List, true, efficiency, multi_label=True):
152 | if efficiency is None:
153 | efficiency = torch.tensor([1.0] * pred[0].shape[0])
154 | # 多标签时,是用focalloss
155 | if multi_label:
156 | loss_0 = loss_func(pred[0], true, efficiency)
157 | loss_1 = loss_func(pred[1], true, efficiency)
158 | else:
159 | batch_size = pred[0].shape[0]
160 | loss_0 = loss_func(pred[0], true).unsqueeze(0).mm(efficiency.unsqueeze(1)).squeeze(0)/batch_size
161 | loss_1 = loss_func(pred[1], true).unsqueeze(0).mm(efficiency.unsqueeze(1)).squeeze(0)/batch_size
162 | if multi_label:
163 | kl_loss = (F.kl_div(pred[0].log(), pred[1], reduction='batchmean') + F.kl_div(pred[1].log(), pred[0], reduction='batchmean')) / 2
164 | else:
165 | kl_loss = (F.kl_div(F.log_softmax(pred[0], -1), F.softmax(pred[1], -1), reduction='batchmean') + F.kl_div(F.log_softmax(pred[1], -1), F.softmax(pred[0], -1), reduction='batchmean')) / 2
166 | return loss_0 + loss_1 + self.alpha * kl_loss
167 |
168 | def forward(self, ans_true, ans_pred: List, prop_true, prop_pred: List, entity_true, entity_pred: List, efficiency):
169 | loss1 = self.r_drop(self.ans_loss, ans_pred, ans_true, efficiency, multi_label=False)
170 | loss2 = self.r_drop(self.prop_loss, prop_pred, prop_true.float(), efficiency, multi_label=True)
171 | loss3 = self.r_drop(self.entity_loss, entity_pred, entity_true, efficiency, multi_label=False)
172 | weight = self.loss_weight.softmax(dim=0)
173 | loss = weight[0] * loss1 + weight[1]*loss2 + weight[2]*loss3
174 | return loss
175 |
176 |
177 | class MutiTaskLossFocal(nn.Module):
178 | def __init__(self):
179 | super(MutiTaskLossFocal, self).__init__()
180 | self.ans_loss = FocalLoss()
181 | self.prop_loss = FocalLoss(alpha=[0.759, 0.68, 0.9458, 0.9436, 0.9616, 0.915, 0.9618, 0.871, 0.8724, 0.994, 0.986, 0.994, 0.9986, 0.993, 0.991, 0.9968, 0.9986, 0.999, 0.9996],
182 | multi_label=True)
183 | self.entity_loss = FocalLoss()
184 | # 可学习的权重
185 | self.loss_weight = nn.Parameter(torch.ones(3), requires_grad=True)
186 |
187 | def forward(self, ans_true, ans_pred, prop_true, prop_pred, entity_true, entity_pred):
188 | loss1 = self.ans_loss(ans_pred, ans_true)
189 | loss2 = self.prop_loss(prop_pred, prop_true.float())
190 | loss3 = self.entity_loss(entity_pred, entity_true)
191 | weight = self.loss_weight.softmax(dim=0)
192 | loss = weight[0] * loss1 + weight[1]*loss2 + weight[2]*loss3
193 | return loss
194 |
195 |
196 | class Metric(object):
197 | def __init__(self):
198 | super(Metric, self).__init__()
199 |
200 | def calculate(self, ans_pred, ans_true, prop_pred, prop_true, entity_pred, entity_true):
201 | ans_f1 = f1_score(ans_true, ans_pred, average='micro')
202 | # ans_matrix = confusion_matrix(ans_true, ans_pred)
203 | prop_j = jaccard_score(prop_true, prop_pred, average='micro')
204 | entity_f1 = f1_score(entity_true, entity_pred, average='micro')
205 | return ans_f1, prop_j, entity_f1
206 |
207 |
208 | class Attention(nn.Module):
209 | def __init__(self, hidden_size):
210 | super(Attention, self).__init__()
211 | self.linear = nn.Linear(hidden_size, 1)
212 |
213 | def forward(self, hidden_states: torch.Tensor, mask: torch.Tensor):
214 | x = self.linear(hidden_states).squeeze(-1)
215 | x = x.masked_fill(mask, -np.inf)
216 | attention_value = x.softmax(dim=-1).unsqueeze(1)
217 | x = torch.bmm(attention_value, hidden_states).squeeze(1)
218 | return x
219 |
220 |
221 | class FocalLoss(nn.Module):
222 | def __init__(self, gamma=2, alpha=1.0, size_average=True, multi_label=False):
223 | super(FocalLoss, self).__init__()
224 | self.gamma = gamma
225 | self.alpha = alpha
226 | if isinstance(self.alpha, list):
227 | self.alpha = torch.tensor(self.alpha)
228 | self.size_average = size_average
229 | self.multi_label = multi_label
230 |
231 | def forward(self, logits, labels, efficiency=None):
232 | """
233 | logits: batch_size * n_class
234 | labels: batch_size
235 | """
236 | batch_size, n_class = logits.shape[0], logits.shape[1]
237 | # 多分类分类
238 | if not self.multi_label:
239 | one_hots = torch.zeros([batch_size, n_class]).to(logits.device).scatter_(-1, labels.unsqueeze(-1), 1)
240 | p = torch.nn.functional.softmax(logits, dim=-1)
241 | log_p = torch.log(p)
242 | loss = - one_hots * (self.alpha * ((1 - p) ** self.gamma) * log_p)
243 | # 多标签分类
244 | else:
245 | p = logits
246 | pt = (labels - (1 - p)) * (2*labels-1)
247 | if isinstance(self.alpha, float):
248 | alpha_t = (labels - (1 - self.alpha)) * (2*labels-1)
249 | else:
250 | alpha_t = (labels - (1 - self.alpha.to(logits.device))) * (2*labels-1)
251 | loss = - alpha_t * ((1 - pt)**self.gamma) * torch.log(pt)
252 | # 加入有效率的计算
253 | if efficiency is not None:
254 | loss = torch.diag_embed(efficiency).mm(loss)
255 | if self.size_average:
256 | return loss.sum()/batch_size
257 | else:
258 | return loss.sum()
259 |
260 |
261 | class FGM(object):
262 | """
263 | 对抗攻击
264 | """
265 | def __init__(self, model: nn.Module, eps=1.):
266 | self.model = (
267 | model.module if hasattr(model, "module") else model
268 | )
269 | self.eps = eps
270 | self.backup = {}
271 |
272 | # only attack word embedding
273 | def attack(self, emb_name='word_embeddings'):
274 | for name, param in self.model.named_parameters():
275 | if param.requires_grad and emb_name in name:
276 | self.backup[name] = param.data.clone()
277 | norm = torch.norm(param.grad)
278 | if norm and not torch.isnan(norm):
279 | r_at = self.eps * param.grad / norm
280 | param.data.add_(r_at)
281 |
282 | def restore(self, emb_name='word_embeddings'):
283 | for name, para in self.model.named_parameters():
284 | if para.requires_grad and emb_name in name:
285 | assert name in self.backup
286 | para.data = self.backup[name]
287 |
288 | self.backup = {}
289 |
--------------------------------------------------------------------------------
/code/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | mode='multi'
3 | debug_status='no'
4 |
5 | if test $debug_status = 'yes'
6 | then
7 | batch_size1=16
8 | batch_size2=16
9 | batch_size3=16
10 | ptm_epoch=1
11 | ptm_epoch_binary=1
12 | train_epoch=1
13 | train_epoch_binary=1
14 | pt_file=augment3.txt
15 | elif test $debug_status = 'no'
16 | then
17 | batch_size1=350
18 | batch_size2=256
19 | batch_size3=190
20 | ptm_epoch=14
21 | ptm_epoch_binary=5
22 | train_epoch=5
23 | train_epoch_binary=5
24 | pt_file=augment3.txt
25 | fi
26 | # TODO 注意回车标志只能是LF
27 | # 单模型
28 | if test $mode = 'single'
29 | then
30 | python docker_process.py
31 | # BERT等模型的batch_size=350,XLNet的batch_size=256
32 | # GPT2batch_size=180暂定,学习率1e-4
33 | # bert+rdrop的batch_size=180
34 | python train.py -tt only_train -e $ptm_epoch -tf $pt_file -ptm roberta -d 0 -m roberta -b $batch_size1 -l 0.0002
35 | python train.py -tt kfold -e $train_epoch -tf cls_labeled.txt -ptm roberta -d 0 -m roberta_model -b $batch_size1 -l 0.00005 -tm pretrained_roberta.pth
36 | python train_evaluate_constraint.py -m ner_model.pth -d 0 -tv bieo -k 1
37 | # 修改模型时,这里模型名也要修改
38 | python inference.py -cm roberta_model.pth -nm ner_model.pth -tv bieo -k 1
39 | # 多模型
40 | elif test $mode = 'multi'
41 | then
42 | python docker_process.py
43 | # bert
44 | python train.py -tt only_train -e $ptm_epoch -tf $pt_file -ptm bert -d 0 -m bert -b $batch_size1 -l 0.0002
45 | python train.py -tt kfold -e $train_epoch -tf cls_labeled.txt -ptm bert -d 0 -m bert_model -b $batch_size1 -l 0.00005 -tm pretrained_bert.pth
46 | # roberta
47 | python train.py -tt only_train -e $ptm_epoch -tf $pt_file -ptm roberta -d 0 -m roberta -b $batch_size1 -l 0.0002
48 | python train.py -tt kfold -e $train_epoch -tf cls_labeled.txt -ptm roberta -d 0 -m roberta_model -b $batch_size1 -l 0.00005 -tm pretrained_roberta.pth
49 | # electra
50 | python train.py -tt only_train -e $ptm_epoch -tf $pt_file -ptm electra -d 0 -m electra -b $batch_size1 -l 0.0002
51 | python train.py -tt kfold -e $train_epoch -tf cls_labeled.txt -ptm electra -d 0 -m electra_model -b $batch_size1 -l 0.00005 -tm pretrained_electra.pth
52 | # xlnet
53 | python train.py -tt only_train -e $ptm_epoch -tf $pt_file -ptm xlnet -d 0 -m xlnet -b $batch_size2 -l 0.0002
54 | python train.py -tt kfold -e $train_epoch -tf cls_labeled.txt -ptm xlnet -d 0 -m xlnet_model -b $batch_size2 -l 0.00005 -tm pretrained_xlnet.pth
55 | # gpt2
56 | #python train.py -tt only_train -e $ptm_epoch -tf $pt_file -ptm gpt2 -d 0 -m gpt2 -b $batch_size3 -l 0.0001
57 | #python train.py -tt kfold -e $train_epoch -tf cls_labeled.txt -ptm gpt2 -d 0 -m gpt2_model -b $batch_size3 -l 0.00005 -tm pretrained_gpt2.pth
58 | # macbert
59 | python train.py -tt only_train -e $ptm_epoch -tf $pt_file -ptm macbert -d 0 -m macbert -b $batch_size1 -l 0.0002
60 | python train.py -tt kfold -e $train_epoch -tf cls_labeled.txt -ptm macbert -d 0 -m macbert_model -b $batch_size1 -l 0.00005 -tm pretrained_macbert.pth
61 | # albert
62 | #python train.py -tt only_train -e $ptm_epoch -tf $pt_file -ptm albert -d 0 -m albert -b $batch_size1 -l 0.001
63 | #python train.py -tt kfold -e $train_epoch -tf cls_labeled.txt -ptm albert -d 0 -m albert_model -b $batch_size1 -l 0.00025 -tm pretrained_albert.pth
64 | # 二分类模型
65 | python train.py -tt only_binary -e $ptm_epoch_binary -tf binary_augment3.txt -ptm bert -d 0 -m bert -b 32 -l 0.00005
66 | python train.py -tt binary -e $train_epoch_binary -tf binary_labeled.txt -ptm bert -d 0 -m bert_binary -b 32 -l 0.00001 -tm pretrained_binary_bert.pth
67 | # ner
68 | python train_evaluate_constraint.py -m ner_model.pth -d 0 -tv bieo -k 5
69 | # inference
70 | python inference.py -cm bert_model.pth -nm ner_model.pth -tv bieo -e 1 -de 0 -b 4 -xm xlnet_model.pth -em electra_model.pth -rm roberta_model.pth -mm macbert_model.pth -k 5 -wb
71 | elif test $mode = 'orig'
72 | then
73 | python docker_process.py
74 | # BERT等模型的batch_size=350,XLNet的batch_size=256
75 | python train.py -tt kfold -e 50 -tf cls_labeled.txt -ptm bert -d 0 -m bert_model -b 350 -l 0.0002
76 | # eh为标注增强,可设为0
77 | python train_evaluate_constraint.py -m ner_model.pth -d 0 -tv bieo -k 5
78 | # 修改模型时,这里模型名也要修改, eh为标注增强,可设为0
79 | python inference.py -cm bert_model.pth -nm ner_model.pth -tv bieo -dp ../data/dataset/cls_unlabeled.txt -k 5
80 | else
81 | echo "其他"
82 | fi
83 |
--------------------------------------------------------------------------------
/code/statistic.py:
--------------------------------------------------------------------------------
1 | """
2 | @Author: hezf
3 | @Time: 2021/6/19 17:20
4 | @desc: 统计数据模块
5 | """
6 | import os
7 | import pandas as pd
8 | import re
9 | import json
10 |
11 |
12 | def statistic_entity(labeled_file: str = '../data/raw_data/train.xlsx'):
13 | """
14 | 统计”实体“分布
15 | :param labeled_file:
16 | :return:
17 | """
18 | labeled_data = pd.read_excel(labeled_file)
19 | entity_list = list(labeled_data.loc[:, '实体'])
20 | entity_count = {}
21 | for entity in entity_list:
22 | if entity not in entity_count:
23 | entity_count[entity] = 0
24 | entity_count[entity] += 1
25 | sorted_dict = sorted(entity_count.items(), key=lambda item: (item[1], item[0]), reverse=True)
26 | # for entity, count in sorted_dict:
27 | # print('实体:{}, 数量:{}'.format(entity, count))
28 | with open('../data/file/entity_count.json', 'w', encoding='utf-8') as f:
29 | json.dump(entity_count, f, ensure_ascii=False, indent=2)
30 |
31 |
32 | def statistic_property(labeled_file: str = '../data/raw_data/train.xlsx'):
33 | """
34 | 统计”属性名“分布
35 | :param labeled_file:
36 | :return:
37 | """
38 | labeled_data = pd.read_excel(labeled_file)
39 | property_list = list(labeled_data.loc[:, '属性名'])
40 | property_count = {}
41 | for prop in property_list:
42 | prop = prop.split('|')
43 | for p in prop:
44 | if p not in property_count:
45 | property_count[p] = 0
46 | property_count[p] += 1
47 | sorted_dict = sorted(property_count.items(), key=lambda item: (item[1], item[0]), reverse=True)
48 | for prop, count in sorted_dict:
49 | print('属性名:{}, 数量:{}'.format(prop, count))
50 | with open('../data/file/prop_count.json', 'w', encoding='utf-8') as f:
51 | json.dump(property_count, f, ensure_ascii=False, indent=2)
52 |
53 |
54 | def statistic_frequent_char(label_file: str = '../data/dataset/cls_labeled.txt'):
55 | """
56 | 统计频繁使用的字符,据此推断停用词
57 | :param label_file:
58 | :return:
59 | """
60 | char_count = {}
61 | less_5 = set()
62 | with open(label_file, 'r', encoding='utf-8') as f:
63 | for line in f:
64 | line = line.strip()
65 | question = line.split('\t')[0]
66 | for c in question:
67 | if c not in char_count:
68 | char_count[c] = 0
69 | char_count[c] += 1
70 | sorted_dict = sorted(char_count.items(), key=lambda item: (item[1], item[0]), reverse=True)
71 | for c, count in sorted_dict:
72 | if count < 10:
73 | less_5.add(c)
74 | print('字符:{}, 数量:{}'.format(c, count))
75 | print('少于5个字符的有:')
76 | for c in less_5:
77 | print(c)
78 |
79 |
80 | def get_all_service(data):
81 | """
82 | 得到子业务名称
83 | @param data: excel读取的dataframe数据
84 | @return: 包含所有子业务名称的list
85 | """
86 | service_name = set()
87 | for index, row in data.iterrows():
88 | constraint_names = row['约束属性名']
89 | constraint_values = row['约束属性值']
90 | constraint_names_list = re.split(r'[|\|]', str(constraint_names))
91 | constraint_values_list = re.split(r'[|\|]', str(constraint_values))
92 | constraint_names_list = [name.strip() for name in constraint_names_list]
93 | constraint_values_list = [value.strip() for value in constraint_values_list]
94 | for i in range(len(constraint_names_list)):
95 | if constraint_names_list[i] == '子业务':
96 | service_name.add(constraint_values_list[i])
97 | return list(service_name)
98 |
99 |
100 | def statistic_duplicate_constraint(labeled_file: str):
101 | """
102 | 统计约束属性值的子业务中,值相同的有哪些 // 结果 10条,全是'合约版'
103 | @param labeled_file:
104 | @return:
105 | """
106 | duplicate = 0
107 |
108 | labeled_data = pd.read_excel(labeled_file)
109 | con_names = labeled_data['约束属性名'].tolist()
110 | con_values = labeled_data['约束属性值'].tolist()
111 | con_names = [re.split(r'[|\|]', str(item)) for item in con_names]
112 | con_values = [re.split(r'[|\|]', str(item)) for item in con_values]
113 | for index in range(len(con_names)):
114 | services = []
115 | con_name = con_names[index]
116 | con_value = con_values[index]
117 | for name_id in range(len(con_name)):
118 | if con_name[name_id] == '子业务':
119 | if con_value[name_id] in services:
120 | print(services)
121 | duplicate += 1
122 | services.append(con_value[name_id])
123 | print(duplicate)
124 |
125 |
126 | def statistic_synonyms():
127 | """
128 | 统计包含的同义词,英文全小写。并将清洗后的同义词写入json文件中
129 | :return:
130 | """
131 | def in_ner_set(word):
132 | for ner in ner_set:
133 | if ner in word:
134 | return True
135 | return False
136 | ner_set = {'咪咕直播流量包', '上网版', '成员', '咪咕', '热剧vip', 'plus版', 'pptv', '优酷', '百度', '腾讯视频', '2020版',
137 | '大陆及港澳台版', '芒果', '乐享版', '小额版', '年包', '体验版-12个月', '电影vip', '合约版', '优酷会员', '王者荣耀',
138 | '长期版', '宝藏版', '免费版', '流量套餐', '乐视会员', '2019版', '基础版', '月包', '全球通版', '畅享包', '腾讯',
139 | '爱奇艺会员', '喜马拉雅', '普通版', '流量包', '芒果tv', '百度包', '24月方案', '2018版', '个人版', '半年包',
140 | '咪咕流量包', '爱奇艺', '阶梯版', '12月方案', '网易', '家庭版', '198', '5', '12', '160', '20', '220', '398', '15',
141 | '200', '24', '30', '3', '288', '18', '238', '120',
142 | '128', '60', '9', '10', '300', '188', '59', '68', '8', '38', '2', '80', '680', '70', '158', '1',
143 | '380', '298', '11', '65', '40', '19', '23', '29', '6', '99', '500', '22', '49', '100', '40', '700', '110',
144 | '20', '100', '500', '1', '30', '300'}
145 | # 有实体冲突的同义词:(语音留言、家庭亲情号、手机视频流量包)
146 | conflict = ('语音留言', '家庭亲情号', '手机视频流量包', '流量包')
147 | synonym_list = []
148 | synonym2entity = dict()
149 | for_cls = True
150 | with open('../data/raw_data/synonyms.txt', 'r', encoding='utf-8') as f:
151 | for line in f:
152 | blocks = line.strip().lower().split(' ')
153 | entity, temp_synonym_list = blocks[0], blocks[1].split('|')
154 | if for_cls or not in_ner_set(entity):
155 | for s in temp_synonym_list:
156 | s = s.strip()
157 | if s != '无' and s != '' and s not in conflict:
158 | # 为分类模型考虑的同义词增强
159 | if for_cls or not in_ner_set(s):
160 | synonym2entity[s] = entity
161 | synonym_list.append(s)
162 | synonym_list.sort(key=lambda item: len(item), reverse=True)
163 | if not for_cls:
164 | remove_synonym = list()
165 | for i, synonym in enumerate(synonym_list):
166 | for j in range(i):
167 | if synonym in synonym_list[j] and synonym2entity[synonym] != synonym2entity[synonym_list[j]]:
168 | # 冲突则移除该词,降低误差
169 | print('冲突:', synonym, synonym_list[j])
170 | remove_synonym += [synonym, synonym_list[j]]
171 | # print('需要移除的同义词有:', set(remove_synonym))
172 | for r_s in set(remove_synonym):
173 | synonym2entity.pop(r_s)
174 | entity2synonym = {}
175 | for s, e in synonym2entity.items():
176 | if e not in entity2synonym:
177 | entity2synonym[e] = []
178 | entity2synonym[e].append(s)
179 | with open('../data/file/synonyms{}.json'.format('' if for_cls else '_ner'), 'w', encoding='utf-8') as f:
180 | json.dump({'entity2synonym': entity2synonym, 'synonym2entity': synonym2entity},
181 | f,
182 | ensure_ascii=False,
183 | indent=2)
184 | print('同义词写入成功...')
185 |
186 |
187 | def statistic_prop_labels():
188 | """
189 | 统计属性名的占样本的比例
190 | [0.241, 0.32, 0.0542, 0.0564, 0.0384, 0.085, 0.0382, 0.129, 0.1276, 0.006, 0.014, 0.006, 0.0014, 0.007, 0.009, 0.0032, 0.0014, 0.001, 0.0004]
191 | [0.759, 0.68, 0.9458, 0.9436, 0.9616, 0.915, 0.9618, 0.871, 0.8724, 0.994, 0.986, 0.994, 0.9986, 0.993, 0.991, 0.9968, 0.9986, 0.999, 0.9996]
192 |
193 | [0.75, 0.91, 0.87, 0.67, 0.99, 0.94, 0.87, 0.96, 0.99, 0.99, 0.99, 0.96, 0.94, 0.98]
194 | """
195 | with open('../data/dataset/cls_label2id_fewer.json', 'r', encoding='utf-8') as f:
196 | label2id = json.load(f)['main_property']
197 | count = 0
198 | prop_count = [0] * len(label2id)
199 | alpha_count = []
200 | with open('../data/dataset/cls_labeled_fewer.txt', 'r', encoding='utf-8') as f:
201 | for line in f:
202 | count += 1
203 | blocks = line.strip().split('\t')
204 | props = blocks[2].split('|')
205 | for p in props:
206 | if p not in label2id:
207 | print(line)
208 | prop_count[label2id[p]] += 1
209 | for i in range(len(prop_count)):
210 | prop_count[i] /= count
211 | alpha_count.append(1 - prop_count[i])
212 | print(prop_count)
213 | print(alpha_count)
214 |
215 |
216 | def statistic_service(labeled_file: str = '../data/raw_data/train_denoised.xlsx'):
217 | """
218 |
219 | :param labeled_file:
220 | :return:
221 | """
222 | labeled_data = pd.read_excel(labeled_file)
223 | name_list = list(labeled_data.loc[:, '约束属性名'])
224 | value_list = list(labeled_data.loc[:, '约束属性值'])
225 | service = set()
226 | price = set()
227 | flow = set()
228 | for i, name in enumerate(name_list):
229 | n_list = str(name).split('|')
230 | v_list = str(value_list[i]).split('|')
231 | for j, n in enumerate(n_list):
232 | if n == '子业务':
233 | service.add(v_list[j])
234 | elif n == '流量':
235 | flow.add(v_list[j])
236 | elif n == '价格':
237 | price.add(v_list[j])
238 | print('所有的子业务:', service)
239 | print('所有的价格:', price)
240 | print('所有的流量:', flow)
241 |
242 |
243 | def statistic_wrong_rdf():
244 | '''
245 | 向kg输入train.xlsx前面的字段,输出预测结果,并和给出的答案做对比
246 | @return: train_wrong_triple.xlsx : 保存对比结果
247 | '''
248 | print('正在获取rdf结果并对比')
249 | from triples import KnowledgeGraph
250 |
251 | kg = KnowledgeGraph('../data/process_data/triples.rdf')
252 | df = pd.read_excel('../data/raw_data/train_denoised.xlsx')
253 | df.fillna('')
254 | id_list = []
255 | ans_list = []
256 | for iter, row in df.iterrows():
257 | ans_true = list(set(row['答案'].split('|')))
258 | question = row['用户问题']
259 | ans_type = row['答案类型']
260 | entity = row['实体']
261 | main_property = row['属性名'].split('|')
262 | operator = row['约束算子']
263 | if operator != 'min' and operator != 'max':
264 | operator == 'other'
265 | sub_properties = []
266 | cons_names = str(row['约束属性名']).split('|')
267 | cons_values = str(row['约束属性值']).split('|')
268 | if cons_names == ['nan']: cons_names = []
269 | for index in range(len(cons_names)):
270 | sub_properties.append([cons_names[index], cons_values[index]])
271 | ans = kg.fetch_ans(question, ans_type, entity, main_property, operator, sub_properties)
272 |
273 | def is_same(ans, ans_true):
274 | for an in ans:
275 | if an in ans_true:
276 | ans_true.remove(an)
277 | else:
278 | return False
279 | if len(ans_true) != 0:
280 | return False
281 | return True
282 |
283 | if not is_same(ans, ans_true):
284 | id_list.append(iter)
285 | ans_list.append(ans)
286 | print(id_list)
287 | df_save = df.iloc[id_list, [0, 1, 2, 3, 4, 5, 6, 7]]
288 | df_save['预测'] = ans_list
289 | print(df_save)
290 | df_save.to_excel('/data/huangbo/project/Tianchi_nlp_git/data/raw_data/train_wrong_triple.xlsx')
291 |
292 |
293 | def statistic_wrong_cons():
294 | '''
295 | 找出属性句中(min max不要,并且只要'档位介绍-xx')中原句中出现但约束中却没出现的价格、子业务、流量
296 | @return:
297 | '''
298 | print('正在获取rdf结果并对比')
299 | from triples import KnowledgeGraph
300 |
301 | kg = KnowledgeGraph('../data/process_data/triples.rdf')
302 | df = pd.read_excel('../data/raw_data/train_denoised.xlsx')
303 | df.fillna('')
304 | id_list = set()
305 | ans_list = []
306 | for iter, row in df.iterrows():
307 | ans_true = list(set(row['答案'].split('|')))
308 | question = row['用户问题']
309 | ans_type = row['答案类型']
310 | # 只对属性值的句子做处理
311 | if ans_type != '属性值':
312 | continue
313 | entity = row['实体']
314 | main_property = row['属性名'].split('|')
315 | # 排除属性中没有'-'的情况,只要'档位介绍-xx'的情况
316 | if '-' not in main_property[0]:
317 | continue
318 | operator = row['约束算子']
319 | # 排除operator为min或max的情况
320 | if operator != 'min' and operator != 'max':
321 | operator == 'other'
322 | else:
323 | continue
324 | sub_properties = {}
325 | cons_names = str(row['约束属性名']).split('|')
326 | cons_values = str(row['约束属性值']).split('|')
327 | if cons_names == ['nan']: cons_names = []
328 | for index in range(len(cons_names)):
329 | if cons_names[index] not in sub_properties:
330 | sub_properties[cons_names[index]] = []
331 | if cons_names[index] == '子业务':
332 | sub_properties[cons_names[index]].append(cons_values[index])
333 | else:
334 | sub_properties[cons_names[index]].append(int(cons_values[index]))
335 | price_ans, flow_ans, service_ans = kg.fetch_wrong_ans(question, ans_type, entity, main_property, operator, [])
336 | rdf_properties = {}
337 | rdf_properties['价格'] = price_ans
338 | rdf_properties['流量'] = flow_ans
339 | rdf_properties['子业务'] = service_ans
340 | compare_result = []
341 | for name, values in rdf_properties.items():
342 | for value in values:
343 | if name != '子业务': value = int(value)
344 | if name == '流量' and (value > 99 and value % 1024 == 0):
345 | value = int(value // 1024)
346 | if str(value) in question:
347 | if name in sub_properties:
348 | if value in sub_properties[name]:
349 | continue
350 | if value == '年包' and '半年包' in sub_properties[name]:
351 | continue
352 | if value == '百度' and '百度包' in sub_properties[name]:
353 | continue
354 | elif str(value) in entity:
355 | continue
356 | compare_result.append(name + '_' + str(value))
357 | id_list.add(iter)
358 | if compare_result != []:
359 | ans_list.append(compare_result)
360 |
361 | index = list(id_list)
362 | index.sort()
363 | df_save = df.iloc[index, [0, 1, 2, 3, 4, 5, 6, 7]]
364 | df_save['预测'] = ans_list
365 | df_save.to_excel('../data/raw_data/train_wrong_cons.xlsx')
366 |
367 |
368 | def statistic_wrong_cons_bieo():
369 | '''
370 | 找出属性句中(min max不要,并且只要'档位介绍-xx')中原句中出现但约束中却没出现的价格、子业务、流量
371 | 根据bieo的结果进行筛选,补全bieo
372 | @return:
373 | '''
374 | print('正在通过rdf补充信息')
375 | from triples import KnowledgeGraph
376 | from data import create_test_BIEO, create_test_BIO
377 |
378 | kg = KnowledgeGraph('../data/process_data/triples.rdf')
379 | with open('../data/file/train_bieo.json') as f:
380 | bieo_dict = json.load(f)
381 | # bieo_dict = create_test_BIEO('../data/raw_data/train_denoised_desyn.xlsx', False)
382 | type_map = {'价格': 'PRICE', '流量': 'FLOW', '子业务': 'SERVICE'}
383 |
384 | df = pd.read_excel('../data/raw_data/train_denoised.xlsx')
385 | df.fillna('')
386 | id_list = set()
387 | ans_list = []
388 | for iter, row in df.iterrows():
389 | ans_true = list(set(row['答案'].split('|')))
390 | question = row['用户问题']
391 | ans_type = row['答案类型']
392 | # 只对属性值的句子做处理
393 | if ans_type != '属性值':
394 | pass
395 | #continue
396 | entity = row['实体']
397 | main_property = row['属性名'].split('|')
398 | # 排除属性中没有'-'的情况,只要'档位介绍-xx'的情况
399 | if '-' not in main_property[0]:
400 | # continue
401 | pass
402 | operator = row['约束算子']
403 | # 排除operator为min或max的情况
404 | if operator != 'min' and operator != 'max':
405 | operator == 'other'
406 | else:
407 | continue
408 | anno = bieo_dict[question]
409 | sub_properties = {}
410 | cons_names = str(row['约束属性名']).split('|')
411 | cons_values = str(row['约束属性值']).split('|')
412 | if cons_names == ['nan']: cons_names = []
413 | for index in range(len(cons_names)):
414 | if cons_names[index] not in sub_properties:
415 | sub_properties[cons_names[index]] = []
416 | if cons_names[index] == '子业务':
417 | sub_properties[cons_names[index]].append(cons_values[index])
418 | else:
419 | sub_properties[cons_names[index]].append(int(cons_values[index]))
420 | price_ans, flow_ans, service_ans = kg.fetch_wrong_ans(question, ans_type, entity, main_property, operator, [])
421 | rdf_properties = {}
422 | rdf_properties['价格'] = price_ans
423 | rdf_properties['流量'] = flow_ans
424 | rdf_properties['子业务'] = service_ans
425 | compare_result = []
426 | for name, values in rdf_properties.items():
427 | for value in values:
428 | if name != '子业务': value = int(value)
429 | if name == '流量' and (value > 99 and value%1024 == 0):
430 | value = int(value//1024)
431 | if str(value) in question:
432 | if name in sub_properties and value in sub_properties[name]:
433 | continue
434 | if value == '百度' and '百度包' in sub_properties['子业务']:
435 | continue
436 | question_index = 0
437 | if str(value) in entity and (question.count(str(value)) == 1 or str(value) == '1'):
438 | continue
439 | while question[question_index:].find(str(value)) != -1:
440 | temp_index = question_index + question[question_index:].find(str(value))
441 | question_index = min(len(question), temp_index + len(str(value)))
442 | if anno[temp_index] == 'O':
443 | if name == '流量':
444 | if question_index < len(question):
445 | if question[question_index] == '元' or question[question_index].isnumeric():
446 | continue
447 | if name == '价格':
448 | if question_index < len(question):
449 | if question[question_index] == 'g' or question[question_index].isnumeric():
450 | continue
451 | if len(str(value)) == 1:
452 | anno[temp_index] = 'S-' + type_map[name]
453 | else:
454 | for temp_i in range(temp_index + 1, temp_index + len(str(value)) - 1):
455 | anno[temp_i] = 'I-' + type_map[name]
456 | anno[temp_index] = 'B-' + type_map[name]
457 | anno[temp_index + len(str(value)) - 1] = 'E-' + type_map[name]
458 | compare_result.append(name + '_' + str(value))
459 | bieo_dict[question] = anno
460 | id_list.add(iter)
461 |
462 | if compare_result != []:
463 | ans_list.append(compare_result)
464 |
465 | index = list(id_list)
466 | index.sort()
467 | df_save = df.iloc[index, [0, 1, 2, 3, 4, 5, 6, 7]]
468 | df_save['预测'] = ans_list
469 | df_save.to_excel('../data/raw_data/train_wrong_cons.xlsx')
470 | with open(r'../data/file/train_bieo_enhence.json', 'w') as f:
471 | json.dump(bieo_dict, f, indent=2, ensure_ascii=False)
472 |
473 |
474 | def statistic_sym_in_question():
475 | from data_argument import read_synonyms
476 | synonyms = read_synonyms()
477 | df = pd.read_excel('../data/raw_data/ensemble_bert_aug2_use_efficiency_2021-07-20-08-29-22_seed1_fi0_gpu2080_be81_0.95411990.pth.xlsx')
478 | # df = pd.read_excel(
479 | # '../data/raw_data/train_denoised.xlsx')
480 | c = 0
481 | no = set()
482 | yes = set()
483 | for iter, row in df.iterrows():
484 | entity = row['实体']
485 | question = row['用户问题']
486 | con_name = row['约束属性名']
487 | con_value = row['约束属性值']
488 | con_names = re.split(r'[|\|]', str(con_name))
489 | con_values = re.split(r'[|\|]', str(con_value))
490 | if entity not in synonyms: continue
491 | for synonym in synonyms[entity]:
492 | if synonym in question:
493 | for value in con_values:
494 | if value in synonym:
495 | no.add(synonym)
496 | if synonym not in no:
497 | yes.add(synonym)
498 | print(no)
499 | print(yes)
500 | for iter, row in df.iterrows():
501 | entity = row['实体']
502 | question = row['用户问题']
503 | con_name = row['约束属性名']
504 | con_value = row['约束属性值']
505 | con_names = re.split(r'[|\|]', str(con_name))
506 | con_values = re.split(r'[|\|]', str(con_value))
507 | if entity not in synonyms: continue
508 | new_question = question
509 | for synonym in synonyms[entity]:
510 | if synonym in question and (synonym not in no):
511 | new_question = question.replace(synonym, entity)
512 | row['用户问题'] = new_question
513 | # df.to_excel('../data/raw_data/train_denoised_desyn.xlsx')
514 |
515 |
516 | df = pd.read_excel(
517 | '../data/raw_data/train_denoised.xlsx')
518 | c = 0
519 | no = set()
520 | yes = set()
521 | for iter, row in df.iterrows():
522 | entity = row['实体']
523 | question = row['用户问题']
524 | con_name = row['约束属性名']
525 | con_value = row['约束属性值']
526 | con_names = re.split(r'[|\|]', str(con_name))
527 | con_values = re.split(r'[|\|]', str(con_value))
528 | if entity not in synonyms: continue
529 | for synonym in synonyms[entity]:
530 | if synonym in question:
531 | for value in con_values:
532 | if value in synonym:
533 | no.add(synonym)
534 | if synonym not in no:
535 | yes.add(synonym)
536 | print(no)
537 | print(yes)
538 | for iter, row in df.iterrows():
539 | entity = row['实体']
540 | question = row['用户问题']
541 | con_name = row['约束属性名']
542 | con_value = row['约束属性值']
543 | con_names = re.split(r'[|\|]', str(con_name))
544 | con_values = re.split(r'[|\|]', str(con_value))
545 | if entity not in synonyms: continue
546 | new_question = question
547 | for synonym in synonyms[entity]:
548 | if synonym in question and (synonym not in no):
549 | new_question = question.replace(synonym, entity)
550 | row['用户问题'] = new_question
551 |
552 |
553 | # def compare_result(result1_file, result2_file):
554 | # """
555 | #
556 | # :param result1_file: 原结果
557 | # :param result2_file: 新结果
558 | # :return:
559 | # """
560 | # result_path = '../data/results/'
561 | # result1 = json.load(open(os.path.join(result_path, result1_file), 'r'))
562 | # result2 = json.load(open(os.path.join(result_path, result2_file), 'r'))
563 | # for idx, value in result1['model_result'].items():
564 | # if value != result2['model_result'][idx]:
565 | # print('结果1:', value, '\n结果2:', result2['model_result'][idx])
566 | # print('\n')
567 |
568 |
569 | def string2result(txt):
570 | with open(txt, 'r') as f:
571 | line = f.read()
572 | result_dict = {}
573 | str_list = line.split('+')
574 | for str in str_list[:-1]:
575 | pro_list = str.split('=')
576 | id = pro_list[0]
577 | result_dict[id] = {}
578 | result_dict[id]['question'] = pro_list[1]
579 | result_dict[id]['ans_type'] = pro_list[2]
580 | result_dict[id]['entity'] = pro_list[3]
581 | result_dict[id]['main_property'] = pro_list[4].split('x')
582 | result_dict[id]['operator'] = pro_list[5]
583 | result_dict[id]['sub_properties'] = []
584 | if pro_list[6] != '':
585 | sub_list = pro_list[6].split('x')
586 | for sub in sub_list:
587 | name = sub.split('_')[0]
588 | value = sub.split('_')[1]
589 | if name == '流量': value = int(value)
590 | if pro_list[5] == 'max' or pro_list[5] == 'min':
591 | value = int(value)
592 | result_dict[id]['sub_properties'].append((name, value))
593 |
594 | with open('../data/results/result_docker.json', 'w', encoding='utf-8') as f:
595 | json.dump({'model_result':result_dict}, f, ensure_ascii=False, indent=2)
596 |
597 |
598 | def compare_result(file1, file2):
599 | with open('../data/results/' + file1, 'r') as f:
600 | base = json.load(f)['model_result']
601 | with open('../data/results/' + file2, 'r') as f:
602 | result = json.load(f)['model_result']
603 |
604 | c = 0
605 | for index, dic in base.items():
606 | dic1 = result[index]
607 | # if dic1['sub_properties'] != dic['sub_properties']:
608 | # if dic1['entity'] != dic['entity']:
609 | # if dic1['main_property'] != dic['main_property']:
610 | # if dic1['ans_type'] != dic['ans_type']:
611 | if dic1['sub_properties'] != dic['sub_properties'] or dic1['entity'] != dic['entity'] or set(dic1['main_property']) != set(dic['main_property']) or dic1['ans_type'] != dic['ans_type']:
612 | c += 1
613 | print(index)
614 | print('file1:', dic)
615 | print('file2:', dic1)
616 | print()
617 | print(c)
618 |
619 |
620 | def static_confusion_entity():
621 | with open('../data/dataset/cls_label2id.json', 'r', encoding='utf-8') as f:
622 | label2id = json.load(f)
623 | entities = list(label2id['entity'])
624 | for i in range(len(entities)):
625 | entity1 = entities[i]
626 | for j in range(len(entities)):
627 | entity2 = entities[j]
628 | if i != j and entity1.find(entity2) != -1:
629 | print(entity1, 'contains', entity2)
630 |
631 |
632 | def static_no_answer(file):
633 | """
634 | 统计答案中没有属性名的样本
635 | :param file:
636 | :return:
637 | """
638 | with open('../data/results/' + file, 'r') as f:
639 | result_dict = json.load(f)
640 | result = result_dict['result']
641 | model_result = result_dict['model_result']
642 | for i in result:
643 | main_property = model_result[i]['main_property']
644 | for prop in main_property:
645 | if '-' in prop:
646 | relation = prop.split('-')[1]
647 | else:
648 | relation = prop
649 | if relation not in ('流量', '价格', '子业务', '上线时间', '语音时长', '带宽', '内含其它服务') and model_result[i]['ans_type'] != '比较句':
650 | if relation not in result[i]:
651 | print('{}\t{}'.format(model_result[i], result[i]))
652 |
653 |
654 | if __name__ == '__main__':
655 | # statistic_entity(labeled_file='../data/raw_data/train_denoised.xlsx')
656 | # statistic_property()
657 | # statistic_synonyms()
658 | # statistic_prop_labels()
659 | # statistic_service()
660 | # compare_result(result1_file='drop_1_2021-07-28-02-58-01_seed1_gpu2080_be34_0.95407602.pth1.json',
661 | # result2_file='drop_1_2021-07-28-02-58-01_seed1_gpu2080_be34_0.95407602.pth.json')
662 | # statistic_prop_labels()
663 | # compare_result(result1_file='bert_pt_augment2_2021-07-30-05-44-08_seed1_gpu2080_be6_0.99729231.pth.json',
664 | # result2_file='ensemble_bert_aug2_use_efficiency_2021-07-20-08-29-22_seed1_fi0_gpu2080_be81_0.95411990.pth.json')
665 | string2result('../data/results/result_text.txt')
666 | compare_result('9677_38.json', '9682.json')
667 |
--------------------------------------------------------------------------------
/code/test.py:
--------------------------------------------------------------------------------
1 | """
2 | @Author: hezf
3 | @Time: 2021/7/3 19:43
4 | @desc:
5 | """
6 | import warnings
7 | warnings.filterwarnings("ignore")
8 | # from nlpcda import Simbert, Similarword, RandomDeleteChar, Homophone, CharPositionExchange
9 | import time
10 | from multiprocessing import Process
11 | from data_argument import augment_from_simbert
12 | import pandas as pd
13 | import json
14 |
15 | # ---simbert---
16 | # config = {
17 | # 'model_path': '/data2/hezhenfeng/other_model_files/chinese_simbert_L-6_H-384_A-12',
18 | # 'CUDA_VISIBLE_DEVICES': '6',
19 | # 'max_len': 40,
20 | # 'seed': 1,
21 | # 'device': 'gpu'
22 | # }
23 | #
24 | # simbert = Simbert(config=config)
25 | # sent_list = ['9元百度专属定向流量包如何取消',
26 | # '你告诉我7天5g视频会员流量包怎么开通,多少钱',
27 | # '您好:怎么开通70爱奇艺',
28 | # 'plus会员领取的权益可以取消吗',
29 | # '通州卡如何取消']
30 | #
31 | # for sent in sent_list:
32 | # synonyms = simbert.replace(sent=sent, create_num=5)
33 | # print(synonyms)
34 |
35 | # ------------nlpcda 一般增强-------------
36 | # start = time.time()
37 | # sent_list = ['9元百度专属定向流量包如何取消',
38 | # '你告诉我7天5g视频会员流量包怎么开通,多少钱',
39 | # '您好:怎么开通70爱奇艺',
40 | # 'plus会员领取的权益可以取消吗',
41 | # '通州卡如何取消']
42 | #
43 | # smw = CharPositionExchange(create_num=3, change_rate=0.01)
44 | # for sent in sent_list:
45 | # rs1 = smw.replace(sent)
46 | # print(rs1)
47 | # end = time.time()
48 | # print(end-start)
49 |
50 | # ---------textda---------------
51 | # from utils import setup_seed
52 | # from textda.data_expansion import data_expansion
53 | # import random
54 | #
55 | # if __name__ == '__main__':
56 | # # setup_seed(1)
57 | # random.seed(1)
58 | # print(data_expansion('这是一句测试的句子。'))
59 |
60 | # ------------多进程------------
61 |
62 | # def function1(id): # 这里是子进程
63 | # augment_from_simbert(source_file='../data/raw_data/train_denoised.xlsx',
64 | # target_file='../data/raw_data/train_augment_simbert.xlsx')
65 | #
66 | #
67 | # def run_process(): # 这里是主进程
68 | # from multiprocessing import Process
69 | # process = [Process(target=function1, args=(1,)),
70 | # Process(target=function1, args=(2,)), ]
71 | # [p.start() for p in process] # 开启了两个进程
72 | # [p.join() for p in process] # 等待两个进程依次结束
73 |
74 |
75 | # ------------标注文件转化成预测样本的格式---------------
76 | def to_predict_file():
77 | data = []
78 | with open('../data/dataset/cls_labeled.txt', 'r', encoding='utf-8') as f:
79 | for line in f:
80 | data.append(line.strip().split('\t')[0])
81 | with open('../data/dataset/labeled_predict.txt', 'w', encoding='utf-8') as f:
82 | for line in data:
83 | f.write(line+'\n')
84 |
85 |
86 | def xlsx2json():
87 | data = {}
88 | df = pd.read_excel('../data/raw_data/train_denoised.xlsx')
89 | for i in range(len(df)):
90 | line = df.loc[i]
91 | data[str(i)] = {
92 | 'question': line['用户问题'],
93 | 'ans_type': line['答案类型'],
94 | 'entity': line['实体'],
95 | 'main_property': line['属性名']}
96 | with open('../data/results/train_result.json', 'w', encoding='utf-8') as f:
97 | json.dump({'model_result': data}, f, indent=2, ensure_ascii=False)
98 |
99 |
100 | if __name__ == '__main__':
101 | # xlsx2json()
102 | from tqdm import tqdm
103 | from utils import judge_cancel
104 | yes = 0
105 | # with open('../data/dataset/cls_labeled.txt', 'r', encoding='utf-8') as f:
106 | # for line in tqdm(f):
107 | # instance = line.strip().split('\t')
108 | # if judge_cancel(instance[0]):
109 | # if '取消' not in instance[2]:
110 | # print(instance)
111 | # else:
112 | # yes += 1
113 | # print(yes)
114 | with open('../data/dataset/cls_unlabeled2.txt', 'r', encoding='utf-8') as f:
115 | for line in tqdm(f):
116 | instance = line.strip().split('\t')
117 | if judge_cancel(instance[0]):
118 | print(instance)
119 |
--------------------------------------------------------------------------------
/code/train_evaluate_constraint.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import re
3 | import json
4 | import os
5 | import random
6 | import time
7 | import torch
8 | from transformers import BertModel, BertTokenizer
9 |
10 | from model import Metrics, BILSTM_CRF_Model, BERT_BILSTM_CRF_Model
11 | from data import *
12 | from utils import setup_seed, load_model
13 | from statistic import get_all_service, statistic_wrong_cons_bieo
14 | import argparse
15 |
16 |
17 | def train(excel_path, model_name, tag_version='bio'):
18 |
19 | def build_map(lists):
20 | maps = {}
21 | for list_ in lists:
22 | for e in list_:
23 | if e == '': continue
24 | if e not in maps:
25 | maps[e] = len(maps)
26 | maps[''] = len(maps)
27 | maps[''] = len(maps)
28 | maps[''] = len(maps)
29 | maps[''] = len(maps)
30 | return maps
31 |
32 | # data
33 | print("读取数据...")
34 | with open(r'../data/file/train_{}.json'.format(tag_version)) as f:
35 | train_dict = json.load(f)
36 | # with open(r'../data/file/train_dev_ids.json') as f:
37 | # id_dict = json.load(f)
38 | raw_data = pd.read_excel(excel_path)
39 | question_list = raw_data['用户问题'].tolist()
40 | question_list = [question.replace(' ', '') for question in question_list]
41 | for i in range(len(question_list)):
42 | question = question_list[i]
43 | question = question.replace('三十', '30')
44 | question = question.replace('四十', '40')
45 | question = question.replace('十块', '10块')
46 | question = question.replace('六块', '6块')
47 | question = question.replace('一个月', '1个月')
48 | question = question.replace('2O', '20')
49 | question_list[i] = question
50 | # 按照划分好的id划分数据
51 | # train_list = [question_list[i] for i in id_dict['train_ids']]
52 | # dev_list = [question_list[i] for i in id_dict['dev_ids']]
53 | # test_list = [question_list[i] for i in id_dict['dev_ids']]
54 | # 随机划分数据
55 | random.shuffle(question_list)
56 | train_list = question_list[:int(0.8*len(question_list))]
57 | dev_list = question_list[int(0.8*len(question_list)): int(0.9*len(question_list))]
58 | test_list = question_list[int(0.9*len(question_list)):]
59 |
60 | train_word_lists, train_tag_lists = process_data(train_list, train_dict)
61 | dev_word_lists, dev_tag_lists = process_data(dev_list, train_dict)
62 | test_word_lists, test_tag_lists = process_data(test_list, train_dict, test=True)
63 | # 生成word2id 和 tag2id 并保存
64 | word2id = build_map(train_word_lists)
65 | tag2id = build_map(train_tag_lists)
66 | with open(r'../data/file/word2id.json', 'w') as f: json.dump(word2id, f)
67 | with open(r'../data/file/tag2id_{}.json'.format(tag_version), 'w') as f: json.dump(tag2id, f)
68 | # train
69 | print("正在训练评估Bi-LSTM+CRF模型...")
70 | start = time.time()
71 | vocab_size = len(word2id)
72 | out_size = len(tag2id)
73 | bilstmcrf_model = BILSTM_CRF_Model(vocab_size, out_size, gpu_id=args.cuda)
74 | bilstmcrf_model.train(train_word_lists, train_tag_lists,
75 | dev_word_lists, dev_tag_lists, word2id, tag2id, debug=args.debug)
76 | torch.save(bilstmcrf_model, '../data/trained_model/{}'.format(model_name))
77 | print("训练完毕,共用时{}秒.".format(int(time.time() - start)))
78 | print("评估{}模型中...".format('bilstm-crf'))
79 | pred_tag_lists, test_tag_lists = bilstmcrf_model.test(
80 | test_word_lists, test_tag_lists, word2id, tag2id)
81 |
82 | print(len(test_tag_lists))
83 | print(len(pred_tag_lists))
84 | if args.debug=='1':
85 | metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=False)
86 | metrics.report_scores()
87 | metrics.report_confusion_matrix()
88 |
89 |
90 | def train_kfold(excel_path, model_name, kfold, tag_version='bio'):
91 |
92 | def build_map(lists):
93 | maps = {}
94 | for list_ in lists:
95 | for e in list_:
96 | if e == '': continue
97 | if e not in maps:
98 | maps[e] = len(maps)
99 | maps[''] = len(maps)
100 | maps[''] = len(maps)
101 | maps[''] = len(maps)
102 | maps[''] = len(maps)
103 | return maps
104 |
105 | # data
106 | # print("读取数据...")
107 | with open(r'../data/file/train_{}.json'.format(tag_version)) as f:
108 | train_dict = json.load(f)
109 | raw_data = pd.read_excel(excel_path)
110 | question_list = raw_data['用户问题'].tolist()
111 | question_list = [question.replace(' ', '') for question in question_list]
112 | for i in range(len(question_list)):
113 | question = question_list[i]
114 | question = question.replace('三十', '30')
115 | question = question.replace('四十', '40')
116 | question = question.replace('十块', '10块')
117 | question = question.replace('六块', '6块')
118 | question = question.replace('一个月', '1个月')
119 | question = question.replace('2O', '20')
120 | question_list[i] = question
121 | # 按照划分好的id划分数据
122 | # train_list = [question_list[i] for i in id_dict['train_ids']]
123 | # dev_list = [question_list[i] for i in id_dict['dev_ids']]
124 | # test_list = [question_list[i] for i in id_dict['dev_ids']]
125 | # 随机划分数据
126 | random.shuffle(question_list)
127 | block_len = len(question_list)//kfold
128 | data_blocks = [question_list[i*block_len: (i+1)*block_len] for i in range(kfold-1)]
129 | data_blocks.append(question_list[(kfold-1)*block_len:])
130 | # 使用所有训练数据生成word2id tag2id
131 | all_train_list = []
132 | for block in data_blocks:
133 | all_train_list += block
134 | all_train_word_lists, all_train_tag_lists = process_data(all_train_list, train_dict)
135 | all_word2id = build_map(all_train_word_lists)
136 | all_tag2id = build_map(all_train_tag_lists)
137 | with open(r'../data/file/word2id_{}.json'.format('all'), 'w') as f: json.dump(all_word2id, f)
138 | with open(r'../data/file/tag2id_{}_{}.json'.format(tag_version, 'all'), 'w') as f: json.dump(all_tag2id, f)
139 | # 开始kfold训练
140 | for i, block in enumerate(data_blocks):
141 | train_list, dev_list = [], block
142 | for _ in range(kfold):
143 | if _ != i:
144 | train_list += data_blocks[_]
145 | train_word_lists, train_tag_lists = process_data(train_list, train_dict)
146 | dev_word_lists, dev_tag_lists = process_data(dev_list, train_dict)
147 | # 生成word2id 和 tag2id 并保存
148 | word2id = build_map(train_word_lists)
149 | tag2id = build_map(train_tag_lists)
150 | with open(r'../data/file/word2id_{}.json'.format(i), 'w') as f: json.dump(word2id, f)
151 | with open(r'../data/file/tag2id_{}_{}.json'.format(tag_version, i), 'w') as f: json.dump(tag2id, f)
152 | # train
153 | print("正在训练评估Bi-LSTM+CRF模型...")
154 | start = time.time()
155 | vocab_size = len(all_word2id)
156 | out_size = len(all_tag2id)
157 | bilstmcrf_model = BILSTM_CRF_Model(vocab_size, out_size, gpu_id=args.cuda)
158 | bilstmcrf_model.train(train_word_lists, train_tag_lists,
159 | dev_word_lists, dev_tag_lists, all_word2id, all_tag2id, debug=args.debug)
160 | torch.save(bilstmcrf_model, '../data/trained_model/{}'.format(model_name.replace('.pth', '_{}'.format(i)+'.pth')))
161 | print("训练完毕,共用时{}秒.".format(int(time.time() - start)))
162 | if args.debug == '1':
163 | print("评估{}模型中...".format('bilstm-crf'))
164 | pred_tag_lists, dev_tag_lists = bilstmcrf_model.test(
165 | dev_word_lists, dev_tag_lists, all_word2id, all_tag2id)
166 |
167 |
168 | print(len(dev_tag_lists))
169 | print(len(pred_tag_lists))
170 | metrics = Metrics(dev_tag_lists, pred_tag_lists, remove_O=False)
171 | metrics.report_scores()
172 | metrics.report_confusion_matrix()
173 |
174 |
175 | def train2(excel_path, model_name, tag_version='bio'):
176 |
177 | def build_map(lists):
178 | maps = {}
179 | for list_ in lists:
180 | for e in list_:
181 | if e == '': continue
182 | if e not in maps:
183 | maps[e] = len(maps)
184 | maps[''] = len(maps)
185 | maps[''] = len(maps)
186 | maps[''] = len(maps)
187 | maps[''] = len(maps)
188 | return maps
189 |
190 | # data
191 | print("读取数据...")
192 | with open(r'../data/file/train_{}.json'.format(tag_version)) as f:
193 | train_dict = json.load(f)
194 | with open(r'../data/file/train_dev_ids.json') as f:
195 | id_dict = json.load(f)
196 | raw_data = pd.read_excel(excel_path)
197 | question_list = raw_data['用户问题'].tolist()
198 | question_list = [question.replace(' ', '') for question in question_list]
199 | for i in range(len(question_list)):
200 | question = question_list[i]
201 | question = question.replace('三十', '30')
202 | question = question.replace('四十', '40')
203 | question = question.replace('十块', '10块')
204 | question = question.replace('六块', '6块')
205 | question = question.replace('一个月', '1个月')
206 | question = question.replace('2O', '20')
207 | question_list[i] = question
208 | # 按照划分好的id划分数据
209 | # train_list = [question_list[i] for i in id_dict['train_ids']]
210 | # dev_list = [question_list[i] for i in id_dict['dev_ids']]
211 | # test_list = [question_list[i] for i in id_dict['dev_ids']]
212 | # 随机划分数据
213 | random.shuffle(question_list)
214 | train_list = question_list[:int(0.8*len(question_list))]
215 | dev_list = question_list[int(0.8*len(question_list)): int(0.9*len(question_list))]
216 | test_list = question_list[int(0.9*len(question_list)):]
217 |
218 | train_word_lists, train_tag_lists = process_data(train_list, train_dict)
219 | dev_word_lists, dev_tag_lists = process_data(dev_list, train_dict)
220 | test_word_lists, test_tag_lists = process_data(test_list, train_dict, test=True)
221 | # bert
222 | BERT_PRETRAINED_PATH = '../data/bert_pretrained_model'
223 | word2id = {}
224 | with open(os.path.join(BERT_PRETRAINED_PATH, 'vocab.txt'), 'r') as f:
225 | count = 0
226 | for line in f:
227 | word2id[line.split('\n')[0]] = count
228 | count += 1
229 | bert_model = BertModel.from_pretrained(BERT_PRETRAINED_PATH)
230 | # 生成word2id 和 tag2id 并保存
231 | tag2id = build_map(train_tag_lists)
232 | with open(r'../data/file/word2id_bert.json', 'w') as f: json.dump(word2id, f)
233 | with open(r'../data/file/tag2id_{}.json'.format(tag_version), 'w') as f: json.dump(tag2id, f)
234 |
235 | # train
236 | print("正在训练评估Bert-Bi-LSTM+CRF模型...")
237 | start = time.time()
238 | vocab_size = len(word2id)
239 | out_size = len(tag2id)
240 | bilstmcrf_model = BERT_BILSTM_CRF_Model(vocab_size, out_size, bert_model, gpu_id=args.cuda)
241 | bilstmcrf_model.train(train_word_lists, train_tag_lists,
242 | dev_word_lists, dev_tag_lists, word2id, tag2id)
243 | torch.save(bilstmcrf_model, '../data/trained_model/{}'.format(model_name))
244 | print("训练完毕,共用时{}秒.".format(int(time.time() - start)))
245 | print("评估{}模型中...".format('bert-bilstm-crf'))
246 | pred_tag_lists, test_tag_lists = bilstmcrf_model.test(
247 | test_word_lists, test_tag_lists, word2id, tag2id)
248 |
249 | print(len(test_tag_lists))
250 | print(len(pred_tag_lists))
251 | metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=False)
252 | metrics.report_scores()
253 | metrics.report_confusion_matrix()
254 |
255 |
256 | def evaluate(excel_path, word2id_version='word2id.json', model_name='bilstmcrf.pth', tag_version='bio'):
257 | with open(r'../data/file/train_{}.json'.format(tag_version), 'r') as f:
258 | train_dict = json.load(f)
259 | with open(r'../data/file/{}'.format(word2id_version), 'r') as f: word2id = json.load(f)
260 | with open(r'../data/file/tag2id_{}.json'.format(tag_version), 'r') as f: tag2id = json.load(f)
261 | # 得到训练数据
262 | raw_data = pd.read_excel(excel_path)
263 | question_list = raw_data['用户问题'].tolist()
264 | question_list = [question.replace(' ', '') for question in question_list]
265 | for i in range(len(question_list)):
266 | question = question_list[i]
267 | question = question.replace('三十', '30')
268 | question = question.replace('四十', '40')
269 | question = question.replace('十块', '10块')
270 | question = question.replace('六块', '6块')
271 | question = question.replace('一个月', '1个月')
272 | question = question.replace('2O', '20')
273 | question_list[i] = question
274 | index_list = list(range(len(question_list)))
275 | random.shuffle(index_list)
276 | question_list = [question_list[i] for i in index_list]
277 | test_word_lists, test_tag_lists = process_data(question_list[:], train_dict)
278 | # 得到原来的约束属性名和约束值
279 | con_names = raw_data['约束属性名'].tolist()
280 | con_values = raw_data['约束属性值'].tolist()
281 | con_names = [re.split(r'[|\|]', str(con_names[i])) for i in index_list]
282 | con_values = [re.split(r'[|\|]', str(con_values[i])) for i in index_list]
283 | con_dict_list = []
284 | for i in range(len(con_names)):
285 | con_dict = {}
286 | con_name = con_names[i]
287 | con_value = con_values[i]
288 | for j in range(len(con_name)):
289 | name = con_name[j].strip()
290 | value = con_value[j].strip()
291 | if name == 'nan' or name == '有效期': continue
292 | if name == '价格' and value == '1' and '最' in question_list[i]: continue
293 | if name == '流量' and value == '1' and '最' in question_list[i]: continue
294 | if name not in con_dict: con_dict[name] = []
295 | con_dict[name].append(value)
296 | con_dict_list.append(con_dict)
297 |
298 | service_names = get_all_service(raw_data)
299 |
300 | bilstmcrf_model = load_model(r'../data/trained_model/{}'.format(model_name), map_location=lambda storage, loc: storage.cuda('cuda:{}'.format(args.cuda)))
301 | bilstmcrf_model.device = torch.device('cuda:{}'.format(args.cuda))
302 | pred_tag_lists, test_tag_lists = bilstmcrf_model.test(
303 | test_word_lists, test_tag_lists, word2id, tag2id)
304 | c = 0
305 | for index in range(len(pred_tag_lists)):
306 | pre_dict = get_anno_dict(question_list[index], pred_tag_lists[index], service_names)
307 | # 对比 官方标注 和 非增强预测标注
308 | if set(pre_dict) != set(con_dict_list[index]):
309 | c+=1
310 | print(question_list[index])
311 | print('true:',test_tag_lists[index])
312 | print('pred:',pred_tag_lists[index])
313 | print('true:',con_dict_list[index])
314 | print('pred:',pre_dict)
315 | print(c)
316 |
317 |
318 | def evaluate_kfold(excel_path, kfold, word2id_version='word2id.json', model_name='bilstmcrf.pth', tag_version='bio'):
319 | for i in range(kfold):
320 | with open(r'../data/file/train_{}.json'.format(tag_version), 'r') as f:
321 | train_dict = json.load(f)
322 | with open(r'../data/file/{}'.format(word2id_version.replace('.json', '_{}.json'.format('all'))), 'r') as f: word2id = json.load(f)
323 | with open(r'../data/file/tag2id_{}_{}.json'.format(tag_version, 'all'), 'r') as f: tag2id = json.load(f)
324 | # 得到训练数据
325 | raw_data = pd.read_excel(excel_path)
326 | question_list = raw_data['用户问题'].tolist()
327 | question_list = [question.replace(' ', '') for question in question_list]
328 | for i in range(len(question_list)):
329 | question = question_list[i]
330 | question = question.replace('三十', '30')
331 | question = question.replace('四十', '40')
332 | question = question.replace('十块', '10块')
333 | question = question.replace('六块', '6块')
334 | question = question.replace('一个月', '1个月')
335 | question = question.replace('2O', '20')
336 | question_list[i] = question
337 | index_list = list(range(len(question_list)))
338 | random.shuffle(index_list)
339 | question_list = [question_list[i] for i in index_list]
340 | test_word_lists, test_tag_lists = process_data(question_list[:], train_dict)
341 | # 得到原来的约束属性名和约束值
342 | con_names = raw_data['约束属性名'].tolist()
343 | con_values = raw_data['约束属性值'].tolist()
344 | con_names = [re.split(r'[|\|]', str(con_names[i])) for i in index_list]
345 | con_values = [re.split(r'[|\|]', str(con_values[i])) for i in index_list]
346 | con_dict_list = []
347 | for i in range(len(con_names)):
348 | con_dict = {}
349 | con_name = con_names[i]
350 | con_value = con_values[i]
351 | for j in range(len(con_name)):
352 | name = con_name[j].strip()
353 | value = con_value[j].strip()
354 | if name == 'nan' or name == '有效期': continue
355 | if name == '价格' and value == '1' and '最' in question_list[i]: continue
356 | if name == '流量' and value == '1' and '最' in question_list[i]: continue
357 | if name not in con_dict: con_dict[name] = []
358 | con_dict[name].append(value)
359 | con_dict_list.append(con_dict)
360 |
361 | service_names = get_all_service(raw_data)
362 |
363 | bilstmcrf_model = load_model(r'../data/trained_model/{}'.format(model_name.replace('.pth', '_{}.pth'.format(i))), map_location=lambda storage, loc: storage.cuda('cuda:{}'.format(args.cuda)))
364 | bilstmcrf_model.device = torch.device('cuda:{}'.format(args.cuda))
365 | pred_tag_lists, test_tag_lists = bilstmcrf_model.test(
366 | test_word_lists, test_tag_lists, word2id, tag2id)
367 | # 使用了 enhence 又需要对比增强前的
368 | c = 0
369 | for index in range(len(pred_tag_lists)):
370 | pre_dict = get_anno_dict(question_list[index], pred_tag_lists[index], service_names)
371 | test_dict = get_anno_dict(question_list[index], test_tag_lists[index][:-1], service_names)
372 | # 对比 官方标注 和 非增强预测标注
373 | if set(pre_dict) != set(con_dict_list[index]):
374 | c+=1
375 | print(question_list[index])
376 | print('true:',test_tag_lists[index])
377 | print('pred:',pred_tag_lists[index])
378 | print('true:',con_dict_list[index])
379 | print('pred:',pre_dict)
380 | print(c)
381 |
382 |
383 | if __name__ == '__main__':
384 | print('开始train_ner')
385 | setup_seed(1)
386 | excel_path = r'../data/raw_data/train_denoised_ner.xlsx'
387 | # excel_path = r'../data/raw_data/train_syn.xlsx'
388 | # os.environ["CUDA_VISIBLE_DEVICES"] = "5"
389 | # create_test_BIEO(excel_path)
390 | parser = argparse.ArgumentParser()
391 | parser.add_argument('--train_type', '-tt', default='lstm', type=str)
392 | parser.add_argument('--tag_version', '-tv', default='bieo', type=str)
393 | parser.add_argument('--model_name', '-m', default='ner_model_test.pth', type=str)
394 | parser.add_argument('--cuda', '-c', default=0)
395 | parser.add_argument('--debug', '-d', default='1', choices=['0', '1'])
396 | parser.add_argument('--kfold', '-k', default=5, type=int)
397 | # parser.add_argument('--kfold', '-k', action='store_true')
398 | args = parser.parse_args()
399 | a = args.kfold
400 | if args.tag_version == 'bio':
401 | create_test_BIO(excel_path)
402 | else:
403 | create_test_BIEO(excel_path)
404 | # statistic_wrong_cons_bieo()
405 | if args.train_type == 'lstm':
406 | if args.kfold == 1:
407 | train(excel_path, model_name=args.model_name, tag_version=args.tag_version)
408 | if args.debug == '1':
409 | evaluate(excel_path, word2id_version='word2id.json', model_name=args.model_name, tag_version=args.tag_version)
410 | else:
411 | train_kfold(excel_path, model_name=args.model_name, kfold=args.kfold, tag_version=args.tag_version)
412 | if args.debug == '1':
413 | evaluate_kfold(excel_path, kfold=args.kfold, word2id_version='word2id.json', model_name=args.model_name,
414 | tag_version=args.tag_version)
415 | else:
416 | train2(excel_path, model_name=args.model_name, tag_version=args.tag_version)
417 | if args.debug == '1':
418 | evaluate(excel_path, word2id_version='word2id_bert.json', model_name=args.model_name, tag_version=args.tag_version)
--------------------------------------------------------------------------------
/code/triples.py:
--------------------------------------------------------------------------------
1 | """
2 | @Author: hezf
3 | @Time: 2021/6/8 11:08
4 | @desc: 负责RDF图的构建以及查询
5 | """
6 | from rdflib import Graph, Namespace
7 | from utils import to_sparql
8 | from typing import List, Dict, Tuple
9 | import re
10 |
11 |
12 | class KnowledgeGraph(object):
13 | def __init__(self, rdf_file):
14 | super(KnowledgeGraph, self).__init__()
15 | self.rdf_file = rdf_file
16 | self.graph = Graph()
17 | self.init()
18 |
19 | def init(self):
20 | # print('parsing file...')
21 | self.graph.parse(self.rdf_file, format='n3')
22 |
23 | def query(self, q):
24 | """
25 | 在图谱中查询答案
26 | :param q: sparql语句
27 | :return: List[str]
28 | """
29 | answers = []
30 | results = self.graph.query(q).bindings
31 | for result in results:
32 | for value in result.values():
33 | value = value.toPython()
34 | value = value[value.rfind('/')+1:]
35 | answers.append(value)
36 | return answers
37 |
38 | def fetch_ans(self, question: str, ans_type: str, entity: str, main_property: List, operator: str,
39 | sub_properties: List[Tuple] = None):
40 | """
41 | 获取最终答案的函数
42 | :param question: 问题本身
43 | :param ans_type: 答案类型【比较句、并列句、属性值】
44 | :param entity: 实体
45 | :param main_property: 属性名["档位介绍-流量"或者"XX规则"],并列句有两个属性名,而其他句式只有一个属性名
46 | :param sub_properties: 约束属性键值对【[(key, value), (key, value), ...]】。不用字典是因为比较句中key有可能相同
47 | :param operator: 约束算子【min max other】
48 | :return:
49 | """
50 | # 问句特色:
51 | # 1 比较句:属性名只要一个,比较句的约束属性名肯定为两个;句式一般为:是否和哪个两种;
52 | # 2 并列句:有多个属性名,少量情况含有min,max算子,只要将属性名分批答案取出即可
53 | # 3 属性句:最普通的情况
54 | try:
55 | for i in range(len(sub_properties)):
56 | if sub_properties[i][0] == '流量':
57 | value = int(sub_properties[i][1])
58 | if 1 < value < 55 and entity != '流量加油包':
59 | value *= 1024
60 | sub_properties[i] = ('流量', value)
61 | # print('修改流量的句子是:', question)
62 | if ans_type == '比较句':
63 | if len(sub_properties) == 0:
64 | return []
65 | # 处理当预测的约束属性名个数大于两个时的情况
66 | if len(sub_properties) > 2:
67 | key_count = {}
68 | for k, v in sub_properties:
69 | if k not in key_count:
70 | key_count[k] = 0
71 | key_count[k] += 1
72 | real_k = None
73 | for k, count in key_count.items():
74 | if count >= 2:
75 | real_k = k
76 | break
77 | if real_k is not None:
78 | for k, v in sub_properties:
79 | if k != real_k:
80 | sub_properties.remove((k, v))
81 | q0 = to_sparql(entity=entity, main_property=main_property[0], sub_properties=[sub_properties[0]])
82 | ans0 = self.query(q0)
83 | if len(sub_properties) >= 2:
84 | q1 = to_sparql(entity=entity, main_property=main_property[0], sub_properties=[sub_properties[1]])
85 | ans1 = self.query(q1)
86 | else:
87 | ans1 = None
88 | if len(ans0) > 0:
89 | ans0 = ans0[0]
90 | else:
91 | ans0 = None
92 | if ans1 is not None and len(ans1) > 0:
93 | ans1 = ans1[0]
94 | else:
95 | ans1 = None
96 | keywords = ['哪', '那个']
97 | flag = False
98 | for kw in keywords:
99 | if kw in question:
100 | flag = True
101 | break
102 | # "哪个"类型的问题
103 | if flag:
104 | if ans1 is None:
105 | ans = [sub_properties[0][1]]
106 | else:
107 | if ans0 == ans1:
108 | ans = [sub_properties[0][1], sub_properties[1][1]]
109 | else:
110 | bigger_keywords = ['多', '贵']
111 | smaller_keywords = ['便宜', '优惠', '少', '实惠']
112 | big_flag = False
113 | for bkw in bigger_keywords:
114 | if bkw in question:
115 | big_flag = True
116 | break
117 | if big_flag:
118 | if int(ans0) > int(ans1):
119 | ans = [sub_properties[0][1]]
120 | else:
121 | ans = [sub_properties[1][1]]
122 | else:
123 | if int(ans0) < int(ans1):
124 | ans = [sub_properties[0][1]]
125 | else:
126 | ans = [sub_properties[1][1]]
127 | # "是否"类型的问题
128 | else:
129 | if ans1 is None:
130 | ans = ['no']
131 | else:
132 | if ans0 == ans1:
133 | ans = ['yes']
134 | else:
135 | ans = ['no']
136 | # equal_kw = ['一样', '相同', '等价', '等同', '相等']
137 | # eq_q = False
138 | # for k in equal_kw:
139 | # if question.find(k) != -1:
140 | # eq_q = True
141 | # break
142 | # # 相等问题
143 | # if eq_q:
144 | # if ans0 == ans1:
145 | # ans = ['yes']
146 | # else:
147 | # ans = ['no']
148 | # # 区别问题
149 | # else:
150 | # if ans0 != ans1:
151 | # ans = ['yes']
152 | # else:
153 | # ans = ['no']
154 | # 并列句
155 | elif ans_type == '并列句':
156 | ans = []
157 | if operator == 'min':
158 | if len(sub_properties) == 0:
159 | return []
160 | # 分两段获取答案
161 | key = sub_properties[0][0]
162 | query_str = to_sparql(entity=entity, main_property='档位介绍-'+key)
163 | temp_ans = self.query(query_str)
164 | temp_ans = [int(i) for i in temp_ans]
165 | if len(temp_ans) == 0:
166 | return []
167 | first_step_ans = min(temp_ans)
168 | for m_p in main_property:
169 | query_str = to_sparql(entity=entity, main_property=m_p, sub_properties=[(key, first_step_ans)])
170 | ans += self.query(query_str)
171 | elif operator == 'max':
172 | if len(sub_properties) == 0:
173 | return []
174 | key = sub_properties[0][0]
175 | query_str = to_sparql(entity=entity, main_property='档位介绍-'+key)
176 | temp_ans = self.query(query_str)
177 | temp_ans = [int(i) for i in temp_ans]
178 | if len(temp_ans) == 0:
179 | return []
180 | first_step_ans = max(temp_ans)
181 | for m_p in main_property:
182 | query_str = to_sparql(entity=entity, main_property=m_p, sub_properties=[(key, first_step_ans)])
183 | ans += self.query(query_str)
184 | # =!=
185 | else:
186 | for m_p in main_property:
187 | query_str = to_sparql(entity=entity, main_property=m_p, sub_properties=sub_properties)
188 | ans += self.query(query_str)
189 | # 属性值
190 | elif ans_type == '属性值':
191 | # 答案只有一个,只需要选择一个最小(大)的即可。分两阶段获取答案
192 | if operator == 'min':
193 | if len(sub_properties) == 0:
194 | # 直接来近似得到
195 | temp_ans = self.query(to_sparql(entity=entity, main_property=main_property[0]))
196 | temp_ans = [int(i) for i in temp_ans]
197 | ans = min(temp_ans)
198 | return [ans]
199 | else:
200 | key = sub_properties[0][0]
201 | query_str = to_sparql(entity=entity, main_property=main_property[0].split('-')[0]+'-'+key)
202 | temp_ans = self.query(query_str)
203 | temp_ans = [int(i) for i in temp_ans]
204 | first_step_ans = min(temp_ans)
205 | query_str = to_sparql(entity=entity, main_property=main_property[0], sub_properties=[(key, first_step_ans)])
206 | ans = self.query(query_str)
207 | elif operator == 'max':
208 | if len(sub_properties) == 0:
209 | query_str = to_sparql(entity=entity, main_property=main_property[0])
210 | temp_ans = self.query(query_str)
211 | temp_ans = [int(i) for i in temp_ans]
212 | ans = max(temp_ans)
213 | return [ans]
214 | else:
215 | key = sub_properties[0][0]
216 | query_str = to_sparql(entity=entity, main_property=main_property[0].split('-')[0]+'-'+key)
217 | temp_ans = self.query(query_str)
218 | # 修复字符串'700'大于'10000'的bug
219 | temp_ans = [int(i) for i in temp_ans]
220 | first_step_ans = max(temp_ans)
221 | query_str = to_sparql(entity=entity, main_property=main_property[0], sub_properties=[(key, first_step_ans)])
222 | ans = self.query(query_str)
223 | else:
224 | query_str = to_sparql(entity=entity, main_property=main_property[0], sub_properties=sub_properties)
225 | ans = self.query(query_str)
226 | else:
227 | ans = []
228 | print('-------------------{}:乱码对结果造成影响---------------------'.format(ans_type))
229 | ans = list(set(ans))
230 | return ans
231 | except Exception as e:
232 | print('---------------查找知识图谱时发生异常:{}------------------------'.format(e))
233 | print('当前参数为,问题:{}, 答案类型:{}, 实体:{}, 属性名{}, 算子:{}, 约束属性:{}'.format(question, ans_type, entity, main_property, operator, sub_properties))
234 | return []
235 |
236 | def fetch_wrong_ans(self, question: str, ans_type: str, entity: str, main_property: List, operator: str,
237 | sub_properties: List[Tuple] = None):
238 | """
239 | 在没有约束传入时,查看对应的子业务、价格、流量有哪些,从而可以与原句对比,看是否有遗漏的约束标注
240 | :param question: 问题本身
241 | :param ans_type: 答案类型【比较句、并列句、属性值】
242 | :param entity: 实体
243 | :param main_property: 属性名["档位介绍-流量"或者"XX规则"],并列句有两个属性名,而其他句式只有一个属性名
244 | :param sub_properties: 约束属性键值对【[(key, value), (key, value), ...]】。不用字典是因为比较句中key有可能相同
245 | :param operator: 约束算子【min max other】
246 | :return:
247 | """
248 | # 问句特色:
249 | # 1 比较句:属性名只要一个,比较句的约束属性名肯定为两个;句式一般为:是否和哪个两种;
250 | # 2 并列句:有多个属性名,少量情况含有min,max算子,只要将属性名分批答案取出即可
251 | # 3 属性句:最普通的情况
252 | for i in range(len(sub_properties)):
253 | if sub_properties[i][0] == '流量':
254 | value = int(sub_properties[i][1])
255 | if 1 < value < 55:
256 | value *= 1024
257 | sub_properties[i] = ('流量', value)
258 | # print('修改流量的句子是:', question)
259 | ans, query_str = [], ''
260 | # 比较句不需要判断,认为约束的标注必然是对的
261 | if ans_type == '比较句':
262 | # pass
263 | # 判断句仍然需要标注。。。
264 | price_property = '档位介绍-价格'
265 | flow_property = '档位介绍-流量'
266 | service_property = '档位介绍-子业务'
267 | query_str = to_sparql(entity=entity, main_property=price_property, sub_properties=[])
268 | price_ans = self.query(query_str)
269 | query_str = to_sparql(entity=entity, main_property=flow_property, sub_properties=[])
270 | flow_ans = self.query(query_str)
271 | query_str = to_sparql(entity=entity, main_property=service_property, sub_properties=[])
272 | service_ans = self.query(query_str)
273 | # 并列句不需要判断,认为约束一般都是没有的
274 | elif ans_type == '并列句':
275 | # pass
276 | # 并列句仍然需要标注
277 | price_property = '档位介绍-价格'
278 | flow_property = '档位介绍-流量'
279 | service_property = '档位介绍-子业务'
280 | query_str = to_sparql(entity=entity, main_property=price_property, sub_properties=[])
281 | price_ans = self.query(query_str)
282 | query_str = to_sparql(entity=entity, main_property=flow_property, sub_properties=[])
283 | flow_ans = self.query(query_str)
284 | query_str = to_sparql(entity=entity, main_property=service_property, sub_properties=[])
285 | service_ans = self.query(query_str)
286 | # 属性句
287 | else:
288 | # 答案只有一个,只需要选择一个最小(大)的即可。分两阶段获取答案
289 | if operator == 'min':
290 | if len(sub_properties) == 0:
291 | # 直接来近似得到
292 | temp_ans = self.query(to_sparql(entity=entity, main_property=main_property[0]))
293 | temp_ans = [int(i) for i in temp_ans]
294 | ans = min(temp_ans)
295 | return [ans]
296 | else:
297 | key = sub_properties[0][0]
298 | query_str = to_sparql(entity=entity, main_property=main_property[0].split('-')[0]+'-'+key)
299 | temp_ans = self.query(query_str)
300 | temp_ans = [int(i) for i in temp_ans]
301 | first_step_ans = min(temp_ans)
302 | query_str = to_sparql(entity=entity, main_property=main_property[0], sub_properties=[(key, first_step_ans)])
303 | ans = self.query(query_str)
304 | elif operator == 'max':
305 | if len(sub_properties) == 0:
306 | query_str = to_sparql(entity=entity, main_property=main_property[0])
307 | temp_ans = self.query(query_str)
308 | temp_ans = [int(i) for i in temp_ans]
309 | ans = max(temp_ans)
310 | return [ans]
311 | else:
312 | key = sub_properties[0][0]
313 | query_str = to_sparql(entity=entity, main_property=main_property[0].split('-')[0]+'-'+key)
314 | temp_ans = self.query(query_str)
315 | # 修复字符串'700'大于'10000'的bug
316 | temp_ans = [int(i) for i in temp_ans]
317 | first_step_ans = max(temp_ans)
318 | query_str = to_sparql(entity=entity, main_property=main_property[0], sub_properties=[(key, first_step_ans)])
319 | ans = self.query(query_str)
320 | else:
321 | # if '-' in main_property[0]:
322 | price_property = '档位介绍-价格'
323 | flow_property = '档位介绍-流量'
324 | service_property = '档位介绍-子业务'
325 | query_str = to_sparql(entity=entity, main_property=price_property, sub_properties=[])
326 | price_ans = self.query(query_str)
327 | query_str = to_sparql(entity=entity, main_property=flow_property, sub_properties=[])
328 | flow_ans = self.query(query_str)
329 | query_str = to_sparql(entity=entity, main_property=service_property, sub_properties=[])
330 | service_ans = self.query(query_str)
331 | price_ans = list(set(price_ans))
332 | flow_ans = list(set(flow_ans))
333 | service_ans = list(set(service_ans))
334 | return (price_ans, flow_ans, service_ans)
335 |
336 | if __name__ == '__main__':
337 | kg = KnowledgeGraph('../data/process_data/triples.rdf')
338 | # q = 'select ?ans where { ?instance. ?instance ?ans}'
339 | # ans = kg.query(q)
340 | ans = kg.fetch_ans(**{'question': '你好!花季守护的软件,需要在孩子的手机里安装东西吗?', 'ans_type': '属性值', 'entity': '花季守护业务', 'main_property': ['档位介绍-使用方法'], 'operator': 'other', 'sub_properties': []})
341 | print(ans)
342 |
343 | import pandas as pd
344 | df = pd.read_excel('../data/raw_data/train_denoised.xlsx')
345 | df.fillna('')
346 | id_list = []
347 | ans_list = []
348 | for iter, row in df.iterrows():
349 | ans_true = list(set(row['答案'].split('|')))
350 | question = row['用户问题']
351 | ans_type = row['答案类型']
352 | entity = row['实体']
353 | main_property = row['属性名'].split('|')
354 | operator = row['约束算子']
355 | if operator != 'min' and operator != 'max':
356 | operator == 'other'
357 | sub_properties = []
358 | cons_names = str(row['约束属性名']).split('|')
359 | cons_values = str(row['约束属性值']).split('|')
360 | if cons_names == ['nan']: cons_names = []
361 | for index in range(len(cons_names)):
362 | sub_properties.append([cons_names[index], cons_values[index]])
363 | ans = kg.fetch_ans(question, ans_type, entity, main_property, operator, sub_properties)
364 |
365 | def is_same(ans, ans_true):
366 | for an in ans:
367 | if an in ans_true:
368 | ans_true.remove(an)
369 | else:
370 | return False
371 | if len(ans_true) != 0:
372 | return False
373 | return True
374 |
375 |
376 | if not is_same(ans, ans_true):
377 | id_list.append(iter)
378 | ans_list.append(ans)
379 | print(id_list)
380 | df_save = df.iloc[id_list, [0, 1, 2, 3, 4, 5, 6, 7]]
381 | df_save['预测'] = ans_list
382 | print(df_save)
383 | df_save.to_excel('/data/huangbo/project/Tianchi_nlp_git/data/raw_data/train_wrong_triple.xlsx', index=False)
384 |
--------------------------------------------------------------------------------
/code/utils.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import *
2 | from rdflib import Graph, Namespace
3 | import rdflib
4 | import copy
5 | from typing import List, Dict, Tuple
6 | import pandas as pd
7 | import numpy as np
8 | import torch
9 | import os
10 | import random
11 | import json
12 | import time
13 | import logging
14 | from datetime import timedelta
15 |
16 |
17 | class TrainingConfig(object):
18 | """
19 | BiLSTM模型的训练参数
20 | """
21 | batch_size = 64
22 | # 学习速率
23 | lr = 0.001
24 | epoches = 50
25 | print_step = 10
26 |
27 |
28 | class LSTMConfig(object):
29 | """
30 | BiLSTM模型中LSTM模块的参数
31 | """
32 | emb_size = 128 # 词向量的维数
33 | hidden_size = 128 # lstm隐向量的维数
34 |
35 |
36 | class BertTrainingConfig(object):
37 | """
38 | BiLSTM模型的训练参数
39 | """
40 | batch_size = 64
41 | # 学习速率
42 | lr = 0.00001
43 | other_lr = 0.0001
44 | epoches = 50
45 | print_step = 5
46 |
47 |
48 | class BERTLSTMConfig(object):
49 | """
50 | BERT+BiLSTM模型中LSTM模块的参数
51 | """
52 | emb_size = 768 # 词向量的维数
53 | hidden_size = 512 # lstm隐向量的维数
54 |
55 |
56 | def score_evalution(answers, predictions):
57 | """
58 | 用avr_F1评价预测结果
59 | @param answers: {'0':'asd|sdf', '1':'qwe|wer' }
60 | @param predictions: {'0':'asd|sdf', '1':'qwe|wer' }
61 | @return: avr_F1
62 | """
63 | avr_F1 = 0
64 | for index, answer in answers.items():
65 | prediction = predictions[index]
66 | answer_list = answer.split('|')
67 | answer_set = set()
68 | for item in answer_list:
69 | answer_set.add(item)
70 |
71 | prediction_list = prediction.split('|')
72 | prediction_set = set()
73 | for item in prediction_list:
74 | prediction_set.add(item)
75 |
76 | intersection_set = answer_set.intersection(prediction_set)
77 |
78 | A = len(answer_set)
79 | G = len(prediction_set)
80 | if G==0 or len(intersection_set) == 0:
81 | avr_F1 += 0
82 | continue
83 | P = len(intersection_set)/(A * 1.0)
84 | R = len(intersection_set)/(G * 1.0)
85 | avr_F1 += (2 * P * R)/(P + R)
86 | avr_F1 /= len(answers)
87 | return avr_F1
88 |
89 | def cal_lstm_crf_loss(crf_scores, targets, tag2id):
90 | """计算双向LSTM-CRF模型的损失
91 | 该损失函数的计算可以参考:https://arxiv.org/pdf/1603.01360.pdf
92 | """
93 | targets_copy = copy.deepcopy(targets)
94 | pad_id = tag2id.get('')
95 | start_id = tag2id.get('')
96 | end_id = tag2id.get('')
97 |
98 | device = crf_scores.device
99 |
100 | # targets_copy:[B, L] crf_scores:[B, L, T, T]
101 | batch_size, max_len = targets_copy.size()
102 | target_size = len(tag2id)
103 |
104 | # mask = 1 - ((targets_copy == pad_id) + (targets_copy == end_id)) # [B, L]
105 | mask = (targets_copy != pad_id)
106 | lengths = mask.sum(dim=1)
107 | targets_copy = indexed(targets_copy, target_size, start_id)
108 |
109 | # # 计算Golden scores方法1
110 | # import pdb
111 | # pdb.set_trace()
112 | targets_copy = targets_copy.masked_select(mask) # [real_L]
113 |
114 | flatten_scores = crf_scores.masked_select(
115 | mask.view(batch_size, max_len, 1, 1).expand_as(crf_scores)
116 | ).view(-1, target_size*target_size).contiguous()
117 |
118 | golden_scores = flatten_scores.gather(
119 | dim=1, index=targets_copy.unsqueeze(1)).sum()
120 |
121 | # 计算golden_scores方法2:利用pack_padded_sequence函数
122 | # targets_copy[targets_copy == end_id] = pad_id
123 | # scores_at_targets = torch.gather(
124 | # crf_scores.view(batch_size, max_len, -1), 2, targets_copy.unsqueeze(2)).squeeze(2)
125 | # scores_at_targets, _ = pack_padded_sequence(
126 | # scores_at_targets, lengths-1, batch_first=True
127 | # )
128 | # golden_scores = scores_at_targets.sum()
129 |
130 | # 计算all path scores
131 | # scores_upto_t[i, j]表示第i个句子的第t个词被标注为j标记的所有t时刻事前的所有子路径的分数之和
132 | scores_upto_t = torch.zeros(batch_size, target_size).to(device)
133 | for t in range(max_len):
134 | # 当前时刻 有效的batch_size(因为有些序列比较短)
135 | batch_size_t = (lengths > t).sum().item()
136 | if t == 0:
137 | scores_upto_t[:batch_size_t] = crf_scores[:batch_size_t,
138 | t, start_id, :]
139 | else:
140 | # We add scores at current timestep to scores accumulated up to previous
141 | # timestep, and log-sum-exp Remember, the cur_tag of the previous
142 | # timestep is the prev_tag of this timestep
143 | # So, broadcast prev. timestep's cur_tag scores
144 | # along cur. timestep's cur_tag dimension
145 | scores_upto_t[:batch_size_t] = torch.logsumexp(
146 | crf_scores[:batch_size_t, t, :, :] +
147 | scores_upto_t[:batch_size_t].unsqueeze(2),
148 | dim=1
149 | )
150 | all_path_scores = scores_upto_t[:, end_id].sum()
151 |
152 | # 训练大约两个epoch loss变成负数,从数学的角度上来说,loss = -logP
153 | loss = (all_path_scores - golden_scores) / batch_size
154 | return loss
155 |
156 |
157 | def tensorized(batch, maps):
158 | """
159 | BiLSTM训练使用
160 | @param batch:
161 | @param maps:
162 | @return:
163 | """
164 | PAD = maps.get('')
165 | UNK = maps.get('')
166 |
167 | max_len = len(batch[0])
168 | batch_size = len(batch)
169 |
170 | batch_tensor = torch.ones(batch_size, max_len).long() * PAD
171 | for i, l in enumerate(batch):
172 | for j, e in enumerate(l):
173 | batch_tensor[i][j] = maps.get(e, UNK)
174 | # batch各个元素的长度
175 | lengths = [len(l) for l in batch]
176 |
177 | return batch_tensor, lengths
178 |
179 |
180 | def tensorized_bert(batch, maps):
181 | """
182 | BiLSTM训练使用
183 | @param batch:
184 | @param maps:
185 | @return:
186 | """
187 | PAD = maps.get('[PAD]')
188 | UNK = maps.get('[UNK]')
189 |
190 | max_len = len(batch[0])
191 | batch_size = len(batch)
192 |
193 | batch_tensor = torch.ones(batch_size, max_len).long() * PAD
194 | for i, l in enumerate(batch):
195 | for j, e in enumerate(l):
196 | batch_tensor[i][j] = maps.get(e, UNK)
197 | # batch各个元素的长度
198 | lengths = [len(l) for l in batch]
199 |
200 | return batch_tensor, lengths
201 |
202 |
203 | def sort_by_lengths(word_lists, tag_lists):
204 | pairs = list(zip(word_lists, tag_lists))
205 | indices = sorted(range(len(pairs)),
206 | key=lambda k: len(pairs[k][0]),
207 | reverse=True)
208 | pairs = [pairs[i] for i in indices]
209 | # pairs.sort(key=lambda pair: len(pair[0]), reverse=True)
210 |
211 | word_lists, tag_lists = list(zip(*pairs))
212 |
213 | return word_lists, tag_lists, indices
214 |
215 |
216 | def indexed(targets, tagset_size, start_id):
217 | """将targets中的数转化为在[T*T]大小序列中的索引,T是标注的种类"""
218 | batch_size, max_len = targets.size()
219 | for col in range(max_len-1, 0, -1):
220 | targets[:, col] += (targets[:, col-1] * tagset_size)
221 | targets[:, 0] += (start_id * tagset_size)
222 | return targets
223 |
224 |
225 | def get_operator(question):
226 | """
227 | 得到句子的约束算子,只针对min\max的算子,对算子为‘=’或者空的情况均返回None
228 | @param question: 用户问题(str)
229 | @return: 约束算子 和 约束属性
230 | """
231 | operater, obj = None, None
232 | if '最多' in question:
233 | operater, obj = 'max', '流量'
234 | elif '最少' in question:
235 | operater, obj = 'min', '流量'
236 | elif '最便宜的' in question or '最实惠的' in question:
237 | operater, obj = 'min', '价格'
238 | elif '最贵' in question:
239 | operater, obj = 'max', '价格'
240 | return operater, obj
241 |
242 |
243 | def parse_triples_file(file_path):
244 | """
245 | 解析triples.txt为triples.rdf
246 | """
247 | triples = []
248 | # read raw_triples
249 | with open(file_path, 'r', encoding='utf-8') as f:
250 | for line in f:
251 | line = line.strip()
252 | line = line.replace(' ', '')
253 | line_block = line.split(' ')
254 | t = []
255 | flag = True
256 | for block in line_block:
257 | # 调整三元组格式与训练文件一致
258 | block = block.replace('档位介绍表', '档位介绍')
259 | # 三元组改成小写
260 | block = block.lower()
261 | if block.find('http://yunxiaomi.com/kbqa') != -1:
262 | _ = block.find('_')
263 | t.append('http://yunxiaomi.com/'+block[_+1: -1])
264 | else:
265 | block = block[1:-1]
266 | if '|' in block:
267 | blocks = block.split('|')
268 | for b in blocks:
269 | flag = False
270 | copy_t = copy.deepcopy(t)
271 | copy_t.append('http://yunxiaomi.com/'+b)
272 | triples.append(copy_t)
273 | else:
274 | t.append('http://yunxiaomi.com/'+block)
275 | if flag:
276 | triples.append(t)
277 | # transfer to RDF file
278 | graph = Graph()
279 | for triple in triples:
280 | if triple is None:
281 | print('None')
282 | continue
283 | graph.add((rdflib.term.URIRef(u'{}'.format(triple[0])),
284 | rdflib.term.URIRef(u'{}'.format(triple[1])),
285 | rdflib.term.URIRef(u'{}'.format(triple[2]))))
286 | graph.serialize('../data/process_data/triples.rdf', format='n3')
287 | graph.close()
288 |
289 |
290 | def to_sparql(entity: str, main_property: str, sub_properties: List[Tuple] = None):
291 | """
292 | 将实体信息等转化成sparql语句
293 | :param entity: str
294 | :param main_property: str
295 | :param sub_properties: List[Tuple[key, value], ...]
296 | :return: str
297 | """
298 | if sub_properties is None:
299 | sub_properties = []
300 | prefix = ''
301 | # relations第0个元素存放主要关系, 之后存放次要关系
302 | relations, conditions = [], ''
303 | # ”档位介绍-取消方式“ 类场景
304 | if main_property.find('-') != -1:
305 | relations.append(main_property.split('-')[0])
306 | relations.append(main_property.split('-')[1])
307 | # “生效规则” 类场景
308 | else:
309 | relations.append(main_property)
310 | # 填充关系到sparql
311 | for i, r in enumerate(relations):
312 | condition = ''
313 | if i == 0:
314 | condition = prefix.format(entity) + ' ' + prefix.format(relations[i])
315 | if len(relations) > 1:
316 | condition += ' ?instance'
317 | else:
318 | condition += ' ?ans'
319 | elif i == 1:
320 | condition += '?instance ' + prefix.format(relations[i]) + ' ?ans'
321 | # else:
322 | # condition += '?instance ' + prefix.format(relations[i]) + ' ' + prefix.format(relations[i])
323 | if len(relations)-1 != i or (len(relations) > 1 and len(sub_properties) > 0):
324 | condition += '. '
325 | conditions += condition
326 | idx = 0
327 | if len(relations) > 1:
328 | for key, value in sub_properties:
329 | condition = '?instance ' + prefix.format(key) + ' ' + prefix.format(value)
330 | if len(sub_properties) - 1 > idx:
331 | condition += '. '
332 | conditions += condition
333 | idx += 1
334 | s = """select ?ans where {%s}""" % (conditions, )
335 | return s
336 |
337 |
338 | def make_dataset(data_path, target_file, label_file, train=True):
339 | if isinstance(data_path, list):
340 | df = pd.DataFrame(pd.read_excel(data_path[0]))
341 | for i in range(1, len(data_path)):
342 | df = pd.concat([df, pd.DataFrame(pd.read_excel(data_path[i]))], ignore_index=True)
343 | else:
344 | df = pd.DataFrame(pd.read_excel(data_path))
345 | # # entity_map
346 | # with open('../data/file/entity_map.json', 'r', encoding='utf-8') as f:
347 | # entity_mapping = json.load(f)
348 | # 标签和id的映射
349 | ans_label2id = {}
350 | prop_label2id = {}
351 | entity_label2id = {}
352 | # 开通方式、条件
353 | binary_label2id = {'档位介绍-开通方式': 0, '档位介绍-开通条件': 1}
354 | # 标注数据xlxs转化成txt
355 | if train:
356 | df = df.loc[:, ['用户问题', '答案类型', '属性名', '实体', '答案', '有效率']]
357 | # 未标注数据xlxs转化成txt
358 | else:
359 | length = []
360 | columns = df.columns
361 | for column in columns:
362 | length.append(len(str(df.iloc[1].at[column])))
363 | max_id = np.argmax(length)
364 | # df = df.loc[:, ['query']]
365 | df = df.loc[:, [columns[max_id]]]
366 | # 转换数据
367 | with open(target_file, 'w', encoding='utf-8') as f:
368 | for i in range(len(df)):
369 | line = df.loc[i]
370 | line = list(line)
371 | for idx in range(len(line)):
372 | # 大写字母改成小写
373 | line[idx] = str(line[idx]).strip().lower()
374 | if train:
375 | # 答案类型
376 | if line[1] not in ans_label2id:
377 | ans_label2id[line[1]] = len(ans_label2id)
378 | # 属性名
379 | sub_blocks = line[2].split('|')
380 | for j, sub_b in enumerate(sub_blocks):
381 | # 把这几类统一变成”其他“类
382 | # if sub_b in ('适用app', '生效规则', '叠加规则', '封顶规则'):
383 | # sub_b = '其他'
384 | # sub_blocks[j] = sub_b
385 | if sub_b not in prop_label2id:
386 | prop_label2id[sub_b] = len(prop_label2id)
387 | # line[2] = '|'.join(list(set(sub_blocks)))
388 | # # 实体
389 | # if line[3] in entity_mapping:
390 | # line[3] = entity_mapping[line[3]]
391 | if line[3] not in entity_label2id:
392 | entity_label2id[line[3]] = len(entity_label2id)
393 | f.write('\t'.join(line)+'\n')
394 | # 整理标签和id的映射
395 | if train:
396 | if label_file is not None:
397 | with open(label_file, 'w', encoding='utf-8') as f:
398 | json.dump({'ans_type': ans_label2id, 'main_property': prop_label2id, 'entity': entity_label2id,
399 | 'binary_type': binary_label2id},
400 | fp=f,
401 | ensure_ascii=False, indent=2)
402 |
403 |
404 | def make_dataset_for_binary(data_path, target_file):
405 | """
406 | 为二分类任务制作数据集
407 | :param data_path:
408 | :param target_file:
409 | :return:
410 | """
411 | if isinstance(data_path, list):
412 | df = pd.DataFrame(pd.read_excel(data_path[0]))
413 | for i in range(1, len(data_path)):
414 | df = pd.concat([df, pd.DataFrame(pd.read_excel(data_path[i]))], ignore_index=True)
415 | else:
416 | df = pd.DataFrame(pd.read_excel(data_path))
417 | # 标注数据xlxs转化成txt
418 | df = df.loc[:, ['用户问题', '答案类型', '属性名', '实体', '答案', '有效率']]
419 | method_count, condition_text = 0, []
420 | # 转换数据
421 | with open(target_file, 'w', encoding='utf-8') as f:
422 | for i in range(len(df)):
423 | line = df.loc[i]
424 | if str(line['用户问题']) in ('20元20g还可以办理吗', ):
425 | continue
426 | props = str(line['属性名'])
427 | line = [str(line['用户问题']), '', str(line['有效率'])]
428 | if '开通方式' in props and '开通条件' not in props:
429 | line[1] = '档位介绍-开通方式'
430 | f.write('\t'.join(line)+'\n')
431 | method_count += 1
432 | elif '开通方式' not in props and '开通条件' in props:
433 | line[1] = '档位介绍-开通条件'
434 | f.write('\t'.join(line)+'\n')
435 | condition_text.append(line)
436 | # 中和结果,重采样
437 | if method_count > len(condition_text):
438 | for i in range(method_count - len(condition_text)):
439 | line = random.choice(condition_text)
440 | f.write('\t'.join(line)+'\n')
441 |
442 |
443 | def label_to_multi_hot(max_len, label_ids):
444 | """
445 | 转化成multi-hot编码
446 | :param max_len: multi-hot编码的最大长度
447 | :param label_ids:
448 | :return:
449 | """
450 | multi_hot = [0] * max_len
451 | for idx in label_ids:
452 | multi_hot[idx] = 1
453 | return multi_hot
454 |
455 |
456 | def logits_to_multi_hot(data: torch.Tensor, ans_pred: torch.Tensor, label_hub, threshold=0.5):
457 | """
458 | logits转化成multi-hot编码
459 | 1 运用规则: 当并列句中,属性名肯定有两个,而其他句子中属性名只有一个
460 | :return:
461 | """
462 | if data.is_cuda:
463 | data = data.detach().data.cpu().numpy()
464 | else:
465 | data = data.detach().numpy()
466 | result = []
467 | for i in range(data.shape[0]):
468 | temp = [0] * data.shape[1]
469 | if label_hub.ans_id2label[ans_pred[i].item()] == '并列句':
470 | # 获取分数最高的两个属性
471 | max_idx = np.argmax(data[i], axis=-1)
472 | data[i][max_idx] = -1e5
473 | temp[max_idx] = 1
474 | max_idx = np.argmax(data[i], axis=-1)
475 | temp[max_idx] = 1
476 | else:
477 | # 获取分数最高的1个属性
478 | max_idx = np.argmax(data[i], axis=-1)
479 | temp[max_idx] = 1
480 | result.append(temp)
481 | return np.array(result)
482 |
483 |
484 | def logits_to_multi_hot_old_version(data: torch.Tensor, threshold=0.5):
485 | """
486 | logits转化成multi-hot编码【没有考虑答案类型的旧版本】
487 | :return:
488 | """
489 | if data.is_cuda:
490 | data = data.detach().data.cpu().numpy()
491 | else:
492 | data = data.numpy()
493 | result = []
494 | for i in range(data.shape[0]):
495 | result.append([1 if v >= threshold else 0 for v in list(data[i])])
496 | return np.array(result)
497 |
498 |
499 | def remove_stop_words(sentence: str, stopwords: List):
500 | """
501 | 移除停用词
502 | :param sentence:
503 | :param stopwords:
504 | :return:
505 | """
506 | for word in stopwords:
507 | sentence = sentence.replace(word, '')
508 | return sentence
509 |
510 |
511 | def load_model(model_path, map_location=None):
512 | print('模型加载中...')
513 | model = torch.load(model_path, map_location=map_location)
514 | return model
515 |
516 |
517 | def save_model(model, model_path, model_name, debug='1'):
518 | # debug模式下,需要删除之前同名的模型
519 | if debug == '1':
520 | for file in os.listdir(model_path):
521 | # 删除同一个模型
522 | if model_name[:-20] in file:
523 | os.remove(os.path.join(model_path, file))
524 | print('模型保存中...')
525 | torch.save(model, os.path.join(model_path, model_name))
526 |
527 |
528 | def setup_seed(seed):
529 | """
530 | 确定随机数
531 | :param seed: 种子
532 | :return:
533 | """
534 | torch.manual_seed(seed)
535 | torch.cuda.manual_seed_all(seed)
536 | random.seed(seed)
537 | np.random.seed(seed)
538 | os.environ['PYTHONHASHSEED'] = str(seed) # 为了禁止hash随机化,使得实验可复现。
539 |
540 |
541 | def split_train_dev(data_size: int, radio=0.8):
542 | """
543 | 从标注数据中随机划分训练与验证集,返回标注数据中的行号列表
544 | :param data_size:
545 | :param radio:
546 | :return: [[id1, id2, id3...], [idx,... ]]
547 | """
548 | line_ids = [i for i in range(data_size)]
549 | random.shuffle(line_ids)
550 | train_ids = line_ids[:int(data_size*radio)]
551 | dev_ids = line_ids[int(data_size*radio):]
552 | with open('../data/file/train_dev_ids.json', 'w', encoding='utf-8') as f:
553 | json.dump({'train_ids': train_ids, 'dev_ids': dev_ids}, f, ensure_ascii=False, indent=2)
554 |
555 |
556 | def split_labeled_data(source_file, train_file, dev_file):
557 | """
558 | 根据训练与验证的ID,写入文件
559 | :param source_file:
560 | :param train_file:
561 | :param dev_file:
562 | :return:
563 | """
564 | train_dev_ids = json.load(open('../data/file/train_dev_ids.json', 'r', encoding='utf-8'))
565 | train_dev_ids['dev_ids'] = set(train_dev_ids['dev_ids'])
566 | train_data, dev_data = [], []
567 | with open(source_file, 'r', encoding='utf-8') as f:
568 | for i, line in enumerate(f):
569 | if i in train_dev_ids['dev_ids']:
570 | dev_data.append(line.strip())
571 | else:
572 | train_data.append(line.strip())
573 | with open(train_file, 'w', encoding='utf-8') as f:
574 | f.write('\n'.join(train_data))
575 | with open(dev_file, 'w', encoding='utf-8') as f:
576 | f.write('\n'.join(dev_data))
577 |
578 |
579 | def get_time_str():
580 | """
581 | 返回当前时间戳字符串
582 | :return:
583 | """
584 | return time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
585 |
586 |
587 | def get_time_dif(start_time):
588 | """
589 | 获取已经使用的时间
590 | :param start_time:
591 | :return:
592 | """
593 | end_time = time.time()
594 | time_dif = end_time - start_time
595 | return timedelta(seconds=int(round(time_dif)))
596 |
597 |
598 | def get_labels(ans_pred: torch.Tensor, prop_pred: np.ndarray, entity_pred: torch.Tensor, label_hub):
599 | """
600 | 根据标签ID返回标签名称
601 | :param ans_pred:
602 | :param prop_pred:
603 | :param entity_pred:
604 | :param label_hub:
605 | :return:
606 | """
607 | ans_labels, prop_labels, entity_labels = [], [], []
608 | for ans_id in ans_pred:
609 | ans_labels.append(label_hub.ans_id2label[ans_id.item()])
610 | for entity_id in entity_pred:
611 | entity_labels.append(label_hub.entity_id2label[entity_id.item()])
612 | for line in prop_pred:
613 | temp = []
614 | for i in range(len(line)):
615 | if line[i] == 1:
616 | temp.append(label_hub.prop_id2label[i])
617 | prop_labels.append(temp)
618 | return ans_labels, prop_labels, entity_labels
619 |
620 |
621 | def fetch_error_cases(pred_data: Dict, gold_data: Dict, question: Dict):
622 | """
623 | 获取预测错误的样本
624 | :param pred_data:
625 | :param gold_data:
626 | :return:
627 | """
628 | time_str = get_time_str()
629 | time_str = time_str.replace(':', ':')
630 | count = 1
631 | with open('../data/results/error_case_{}.txt'.format(time_str), 'w', encoding='utf-8') as f:
632 | for idx, res in pred_data['result'].items():
633 | ans = gold_data[idx]
634 | ans = set(ans.split('|'))
635 | if '' in ans:
636 | ans.remove('')
637 | res = set(res.split('|'))
638 | if ans != res:
639 | f.write('第{}个错误结果\n'.format(count))
640 | f.write('问题编号为: {}, 问题为:{}\n'.format(idx, question[idx]))
641 | f.write('正确答案:{}\n'.format(ans))
642 | f.write('预测答案:{}\n'.format(res))
643 | f.write('预测中间结果:{}\n'.format(pred_data['model_result'][idx]))
644 | f.write('【问题分析】:【】\n')
645 | f.write('\n')
646 | count += 1
647 |
648 |
649 | def k_fold_data(data: List[Dict], k=5, batch_size=32, seed=1, collate_fn='1'):
650 | """
651 | data即位DataSet类中的data。将data分成k份,然后组成训练验证集
652 | :param data:
653 | :param k:
654 | :param batch_size:
655 | :param seed: 随机数种子
656 | :collate_fn:
657 | :return:
658 | """
659 | print('K折数据划分中...')
660 | from data import BertDataset, bert_collate_fn, binary_collate_fn
661 | from torch.utils.data import DataLoader
662 | temp_data = copy.deepcopy(data)
663 | random.seed(seed)
664 | random.shuffle(temp_data)
665 | block_len = len(temp_data)//k
666 | data_blocks = [temp_data[i*block_len: (i+1)*block_len] for i in range(k)]
667 | train_dev_tuples = []
668 | for i, block in enumerate(data_blocks):
669 | train, dev = [], block
670 | for _ in range(k):
671 | if _ != i:
672 | train += data_blocks[_]
673 | train_set = BertDataset(train_file=None, tokenizer=None, label_hub=None, init=False)
674 | dev_set = BertDataset(train_file=None, tokenizer=None, label_hub=None, init=False)
675 | train_set.data = train
676 | dev_set.data = dev
677 | if collate_fn == '1':
678 | train_dev_tuples.append((DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=bert_collate_fn),
679 | DataLoader(dev_set, batch_size=batch_size, shuffle=True, collate_fn=bert_collate_fn)))
680 | else:
681 | train_dev_tuples.append(
682 | (DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=binary_collate_fn),
683 | DataLoader(dev_set, batch_size=batch_size, shuffle=True, collate_fn=binary_collate_fn)))
684 | return train_dev_tuples
685 |
686 |
687 | def filter_sub_properties(sub_properties: List[Tuple], entity_label: str, ans_label: str):
688 | """
689 | 过滤 sub_properties
690 | 去掉 实体中流量或价钱被预测成约束的情况 或 去掉重复的约束属性名
691 | :param sub_properties:
692 | :param entity_label:
693 | :param ans_label:
694 | :return:
695 | """
696 | sub_properties_filter = []
697 | sub_property_set = set()
698 | for sub_property in sub_properties:
699 | if sub_property[0] not in sub_property_set:
700 | sub_property_set.add(sub_property[0])
701 | # 重复的约束属性名而且答案类型是属性值
702 | elif ans_label == '属性值':
703 | continue
704 | if sub_property[0] == '子业务':
705 | sub_properties_filter.append(sub_property)
706 | elif sub_property[1] in entity_label:
707 | if entity_label == '1元5gb流量券': # 该实体中 1元 是约束条件
708 | sub_properties_filter.append(sub_property)
709 | else:
710 | continue
711 | else:
712 | sub_properties_filter.append(sub_property)
713 | return sub_properties_filter
714 |
715 |
716 | def reverse_prop(prop_label: List[str]):
717 | """
718 | 置换标签,在查找答案为空的时候使用
719 | :param prop_label:
720 | :return:
721 | """
722 | changed = False
723 | for i, prop in enumerate(prop_label):
724 | # 假设以下条件独立
725 | if '方式' in prop:
726 | prop_label[i] = prop.replace('方式', '条件')
727 | changed = True
728 | elif '条件' in prop:
729 | prop_label[i] = prop.replace('条件', '方式')
730 | changed = True
731 | elif '档位介绍-有效期规则' == prop:
732 | prop_label[i] = '生效规则'
733 | changed = True
734 | elif '生效规则' == prop:
735 | prop_label[i] = '档位介绍-有效期规则'
736 | changed = True
737 | return changed
738 |
739 |
740 | def reverse_entity(entity_label: str):
741 | """
742 | 置换属性名,在查找答案为空的时候使用
743 | :param entity_label:
744 | :return:
745 | """
746 | if entity_label == '嗨购月包':
747 | return True, '嗨购产品'
748 | elif entity_label == '嗨购产品':
749 | return True, '嗨购月包'
750 | if entity_label == '北京移动plus会员权益卡':
751 | return True, '北京移动plus会员'
752 | elif entity_label == '北京移动plus会员':
753 | return True, '北京移动plus会员权益卡'
754 | else:
755 | return False, ''
756 |
757 |
758 | def rm_symbol(sentence):
759 | import re
760 | return re.sub(',|\.|,|。|?|\?|!|!|:', '', sentence)
761 |
762 |
763 | def judge_cancel(question):
764 | """
765 | 判断“取消”属性名
766 | :param question:
767 | :return:
768 | """
769 | import re
770 | neg_word, pos_word = ['不要', '不想', '取消', '不需要', '不用'], ['办理', '需要', '开通']
771 | sentence_blocks = re.split(',|\.|:|,|。', question)
772 | for block in sentence_blocks:
773 | for n in neg_word:
774 | if n in block:
775 | for p in pos_word:
776 | if p in block:
777 | return True
778 | return False
779 |
780 |
781 | def rules_to_judge_prop(question):
782 | """
783 | 规则判断属性名:适用app、叠加规则、封顶规则、档位介绍-带宽、档位介绍-有效期规则、生效规则
784 | :param question:
785 | :return:
786 | """
787 | # 适用app:当前只需要判断是否包含“app”和“适用”关键字
788 | if 'app' in question and '适用' in question:
789 | return '适用app'
790 | # 生效规则:判断包含“生效”并带疑问关键字
791 | if '生效' in question and ('么' in question or '吗' in question or '啥' in question):
792 | return '生效规则'
793 | # 叠加规则:包含“可以叠加”关键字。注意与“叠加包”区别
794 | if '可以叠加' in question and '叠加包' not in question:
795 | return '叠加规则'
796 | # 封顶规则:包含“限速”或“上限”关键字,不能包含“解除”、”恢复“关键字
797 | if ('限速' in question or '上限' in question) and ('解除' not in question and '恢复' not in question):
798 | return '封顶规则'
799 | # 有效期规则
800 | if ('到期' in question or '有效期' in question) and ('办理' not in question and '取消' not in question and '关闭' not in question):
801 | return '档位介绍-有效期规则'
802 | # 带宽
803 | if '通带宽' in question or '网速' in question:
804 | return '档位介绍-带宽'
805 | return None
806 |
807 |
808 | def rules_to_judge_entity(question, predict_entity):
809 | """
810 | 规则来判断实体
811 | :param question:
812 | :param predict_entity
813 | :return:
814 | """
815 | if predict_entity == '畅享套餐':
816 | if '升档' in question:
817 | return '新畅享套餐升档优惠活动'
818 | elif '促销' in question or '78元无限流量套餐' in question:
819 | return '畅享套餐促销优惠活动'
820 | elif '新全球通' in question:
821 | return '新全球通畅享套餐'
822 | elif '首充活动' in question:
823 | return '畅享套餐首充活动'
824 | elif predict_entity == '移动王卡':
825 | if '惠享合约' in question:
826 | return '移动王卡惠享合约'
827 | elif predict_entity == '5g畅享套餐':
828 | if '合约版' in question:
829 | return '5g畅享套餐合约版'
830 | elif predict_entity == '30元5gb包':
831 | if '半价' in question:
832 | return '30元5gb半价体验版'
833 | elif predict_entity == '移动花卡':
834 | if '新春升级' in question:
835 | return '移动花卡新春升级版'
836 | elif predict_entity == '随心看会员':
837 | if '合约版' in question:
838 | return '随心看会员合约版'
839 | elif predict_entity == '北京移动plus会员':
840 | if '权益卡' in question:
841 | return '北京移动plus会员权益卡'
842 | elif predict_entity == '5g智享套餐':
843 | if '合约版' in question:
844 | return '5g智享套餐合约版'
845 | elif '家庭版' in question:
846 | return '5g智享套餐家庭版'
847 | elif predict_entity == '全国亲情网':
848 | if '亲情网免费' in question:
849 | return '全国亲情网功能费优惠活动'
850 | elif predict_entity == '精灵卡':
851 | if '首充' in question and '优惠' in question:
852 | return '精灵卡首充优惠活动'
853 | # 合并的样本外的样本
854 | else:
855 | if '无忧' in question:
856 | return '流量无忧包'
857 | return None
858 |
859 |
860 | class MyLogger(object):
861 | def __init__(self, log_file, debug='0'):
862 | self.ch = logging.StreamHandler()
863 | self.formatter = logging.Formatter("%(asctime)s - %(message)s")
864 | self.fh = logging.FileHandler(log_file, mode='w')
865 | self.logger = logging.getLogger()
866 | self.debug = debug
867 | self.init()
868 |
869 | def init(self):
870 | self.logger.setLevel(logging.INFO)
871 | # 输出到文件
872 | self.fh.setLevel(logging.INFO)
873 | self.fh.setFormatter(self.formatter)
874 | # 输出到控制台
875 | self.ch.setLevel(logging.INFO)
876 | self.ch.setFormatter(self.formatter)
877 | self.logger.addHandler(self.ch)
878 | self.logger.addHandler(self.fh)
879 |
880 | def log(self, message):
881 | if self.debug == '1':
882 | self.logger.info(message)
883 |
884 |
885 | def generate_false_label(predict_file):
886 | """
887 | 根据预测的结果反向生成同名伪标签xlsx
888 | :param predict_file
889 | :return
890 | """
891 | result_path = '../data/results/'
892 | with open(result_path+predict_file, 'r', encoding='utf-8') as f:
893 | predict_dict = json.load(f)
894 | predict_results = []
895 | for i, predict_entity in predict_dict['model_result'].items():
896 | instance = []
897 | for key, value in predict_entity.items():
898 | if key == 'main_property':
899 | instance.append('|'.join(value))
900 | elif key == 'sub_properties':
901 | cons_prop, cons_value = [], []
902 | for pair in value:
903 | cons_prop.append(pair[0])
904 | cons_value.append(str(pair[1]))
905 | instance.append('|'.join(cons_prop))
906 | instance.append('|'.join(cons_value))
907 | else:
908 | if key == 'operator' and value == 'other':
909 | instance.append('')
910 | else:
911 | instance.append(value)
912 | instance.append(predict_dict['result'][i])
913 | predict_results.append(instance)
914 | header = ['用户问题', '答案类型', '实体', '属性名', '约束算子', '约束属性名', '约束属性值', '答案']
915 | pd.DataFrame(predict_results, index=None, columns=header).to_excel(result_path+predict_file[:-5]+'.xlsx', index=False)
916 | print('预测结果写入xlsx成功')
917 |
918 |
919 | def view_gpu_info(gpu_id: int):
920 | import pynvml
921 | pynvml.nvmlInit()
922 | handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
923 | mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
924 | # 这里是字节bytes,所以要想得到以兆M为单位就需要除以1024**2
925 | print('当前为第{}号显卡,总显存数:{}MB;显存使用数:{}MB;显存剩余数:{}MB'.format(gpu_id, mem_info.total/1024**2, mem_info.used/1024**2, mem_info.free/1024**2)) # 第二块显卡总的显存大小
926 | return mem_info.total/1024**2, mem_info.used/1024**2, mem_info.free/1024**2
927 |
928 |
929 | def max_count(data: dict, k=1):
930 | """
931 | 对于存在"标签"->"个数"的字典data,得到最大数量的标签
932 | :param data:
933 | :param k: 前k个统计
934 | :return:
935 | """
936 | max_count_, la = 0, None
937 | second_count_, la2 = 0, None
938 | for lab, c in data.items():
939 | if c > max_count_:
940 | max_count_ = c
941 | la = lab
942 | # 只有在标签个数至少为2时才有第二多的标签,否则直接返回空结果
943 | if k == 2 and len(data) > 1:
944 | for lab, c in data.items():
945 | if c > second_count_ and lab != la:
946 | second_count_ = c
947 | la2 = lab
948 | if k == 1:
949 | return la, max_count_
950 | else:
951 | return la, max_count_, la2, second_count_
952 |
953 |
954 | def vote_integration(ans_logits: List[torch.Tensor],
955 | prop_logits: List[torch.Tensor],
956 | entity_logits: List[torch.Tensor],
957 | label_hub,
958 | batch_size: int):
959 | """
960 | 模型投票融合
961 | :param ans_logits:
962 | :param prop_logits:
963 | :param entity_logits:
964 | :param label_hub:
965 | :param batch_size:
966 | :return: 最终的标签
967 | """
968 | vote_method = 2
969 | mid_results = {'ans': [], 'prop': [], 'entity': []}
970 | for i in range(batch_size):
971 | mid_results['ans'].append({})
972 | mid_results['prop'].append({})
973 | mid_results['entity'].append({})
974 | ans_results, prop_results, entity_results = [], [], []
975 | for i in range(len(ans_logits)):
976 | ans_pred = ans_logits[i].data.cpu().argmax(dim=-1)
977 | prop_pred = logits_to_multi_hot(prop_logits[i], ans_pred, label_hub)
978 | entity_pred = entity_logits[i].data.cpu().argmax(dim=-1)
979 | ans_labels, prop_labels, entity_labels = get_labels(ans_pred, prop_pred, entity_pred, label_hub)
980 | for j in range(batch_size):
981 | if ans_labels[j] not in mid_results['ans'][j]:
982 | mid_results['ans'][j][ans_labels[j]] = 0
983 | mid_results['ans'][j][ans_labels[j]] += 1
984 | # prop投票1: key是模型预测的直接结果
985 | if vote_method == 1:
986 | prop_label = tuple(sorted(prop_labels[j]))
987 | if prop_label not in mid_results['prop'][j]:
988 | mid_results['prop'][j][prop_label] = 0
989 | mid_results['prop'][j][prop_label] += 1
990 | # prop投票2: key是模型预测的每个个体标签
991 | else:
992 | prop_label = prop_labels[j]
993 | for lab in prop_label:
994 | if lab not in mid_results['prop'][j]:
995 | mid_results['prop'][j][lab] = 0
996 | mid_results['prop'][j][lab] += 1
997 | if entity_labels[j] not in mid_results['entity'][j]:
998 | mid_results['entity'][j][entity_labels[j]] = 0
999 | mid_results['entity'][j][entity_labels[j]] += 1
1000 | for instance in mid_results['ans']:
1001 | ans_results.append(max_count(instance)[0])
1002 | if vote_method == 1:
1003 | for i, instance in enumerate(mid_results['prop']):
1004 | prop_results.append(max_count(instance)[0])
1005 | else:
1006 | for i, instance in enumerate(mid_results['prop']):
1007 | ans = ans_results[i]
1008 | la, max_count_, la2, second_count_ = max_count(instance, k=2)
1009 | if ans == '并列句':
1010 | if second_count_ <= len(ans_logits)//3:
1011 | print('-------------------出现了不应该发生的情况-------------------')
1012 | ans_results[i] = '属性值'
1013 | prop_results.append([la])
1014 | else:
1015 | prop_results.append([la, la2])
1016 | else:
1017 | prop_results.append([la])
1018 | if second_count_ > len(ans_logits)//2:
1019 | prop_results[-1].append(la2)
1020 | for instance in mid_results['entity']:
1021 | entity_results.append(max_count(instance)[0])
1022 | return ans_results, prop_results, entity_results
1023 |
1024 |
1025 | def average_integration(ans_logits: List[torch.Tensor],
1026 | prop_logits: List[torch.Tensor],
1027 | entity_logits: List[torch.Tensor],
1028 | label_hub):
1029 | """
1030 | 模型平均融合
1031 | :param ans_logits:
1032 | :param prop_logits:
1033 | :param entity_logits:
1034 | :param label_hub:
1035 | :return: 最终的标签
1036 | """
1037 | local_ans_logits, local_prop_logits, local_entity_logits = torch.stack(ans_logits).mean(dim=0), torch.stack(prop_logits).mean(dim=0), torch.stack(entity_logits).mean(dim=0)
1038 | ans_pred = local_ans_logits.data.cpu().argmax(dim=-1)
1039 | prop_pred = logits_to_multi_hot(local_prop_logits, ans_pred, label_hub)
1040 | entity_pred = local_entity_logits.data.cpu().argmax(dim=-1)
1041 | ans_labels, prop_labels, entity_labels = get_labels(ans_pred, prop_pred, entity_pred, label_hub)
1042 | return ans_labels, prop_labels, entity_labels
1043 |
1044 |
1045 | def binary_average_integration(logits: List[torch.Tensor], label_hub):
1046 | """
1047 | 二进制模型的平均融合
1048 | :param logits:
1049 | :param label_hub:
1050 | :return:
1051 | """
1052 | local_logits = torch.stack(logits).mean(dim=0)
1053 | pred = local_logits.data.cpu().argmax(dim=-1)
1054 | labels = []
1055 | for label_id in pred:
1056 | labels.append(label_hub.binary_id2label[label_id.item()])
1057 | return labels
1058 |
1059 |
1060 | def tqdm_with_debug(data, debug=None):
1061 | """
1062 | 将tqdm加入debug模式
1063 | :param data:
1064 | :param debug:
1065 | :return:
1066 | """
1067 | from tqdm import tqdm
1068 | if debug == '1' or debug is True:
1069 | if isinstance(data, enumerate):
1070 | temp_data = list(data)
1071 | else:
1072 | temp_data = data
1073 | len_data = len(temp_data)
1074 | ans = tqdm(temp_data, total=len_data)
1075 | return ans
1076 | else:
1077 | return data
1078 |
1079 |
1080 | # def entity_map():
1081 | # a = {'新畅享套餐升档优惠活动': '畅享套餐', '畅享套餐促销优惠活动': '畅享套餐', '畅享套餐促销优惠': '畅享套餐', '新全球通畅享套餐': '畅享套餐', '畅享套餐首充活动': '畅享套餐',
1082 | # '移动王卡惠享合约': '移动王卡', '5g畅享套餐合约版': '5g畅享套餐', '30元5gb半价体验版': '30元5gb包', '移动花卡新春升级版': '移动花卡', '随心看会员合约版': '随心看会员', '北京移动plus会员权益卡': '北京移动plus会员',
1083 | # '5g智享套餐合约版': '5g智享套餐', '5g智享套餐家庭版': '5g智享套餐',
1084 | # '全国亲情网功能费优惠活动': '全国亲情网', '精灵卡首充优惠活动': '精灵卡'}
1085 | # with open('../data/file/entity_map.json', 'w', encoding='utf-8') as f:
1086 | # json.dump(a, f, ensure_ascii=False, indent=2)
1087 |
1088 |
1089 | if __name__ == '__main__':
1090 | # parse_triples_file('../data/raw_data/triples.txt')
1091 | # print(to_sparql(entity='视频会员通用流量月包', main_property='档位介绍-上线时间', sub_properties={'价格': '70', '子业务': '优酷会员'}))
1092 | # make_dataset(['../data/raw_data/train_augment_few_nlpcda.xlsx',
1093 | # '../data/raw_data/train_augment_simbert.xlsx',
1094 | # '../data/raw_data/train_augment_synonyms.xlsx'],
1095 | # target_file='../data/dataset/augment3.txt',
1096 | # label_file=None,
1097 | # train=True)
1098 | make_dataset_for_binary(['../data/raw_data/train_denoised.xlsx'],
1099 | target_file='../data/dataset/binary_labeled.txt')
1100 | make_dataset_for_binary(['../data/raw_data/train_augment_few_nlpcda.xlsx',
1101 | '../data/raw_data/train_augment_simbert.xlsx',
1102 | '../data/raw_data/train_augment_synonyms.xlsx'],
1103 | target_file='../data/dataset/binary_augment3.txt')
1104 | # make_dataset(['../data/raw_data/test2_denoised.xlsx'],
1105 | # target_file='../data/dataset/cls_unlabeled2.txt',
1106 | # label_file=None,
1107 | # train=False)
1108 | # split_train_dev(5000)
1109 | # split_labeled_data(source_file='../data/dataset/cls_labeled.txt',
1110 | # train_file='../data/dataset/cls_train.txt',
1111 | # dev_file='../data/dataset/cls_dev.txt')
1112 | # ------------------获取错误答案详情---------------
1113 | # with open('../data/results/ans_dev.json', 'r', encoding='utf-8') as f:
1114 | # pred = json.load(f)
1115 | # with open('../data/dataset/cls_dev.txt', 'r', encoding='utf-8') as f:
1116 | # answer, question = {}, {}
1117 | # for i, line in enumerate(f):
1118 | # ans = line.strip().split('\t')[-1]
1119 | # answer[str(i)] = ans
1120 | # question[str(i)] = line.strip().split('\t')[0]
1121 | # fetch_error_cases(pred, answer, question)
1122 | # generate_false_label('ensemble_bert_aug2_use_efficiency_2021-07-20-08-29-22_seed1_fi0_gpu2080_be81_0.95411990.pth.json')
1123 | # make_dataset(['../data/raw_data/test2_denoised.xlsx'],
1124 | # target_file='../data/dataset/cls_unlabeled.txt',
1125 | # label_file=None,
1126 | # train=False)
1127 | # entity_map()
1128 |
--------------------------------------------------------------------------------