├── LICENSE ├── README.md └── code ├── data.py ├── data_argument.py ├── doc2unix.py ├── docker_process.py ├── inference.py ├── model.py ├── module.py ├── run.sh ├── statistic.py ├── test.py ├── train.py ├── train_evaluate_constraint.py ├── triples.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 2021CCKS-QA-Task 2 | A solution to 2021-CCKS-QA-Task(https://tianchi.aliyun.com/competition/entrance/531904/introduction) 3 | 4 | ### 写在前面 5 | 这个比赛结束已经大半年了,期间参赛者似乎都没有做方案的开源,讨论区也很冷清。 6 | 其实自结束以来一直都想做开源工作,分享自己在算法比赛过程中的思路,奈何学业科研繁忙一直搁置。最近因为机缘巧合,打算践行开源,给有需要的同学,提供一些思路(抛砖引玉233)。同时也希望NLP开源社区可以越来越活跃。 7 | 最后这个比赛是我自学NLP之后认真筹备的第一个项目,构建代码时做了很多尝试,比较乱,望见谅。 8 | 9 | ____________________________________________ 10 | 11 | 12 | 13 | 14 | # CCKS2021 运营商知识图谱推理问答 - 湖人总冠军方案 15 | 16 | 17 | 18 | ### 1 任务背景 19 | 20 | 基于给定的运营商知识图谱,使用模型预测用户所提出问题的答案,任务来自2021CCKS会议在阿里云天池平台组织的算法竞赛。在自然语言处理领域该任务属于**KBQA**(**K**nowledge **B**ase **Q**uestion **A**nswering)。 21 | 22 | 场景案例 23 | 24 | ``` 25 | Q:流量日包的流量有多少? 26 | A:100MB 27 | 28 | Q:不含彩铃的套餐有哪些? 29 | A:流量月包|流量年包 30 | ``` 31 | 32 | 33 | 34 | ### 2 数据介绍 35 | 36 | 用以问答的数据的数据包含(数据下载地址 https://tianchi.aliyun.com/dataset/dataDetail?dataId=109340 ) 37 | 38 | * 知识图谱schema:定义了结点的实体、关系类型。 39 | 40 | * 知识图谱三元组:类似于<流量日包, 流量, 100MB>形式的三元组。 41 | 42 | * 同义词文件:用户口语化问题中实体的同义词,辅助实体的分类。 43 | 44 | * 训练数据:5000条。 45 | 46 | * 测试数据:1000条。 47 | 48 | | **用户问题** | **答案类型** | **属性名** | **实体** | **约束属性名** | **约束属性值** | **约束算子** | **答案** | 49 | | ----------------------- | ------------ | ----------------- | ---------------- | -------------- | -------------- | ------------ | ------------ | 50 | | 我的20元20G流量怎么取消 | 属性值 | 档位介绍-取消方式 | 新惠享流量券活动 | 价格\|流量 | 20\|20 | =\|= | 取消方式_133 | 51 | 52 |
表1 训练集样本示例
53 | 54 | | **问题** | 55 | | ------------------------------ | 56 | | 什么时候可以办理70元的优酷会员 | 57 | 58 |
表2 测试集样本示例
59 | 60 | ### 3 数据分析 61 | 62 | 目前知识图谱问答系统在简单句(单实体单属性)上已经取得比较好的效果,而在约束句:条件约束句、时间约束句,以及推理型问句:比较句、最值句、是否型问句以及问句中带有交集、并集和取反的问句等,其逻辑推理能力还有待提升。 63 | 64 | 图1展示了问句类型的分布。训练数据中4146个样本是属性句,696条样本是并列句,158条样本是比较句。虽然预测难度较小的属性句占比较大,但是在测试集上取得较高的指标仍然需要聚焦于其他两种类型的数据。 65 | 66 | 图2展示了数据中的实体分布,前三的实体是新惠享流量券活动,全国亲情网,北京移动plus会员。可以看出实体的分布存在长尾效应,而类别最少的样本仅1条。 67 | 68 | 69 | 70 |
图1 问句类型分布
71 | 72 | 73 | 74 |
图2 实体分布
75 | 76 | 77 | 78 | ### 4 整体设计 79 | 80 | 为了提高推理的性能,我们算法的整体思路是用NLP模型对文本进行语义解析后,将解析到的成分拼接成SPARQL查询语句,利用查询语句查询知识库得到最终的答案。 81 | 语义解析部分分为两个部分 82 | 83 | - 多任务分类模型预测答案类型、属性名、实体; 84 | - 命名实体识别模型预测约束属性名和约束属性值。 85 | 86 | ![](https://cdn.nlark.com/yuque/0/2021/png/22425527/1629187488416-fd08ca87-c65d-4c04-ae0e-dd5e5087cff0.png) 87 | 88 |
图3 整体流程
89 | 90 | ### 5 数据预处理 91 | 92 | 数据预处理主要分为数据清洗、数据集处理、数据增强三个方面。 93 | 94 | #### 5.1 数据清洗 95 | 96 | 数据清洗工作包含了对训练、测试数据,三元组数据的清洗,主要做以下处理 97 | 98 | - 英文字母转化成小写形式,例如A—>a 99 | - 中文的数字表示用阿拉伯数字表示,例如五十—>50 100 | - 去除问句中包含的空格 101 | 102 | #### 5.2 数据集处理 103 | 104 | 这里的数据集处理主要将原数据转化成可训练的数据格式。 105 | 106 | | **用户问题** | **答案类型** | **属性名** | **实体** | **约束属性名** | **约束属性值** | **约束算子** | **答案** | 107 | | ----------------------- | ------------ | ----------------- | ---------------- | -------------- | -------------- | ------------ | ------------ | 108 | | 我的20元20G流量怎么取消 | 属性值 | 档位介绍-取消方式 | 新惠享流量券活动 | 价格|流量 | 20|20 | =|= | 取消方式_133 | 109 | 110 |
表3 训练集样本示例
111 | 112 | 答案类型、属性名、实体的预测被作为视作是文本的分类,其中答案类型、实体预测属于多分类任务,属性名是多标签分类任务。约束属性名、约束属性值是在字符级别上的分类,这被我们理解为命名实体识别任务,这就需要将约束属性名对应到原问句中的约束属性值。 113 | 在仔细分析标注数据后发现约束属性值基本都会出现在原问句中,如表3中的两个20。因此使用正则匹配的方法自动进行约束属性名和属性值的标注。 114 | **标注数据** 采用BIEO方法,用train.xlsx中的「约束属性」和「约束值」对「用户问题」进行标注。注意标注前需要统一数字和字母的大小写。标注示例如下: 115 | 116 | '我 的 **2** **0** 元 **2** **0** G 流 量 怎 么 取 消' 117 | 118 | 'O O **_B-PRICE E-PRICE_** O **B-FLOW E-FLOW** O O O O O O O' 119 | 120 | #### 5.3 数据增强 121 | 122 | 数据增强从主要考虑三个角度 123 | 124 | **少样本增强** 少样本增强的动机是观察到数据中的标签存在分布不平衡的现象。我们利用`nlpcda`包,使用EDA的方案对数据量少于某个阈值的标签样本进行增强。由于EDA中的随机删除操作会导致模型性能变差,随机删除操作被我们去除。 125 | 126 | **同义词增强** 比赛数据中包含了实体的同义词文件,为了尽最大化利用该数据,提高模型在测试集中的泛化能力,我们利用该文件对问句中的实体用其同义词进行替换,生成新的问句,从而达到扩充训练数据的目的。但是由于同义词较多,初次生成的样本达20000+例,是原训练样本的四倍,增加了训练时间开销且提升效果不明显。通过对数据的分析,我们发现训练集中样本量多的实体其同义词一般也较多。如果直接替换,会加剧标签分布不平衡的问题,从而导致提升不明显。基于这个问题,我们遵从少样本增强的思想,对样本量较多的实体按概率来进行同义词替换,对样本量较少的实体全量进行替换。基于该方法,得到增强样本6000+例,提高了训练速度,明显地增加了模型的泛化性能。 127 | 128 | **基于Bert的文本生成增强** 前两种增强方式都是基于原样本的小范围修改,本次增强方式利用`nlpcda`包中的simbert模型来生成与原句意思相近,但表达方式不同的样本,进一步提升模型的泛化能力。基于该方法得到增强样本4000+例。 129 | 130 | 131 | 132 | ### 6 分类模块 133 | 134 | ![图片.png](https://cdn.nlark.com/yuque/0/2021/png/22425527/1629187761737-1cf625da-f27f-491b-bc6a-11b05f19415a.png) 135 | 136 |
图4 分类模型结构
137 | 138 | 分类模块作用是对答案类型、属性名、实体的分类。由于一个问句只有一个答案类型和一个实体,而可以有多个属性名,所以我们把答案类型和实体分类视作多分类任务,把属性名视作多标签分类任务。 139 | 使用多任务学习的考虑是多个相近的任务一起来训练可以提高准确率、学习速度和泛化能力。我们的实验证明多任务模型比单任务模型准确率整体更高且更稳定。 140 | 最终的分类结果由五个模型平均融合得到,分别是BERT-base、XLNET-base、RoBERTa-base、ELECTRA-base、MacBERT-base。 141 | 接下来开始介绍单模型中使用的方法。 142 | 143 | 1. 任务级的注意力机制 144 | 145 | ![image.png](https://cdn.nlark.com/yuque/0/2021/png/1074904/1629179071154-9eecf8b2-5c46-435a-b709-91900d5c802e.png) 146 | 147 |
图5 任务级注意力机制
148 | 149 | 对BERT等预训练模型产生的字符级embedding使用任务级别的Attention机制,得到适合于当前子任务的句子级embedding,进而用于后续的分类器。实验证明结果该注意力机制对于结果的稳定与提升有明显的效果,比赛实验结果无记录。方法详情可见[Same Representation, Different Attentions: Shareable Sentence Representation Learning from Multiple Tasks](https://arxiv.org/pdf/1804.08139.pdf)。 150 | 151 | 2. 样本有效率 152 | 153 | 在经过数据增强操作后,得到了近15000例的增强文本。这些样本的质量或多或少低于原数据,直接用于模型的训练会使模型向这些低质量样本拟合,进而导致模型无法准确学习到原标注样本中数据分布,最终导致训练费时且效果低下。 154 | 设置样本有效率权重是解决该问题的一个思路。我们在仔细分析了不同数据增强文本后,为不同方法得到的样本设置了不同的有效率,以让模型可以从这堆样本中学习到东西同时又不会在低质量的样本上过拟合。以下是实验结果。 155 | 156 | | | Base | Base + **加入样本权重** | 157 | | -------- | ------ | ----------------------- | 158 | | 复赛结果 | 0.9423 | 0.9453 | 159 | 160 | 3. 二阶段的训练模式 161 | 162 | 当原数据混合了增强样本,在训练时如果还是按照原来的模式——将训练样本随机划分20%作为验证集,必然会在验证集中混入增强的样本。挑选验证集上表现最佳的模型时,这个模型是在低质量的数据上表现最优,这会导致模型在高质量的测试集下表现得远不如验证集。 163 | 我们在解决这个问题的思路是使用两个阶段来训练模型。一阶段时,模型在全量的增强样本上进行训练,不划分验证集,训练λ个epoch,λ作为超参来调整,在训练完后保存该模型参数。二阶段时,加载一阶段的模型并在原训练集上进行“微调”训练,划分20%的数据作为验证集,挑选在验证集上表现最优的模型。实验结果显示,模型有少量的提升,更重要的是由于减少了整体训练的epoch,训练的时间开销减少约一倍。 164 | 165 | | | 原训练方式 | **二阶段训练** | 166 | | -------- | ---------- | -------------- | 167 | | 复赛结果 | 0.9477 | 0.9485 | 168 | | 运行时间 | 50min | 29min | 169 | 170 | 4. 损失函数优化 171 | 172 | 多任务模型的损失值来自于三个子任务的损失值的加权和,加权方式使用可学习的参数来进行自动调整,相比于固定权重的方法,该方法可以使模型结果更加稳定。比赛实验结果无记录。 173 | 174 | $$ 175 | \alpha = softmax(W) 176 | $$ 177 | 178 | $$ 179 | \mathcal{L_{total}} = - \sum_{i=1}\alpha_i\mathcal{L_{Task_i}} 180 | $$ 181 | 182 | ### 7 实体识别模块 183 | 184 | 本模块目的是为了预测出句子中的「约束属性」和「约束值」。 185 | ![图片.png](https://cdn.nlark.com/yuque/0/2021/png/22425527/1629187518392-978f34e6-436e-46b2-a98c-031402285da6.png) 186 | 187 |
图6 实体识别模型结构
188 | 189 | **重采样** 实体对应很多同义词,但是有些同义词在训练集中出现次数很少,甚至只出现一次,如果该同义词中又出现了价格或者流量等,就很有可能导致预测错误。比如「和家庭流量包」有一个同义词「100元5g」,「100元5g」是一个同义词,其中的「100元」不应该被预测成约束;但是由于「100元5g」在训练集中就只出现了一次,并且其他样本中「100元」经常被标注出来,所以导致模型错误地把同义词中的价格也预测出来了。因此对句子实体存在同义词的样本进行重采样,让模型充分学习这些样本。 190 | 191 | **同义词增强** 同一个实体中不同约束出现的频率不同,比如「天气预报」的价格有「3元」「5元」「6元」三档,但「3元」在训练集中出现次数远少于其他两个。针对这些样本我们对实体进行同义词替换后加入到了训练集。 192 | 193 | | | *base* | *base* + **重采样** | *base* + **同义词增强** | 194 | | -------- | ------ | -------------------- | ------------------------ | 195 | | 复赛结果 | 0.9518 | 0.9534 | 0.9567 | 196 | 197 | **模型训练** 模型采用了BiLSTM-crf 模型进行训练,数据分为5折交叉训练,batch_size 选用64,lr 选用0.001,epoch选用50。 198 | 199 | **推理前的同义词替换** 很多实体对应大量同义词,其中部分同义词会对实体识别造成干扰(干扰主要是由于同义词中包含价格、流量、子业务),比如「天气预报」和「天气预报年包」是同义词。模型会把「年包」标注出来(但实际不应该标注出来),所以在推理前采用同义词替换,将对标注造成干扰的词进行同义词替换。 200 | 201 | **推理部分** 5折训练的模型分别推理,并将5折模型结果进行平均融合,得分相加取平均再转换到BIEO标的预测结果。 202 | 203 | | | **base** | *base* + **5折平均融合** | *base* + **5折平均融合+推理前的同义词替换** | 204 | | -------- | -------- | -------------------------- | --------------------------------------------- | 205 | | 复赛结果 | 0.9486 | 0.9518 | 0.9544 | 206 | 207 | ### 8 后处理 208 | 209 | **实体识别预测结果补全** 实体识别预测出的结果存在部分错误,比如「680」只识别了「68」,或「半年包」只识别了「年包」。针对以上情况,通过判断「年包」前面的字符是否是「半」,「68」紧邻的前后字符是否是数字等,对结果进行补全。 210 | 211 | **实体识别结果规则处理** 对预测结果不合理的部分进行规则处理。比如:(1)价格和流量预测反。通过判断‘元’‘G’等单位,来判断是否出现价格预测成了流量,或者流量预测成了价格的情况。(2)针对容易预测错误的句子,采用规则方法进行纠正。比如「新惠享流量券活动」只需要判断句子中是否存在「20」或者「30」就可以得到约束属性。 212 | 213 | **空结果处理** 在使用SPARQL查询过程中可能会有空结果的情况出现。当发生此类情况时,将模型预测的标签转化成其可能会混淆的标签,设置多层的判定条件,可以尽最大可能性避免空结果的产生。 214 | 215 | **属性名细分处理** 分类模型在“开通方式”和“开通条件”上混淆较为严重。当多任务模型只预测了上述标签中的其中一个,我们使用一个二分类模型和规则方法去修正标签是否正确。二分类模型的分类标签为“开通方式”和“开通条件”。 216 | 217 | **实体规则处理** 使用同义词表中的同义词来修正模型预测的属性名标签结果。 218 | 219 | ### 9 总结 220 | 221 | 从比赛初期的方案构思到复赛白热化的“军备竞赛”,团队成员在这次比赛中成长了很多,收获了很多。以下是我们本次比赛的总结。 222 | 223 | - 知识图谱推理问答与一般的文本匹配或分类任务有较大的不同,获取答案需要有多个维度的信息,系统中每一个薄弱点都可能成为木桶里的“短板”,因此既要兼顾大局又要各个击破。 224 | - 比赛由于任务自身的问题,标注信息没有那么准确,因此需要对数据进行仔细的甄别,有疑问的数据需要去除; 225 | - 由于复赛的docker运行时长只有6小时,所以不光要考虑模型的性能还需要考虑训练的效率; 226 | - 没有万金油的方案,深度学习需要不断实现方案并实验; 227 | 228 | 此外,由于实验室的学业压力,仍有很多方案没有进行尝试,同时很多方案尝试了但没有提升结果,与前排大佬也有不少差距。希望吸取本次比赛的经验和教训,继续提升自身的实力。 229 | 230 | 231 | ### 10 附录 232 | 233 | ##### 10.1 项目目录树 234 | 235 | ```sh 236 | . 237 | └── code 238 | ├── data.py # 数据模块 239 | ├── data_argument.py # 数据增强模块 240 | ├── doc2unix.py # 修改回车符 241 | ├── docker_process.py # docker中预处理模块 242 | ├── inference.py # 推理模块 243 | ├── model.py # 模型结构 244 | ├── module.py # 必要模块 245 | ├── run.sh # 一键数据处理、训练、推理 246 | ├── statistic.py # 数据统计模块 247 | ├── test.py # 测试类 248 | ├── train.py # 多任务学习训练模块 249 | ├── train_evaluate_constraint.py # 命名实体识别训练模块 250 | ├── triples.py # 知识图谱与三元组处理模块 251 | └── utils.py # 工具类 252 | ``` 253 | 254 | ##### 10.2 环境依赖 255 | 256 | ``` 257 | bert4keras==0.7.7 258 | jieba==0.42.1 259 | Keras==2.3.1 260 | nlpcda==2.5.6 261 | numpy==1.16.5 262 | openpyxl==3.0.7 263 | pandas==1.2.3 264 | scikit-learn==0.23.2 265 | tensorboard==1.14.0 266 | tensorflow-gpu==1.14.0 267 | textda==0.1.0.6 268 | torch==1.5.0 269 | tqdm==4.50.2 270 | transformers==4.0.1 271 | xlrd==2.0.1 272 | rdflib==5.0.0 273 | nvidia-ml-py3==7.352.0 274 | ``` 275 | 276 | ##### 10.3 运行 277 | 278 | ```shell 279 | sh run.sh 280 | ``` 281 | 282 | 283 | 284 | 285 | -------------------------------------------------------------------------------- /code/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: hezf 3 | @Time: 2021/6/3 14:41 4 | @desc: 5 | """ 6 | from torch.utils.data import Dataset 7 | import torch 8 | import copy 9 | import json 10 | from utils import label_to_multi_hot, remove_stop_words 11 | import re 12 | import pandas as pd 13 | 14 | 15 | class BertDataset(Dataset): 16 | """ 17 | 用于bert的Dataset类 18 | """ 19 | def __init__(self, train_file, tokenizer, label_hub, init=True): 20 | super(BertDataset, self).__init__() 21 | self.train_file = train_file 22 | self.data = [] 23 | if init: 24 | self.tokenizer = tokenizer 25 | self.label_hub = label_hub 26 | # self.stopwords = [] 27 | self.init() 28 | 29 | def init(self): 30 | # print('加载停用词表...') 31 | # with open('../data/file/stopword.txt', 'r', encoding='utf-8') as f: 32 | # for line in f: 33 | # self.stopwords.append(line.strip()) 34 | print('读取数据...') 35 | with open(self.train_file, 'r', encoding='utf-8') as f: 36 | for line in f: 37 | # blocks: 0:问题;1:答案类型;2:属性名;3:实体;4:答案;5:样本有效率 38 | blocks = line.strip().split('\t') 39 | # 单独处理”属性名“ 40 | prop_label_ids = [self.label_hub.prop_label2id[label] for label in blocks[2].split('|')] 41 | prop_label = label_to_multi_hot(len(self.label_hub.prop_label2id), prop_label_ids) 42 | self.data.append({'token': self.tokenizer(blocks[0], 43 | add_special_tokens=True, max_length=100, 44 | padding='max_length', return_tensors='pt', 45 | truncation=True), 46 | 'ans_label': self.label_hub.ans_label2id[blocks[1]], 47 | 'prop_label': prop_label, 48 | 'entity_label': self.label_hub.entity_label2id[blocks[3]], 49 | 'efficiency': float(blocks[-1])}) 50 | 51 | def __getitem__(self, item): 52 | return self.data[item] 53 | 54 | def __len__(self): 55 | return len(self.data) 56 | 57 | 58 | class BinaryDataset(Dataset): 59 | """ 60 | 用于bert的Dataset类 61 | """ 62 | def __init__(self, train_file, tokenizer, label_hub, init=True): 63 | super(BinaryDataset, self).__init__() 64 | self.train_file = train_file 65 | self.data = [] 66 | if init: 67 | self.tokenizer = tokenizer 68 | self.label_hub = label_hub 69 | # self.stopwords = [] 70 | self.init() 71 | 72 | def init(self): 73 | # print('加载停用词表...') 74 | # with open('../data/file/stopword.txt', 'r', encoding='utf-8') as f: 75 | # for line in f: 76 | # self.stopwords.append(line.strip()) 77 | print('读取数据...') 78 | with open(self.train_file, 'r', encoding='utf-8') as f: 79 | for line in f: 80 | # blocks: 0:问题;1:标签;2:样本有效率 81 | blocks = line.strip().split('\t') 82 | # 单独处理”属性名“ 83 | self.data.append({'token': self.tokenizer(blocks[0], 84 | add_special_tokens=True, max_length=100, 85 | padding='max_length', return_tensors='pt', 86 | truncation=True), 87 | 'label': self.label_hub.binary_label2id[blocks[1]], 88 | 'efficiency': float(blocks[-1])}) 89 | 90 | def __getitem__(self, item): 91 | return self.data[item] 92 | 93 | def __len__(self): 94 | return len(self.data) 95 | 96 | 97 | class PredDataset(Dataset): 98 | """ 99 | 用于预测的Dataset类 100 | """ 101 | def __init__(self, data_path, tokenizer): 102 | super(PredDataset, self).__init__() 103 | self.data_path = data_path 104 | self.tokenizer = tokenizer 105 | self.data = [] 106 | self.init() 107 | 108 | def init(self): 109 | with open(self.data_path, 'r', encoding='utf-8') as f: 110 | for line in f: 111 | line = line.strip() 112 | line = line.split('\t') 113 | self.data.append({'token': self.tokenizer(line[0], add_special_tokens=True, max_length=100, 114 | padding='max_length', return_tensors='pt', 115 | truncation=True), 116 | 'question': line[0]}) 117 | 118 | def __len__(self): 119 | return len(self.data) 120 | 121 | def __getitem__(self, item): 122 | return self.data[item] 123 | 124 | 125 | class PredEnsembleDataset(Dataset): 126 | """ 127 | 用于集成预测的Dataset类 128 | """ 129 | def __init__(self, data_path): 130 | super(PredEnsembleDataset, self).__init__() 131 | self.data_path = data_path 132 | self.data = [] 133 | self.init() 134 | 135 | def init(self): 136 | with open(self.data_path, 'r', encoding='utf-8') as f: 137 | for line in f: 138 | line = line.strip() 139 | self.data.append(line) 140 | 141 | def __len__(self): 142 | return len(self.data) 143 | 144 | def __getitem__(self, item): 145 | return self.data[item] 146 | 147 | 148 | def process_data(data_lists, bieo_dict=None, test=False): 149 | ''' 150 | 由[question] 得到 [['你','好','世','界']] [['O','O','B','E']] 151 | 另外 返回的word_list后面需要加上'',非test tag_list后面也要加 152 | @param data_lists: question的list 153 | @param bieo_dict: 由create_test_BIEO得到的标注字典 {question:['O','O','B','E']} 154 | @return: question逐字的list 和 标注的list 155 | ''' 156 | data_word_lists = [] 157 | tag_word_lists = [] 158 | for data_list in data_lists: 159 | data_word_list = [] 160 | for word in data_list: 161 | data_word_list.append(word) 162 | data_word_list.append('') 163 | data_word_lists.append(data_word_list) 164 | if bieo_dict is not None: 165 | tag_word_list = bieo_dict[data_list] + [] 166 | if not test: tag_word_list.append('') 167 | tag_word_lists.append(tag_word_list) 168 | if bieo_dict is None: 169 | return data_word_lists 170 | return data_word_lists, tag_word_lists 171 | 172 | 173 | def process_syn(data_list, entities, ans_labels): 174 | ''' 175 | 识别出句子中的实体或者同义词,并将同义词换成实体 176 | @param data_list: 177 | @param entities: 178 | @return: 179 | ''' 180 | new_data_list = [] 181 | from data_argument import read_synonyms 182 | syn_dict = read_synonyms() 183 | for index in range(len(data_list)): 184 | question = data_list[index] 185 | ans_label = ans_labels[index] 186 | entity = entities[index] 187 | if entity in syn_dict: 188 | # 排除以下实体,因为同义词中包含有价格、流量、子业务等 # 189 | if entity in ['视频会员通用流量月包', '1元5gb流量券', '校园卡活动', '新惠享流量券活动', '专属定向流量包', 190 | '任我看视频流量包', '快手定向流量包', '5g智享套餐家庭版', '视频会员5gb通用流量7天包', 191 | '语音信箱', '全国亲情网', '通话圈', '和家庭分享']: 192 | new_data_list.append(question) 193 | continue 194 | candidate_list = [] 195 | for syn in syn_dict[entity] + [entity]: 196 | if syn == '无': continue 197 | if syn == '': continue 198 | if syn in question: 199 | candidate_list.append(syn) 200 | candidate_list = sorted(candidate_list, key=lambda x:len(x)) 201 | if len(candidate_list) == 0: 202 | new_data_list.append(question) 203 | continue 204 | syn_curr = candidate_list[-1] 205 | question = question.replace(syn_curr, entity) 206 | new_data_list.append(question) 207 | else: 208 | new_data_list.append(question) 209 | return new_data_list 210 | 211 | 212 | def process_postdo(subs, question ,entity): 213 | if entity == '彩铃': 214 | if len(subs) == 0: 215 | if '3元' in question: subs.append(('价格', '3')) 216 | if '5元' in question: subs.append(('价格', '5')) 217 | if entity == '承诺低消优惠话费活动': 218 | if '百度包' in question: 219 | if ('子业务', '百度包') not in subs: 220 | subs.append(('子业务', '百度包')) 221 | if entity == '新惠享流量券活动': 222 | if '30' in question: 223 | subs = [('价格', '30')] 224 | if '40' in question: 225 | subs.append(('流量', '40')) 226 | if '20' in question: 227 | subs = [('价格', '20')] 228 | if question.count('20') > 1: 229 | subs.append(('流量', '20')) 230 | if entity == '神州行5元卡': 231 | if len(subs) == 0: 232 | if question.count('5') == 2 or ('5' in question and '30' in question): 233 | subs.append(('价格', '5')) 234 | if entity == '承诺低消优惠话费活动': 235 | if '38' in question: 236 | subs = [('价格', '38')] 237 | if '18' in question: 238 | subs = [('价格', '18')] 239 | 240 | # todo 加入其他后处理 241 | return subs 242 | 243 | 244 | def process_postdo_last(question, entity_label, prop_label): 245 | if '家庭亲情网' in question or '家庭亲情通话' in question or '互打免费' in question: 246 | entity_label = '全国亲情网' 247 | elif '家庭亲情号' in question: 248 | entity_label = '和家庭分享' 249 | if '成员' in question: 250 | entity_label = '通话圈' 251 | elif '家庭亲情' in question: 252 | entity_label = '和家庭分享' 253 | # 北京移动plus会员权益卡 和 北京移动plus会员 混淆 254 | if '权益卡' in question: 255 | entity_label = '北京移动plus会员权益卡' 256 | # 30元5gb半价体验版 和 30元5gb 混淆 257 | if '半价' in question: 258 | entity_label = '30元5gb半价体验版' 259 | # 流量无忧包 和 畅享套餐 混淆 260 | if '无忧' in question: 261 | entity_label = '流量无忧包' 262 | if '全国亲情网' in question and '全国亲情网功能费优惠活动' not in question: 263 | entity_label = '全国亲情网' 264 | # 和留言服务中的留言服务二字,同时是语音信箱的同义词,容易预测错误 265 | if '和留言服务' in question: 266 | entity_label = '和留言' 267 | # 训练集中所有'帮我开通'都对应'开通方式' 268 | if '帮我开通' in question: 269 | if len(prop_label) == 1: 270 | prop_label = ['档位介绍-开通方式'] 271 | # 训练集中的'怎么取消不了'都对应'取消方式' 272 | if '怎么取消不了' in question: 273 | if len(prop_label) == 1: 274 | prop_label = ['档位介绍-取消方式'] 275 | return entity_label, prop_label 276 | 277 | def get_anno_dict(question, anno_list, service_names=None): 278 | ''' 279 | 从bmes标注 得到 实际的 约束属性名 和 约束属性值 280 | :param question: 用户问题(去掉空格的) 281 | :param anno_list: bmes的标注 ['O','B','I','E'] 282 | :return: anno_map: {'价钱':['100','20']} 283 | ''' 284 | type_map = {'PRICE': '价格', 'FLOW': '流量', 'SERVICE': '子业务', 'EXPIRE': '有效期'} 285 | anno_index = 0 286 | anno_map = {} 287 | while anno_index < len(anno_list): 288 | if anno_list[anno_index] == 'O': 289 | anno_index += 1 290 | else: 291 | if anno_list[anno_index].startswith('S'): 292 | anno_type = anno_list[anno_index][2:] 293 | anno_type = type_map[anno_type] 294 | anno_index += 1 295 | if anno_type not in anno_map: 296 | anno_map[anno_type] = [] 297 | anno_map[anno_type].append(question[anno_index - 1: anno_index]) 298 | continue 299 | anno_type = anno_list[anno_index][2:] 300 | anno_type = type_map[anno_type] 301 | anno_start_index = anno_index 302 | # B-xxx的次数统计 303 | B_count = 0 304 | while not anno_list[anno_index].startswith('E'): 305 | if anno_list[anno_index] == 'O': 306 | anno_index -= 1 307 | break 308 | if anno_index == len(anno_list)-1: 309 | break 310 | if anno_list[anno_index][0] == 'B': 311 | B_count += 1 312 | if B_count > 1: 313 | anno_index -= 1 314 | break 315 | anno_index += 1 316 | anno_index += 1 317 | # 对 子业务 和 数字 进行补全 318 | anno_value = question[anno_start_index: anno_index] 319 | if service_names != None: 320 | if anno_type == '子业务': 321 | candidate_list = [] 322 | for service in service_names: 323 | if anno_value in service: 324 | candidate_list.append(service) 325 | # 如果没有符合的实体则筛去该数据 326 | if len(candidate_list) == 0: 327 | continue 328 | drop = True 329 | for candidate in candidate_list: 330 | if candidate in question: 331 | drop = False 332 | if drop: continue 333 | # 对照句子,找到最符合的对象 334 | candidate_list = sorted(candidate_list, key=lambda x:len(x), reverse=True) 335 | for candidate in candidate_list: 336 | if candidate == '半年包' and anno_value=='年包':break 337 | if candidate == '腾讯视频' and anno_value=='腾讯': 338 | if '任我看' in question or '24' in question: 339 | pass 340 | else: 341 | break 342 | if candidate in question: 343 | anno_value = candidate 344 | break 345 | if anno_type == '流量' or anno_type == '价格': 346 | l, r = anno_start_index, anno_index 347 | while r < len(question) and question[r] == '0': 348 | anno_value = anno_value + '0' 349 | r += 1 350 | while l > 0 and question[l-1].isdigit(): 351 | anno_value = question[l-1] + anno_value 352 | l -= 1 353 | if anno_type not in anno_map: 354 | anno_map[anno_type] = [] 355 | # 去除 同一个属性值 出现两次的现象(如果是合约版(比较句出现)、20(价格20流量20)、18(train.xlsx出现),则是正常现象) 356 | if anno_value in anno_map[anno_type]: 357 | if anno_value == '合约版' or anno_value == '20' or anno_value == '18': 358 | pass 359 | else: continue 360 | anno_map[anno_type].append(anno_value) 361 | # if service_names is not None: 362 | # if '和留言' in question and '年包' in question and '月包和留言年包' not in question: 363 | # if '子业务' in anno_map: 364 | # if '年包' not in anno_map['子业务']: 365 | # anno_map['子业务'].append('年包') 366 | # else: 367 | # anno_map['子业务'] = ['年包'] 368 | return anno_map 369 | 370 | 371 | def get_anno_dict_with_pos(question, anno_list, service_names=None): 372 | ''' 373 | 从bmes标注 得到 实际的 约束属性名 和 约束属性值, 以及位置 374 | :param question: 用户问题(去掉空格的) 375 | :param anno_list: bmes的标注 ['O','B','I','E'] 376 | :return: anno_map: {'价钱':[['100', 1], ['20', 3]]} 记录值和在句子中出现的位置 377 | ''' 378 | type_map = {'PRICE': '价格', 'FLOW': '流量', 'SERVICE': '子业务', 'EXPIRE': '有效期'} 379 | anno_index = 0 380 | anno_map = {} 381 | while anno_index < len(anno_list): 382 | if anno_list[anno_index] == 'O': 383 | anno_index += 1 384 | else: 385 | if anno_list[anno_index].startswith('S'): 386 | anno_type = anno_list[anno_index][2:] 387 | anno_type = type_map[anno_type] 388 | anno_index += 1 389 | if anno_type not in anno_map: 390 | anno_map[anno_type] = [] 391 | anno_map[anno_type].append([question[anno_index - 1: anno_index], anno_index - 1]) 392 | continue 393 | anno_type = anno_list[anno_index][2:] 394 | anno_type = type_map[anno_type] 395 | anno_start_index = anno_index 396 | # B-xxx的次数统计 397 | B_count = 0 398 | while not anno_list[anno_index].startswith('E'): 399 | if anno_list[anno_index] == 'O': 400 | anno_index -= 1 401 | break 402 | if anno_index == len(anno_list)-1: 403 | break 404 | if anno_list[anno_index][0] == 'B': 405 | B_count += 1 406 | if B_count > 1: 407 | anno_index -= 1 408 | break 409 | anno_index += 1 410 | anno_index += 1 411 | # 对 子业务 和 数字 进行补全 412 | anno_value = question[anno_start_index: anno_index] 413 | pos = anno_start_index 414 | if service_names != None: 415 | if anno_type == '子业务': 416 | candidate_list = [] 417 | for service in service_names: 418 | if anno_value in service: 419 | candidate_list.append(service) 420 | # 如果没有符合的实体则筛去该数据 421 | if len(candidate_list) == 0: 422 | continue 423 | drop = True 424 | for candidate in candidate_list: 425 | if candidate in question: 426 | drop = False 427 | if drop: continue 428 | # 对照句子,找到最符合的对象 429 | candidate_list = sorted(candidate_list, key=lambda x:len(x), reverse=True) 430 | for candidate in candidate_list: 431 | if candidate == '半年包' and anno_value=='年包':break 432 | if candidate == '腾讯视频' and anno_value=='腾讯': 433 | if '任我看' in question or '24' in question: 434 | pass 435 | else: 436 | break 437 | if candidate in question: 438 | anno_value = candidate 439 | break 440 | if anno_type == '流量' or anno_type == '价格': 441 | l, r = anno_start_index, anno_index 442 | while r < len(question) and question[r] == '0': 443 | anno_value = anno_value + '0' 444 | r += 1 445 | while l > 0 and question[l-1].isdigit(): 446 | anno_value = question[l-1] + anno_value 447 | l -= 1 448 | if anno_type not in anno_map: 449 | anno_map[anno_type] = [] 450 | # 去除 同一个属性值 出现两次的现象(如果是合约版(比较句出现)、20(价格20流量20)、18(train.xlsx出现),则是正常现象) 451 | # enhence版 不需要去重 452 | # if anno_value in anno_map[anno_type]: 453 | # if anno_value == '合约版' or anno_value == '20' or anno_value == '18': 454 | # pass 455 | # else: continue 456 | anno_map[anno_type].append([anno_value, pos]) 457 | return anno_map 458 | 459 | 460 | def pred_collate_fn(batch_data): 461 | """ 462 | 用于用于预测的collate函数 463 | :param batch_data: 464 | :return: 465 | """ 466 | input_ids, token_type_ids, attention_mask = [], [], [] 467 | questions = [] 468 | for instance in copy.deepcopy(batch_data): 469 | questions.append(instance['question']) 470 | input_ids.append(instance['token']['input_ids'][0].squeeze(0)) 471 | token_type_ids.append(instance['token']['token_type_ids'][0].squeeze(0)) 472 | attention_mask.append(instance['token']['attention_mask'][0].squeeze(0)) 473 | return torch.stack(input_ids), torch.stack(token_type_ids), \ 474 | torch.stack(attention_mask), questions 475 | 476 | 477 | def bert_collate_fn(batch_data): 478 | """ 479 | 用于BERT训练的collate函数 480 | :param batch_data: 481 | :return: 482 | """ 483 | input_ids, token_type_ids, attention_mask = [], [], [] 484 | ans_labels, prop_labels, entity_labels = [], [], [] 485 | efficiency_list = [] 486 | for instance in copy.deepcopy(batch_data): 487 | input_ids.append(instance['token']['input_ids'][0].squeeze(0)) 488 | token_type_ids.append(instance['token']['token_type_ids'][0].squeeze(0)) 489 | attention_mask.append(instance['token']['attention_mask'][0].squeeze(0)) 490 | ans_labels.append(instance['ans_label']) 491 | prop_labels.append(torch.tensor(instance['prop_label'])) 492 | entity_labels.append(instance['entity_label']) 493 | efficiency_list.append(instance['efficiency']) 494 | return torch.stack(input_ids), torch.stack(token_type_ids), \ 495 | torch.stack(attention_mask), torch.tensor(ans_labels), \ 496 | torch.stack(prop_labels), torch.tensor(entity_labels), \ 497 | torch.tensor(efficiency_list, dtype=torch.float) 498 | 499 | 500 | def binary_collate_fn(batch_data): 501 | """ 502 | 用于二类训练的collate函数 503 | """ 504 | input_ids, token_type_ids, attention_mask = [], [], [] 505 | labels = [] 506 | efficiency_list = [] 507 | for instance in copy.deepcopy(batch_data): 508 | input_ids.append(instance['token']['input_ids'][0].squeeze(0)) 509 | token_type_ids.append(instance['token']['token_type_ids'][0].squeeze(0)) 510 | attention_mask.append(instance['token']['attention_mask'][0].squeeze(0)) 511 | labels.append(instance['label']) 512 | efficiency_list.append(instance['efficiency']) 513 | return torch.stack(input_ids), torch.stack(token_type_ids), \ 514 | torch.stack(attention_mask), torch.tensor(labels), \ 515 | torch.tensor(efficiency_list, dtype=torch.float) 516 | 517 | 518 | class LabelHub(object): 519 | """ 520 | 分类任务的Label数据中心 521 | """ 522 | def __init__(self, label_file): 523 | super(LabelHub, self).__init__() 524 | self.label_file = label_file 525 | self.ans_label2id = {} 526 | self.ans_id2label = {} 527 | self.prop_label2id = {} 528 | self.prop_id2label = {} 529 | self.entity_label2id = {} 530 | self.entity_id2label = {} 531 | self.binary_label2id = {} 532 | self.binary_id2label = {} 533 | self.load_label() 534 | 535 | def load_label(self): 536 | with open(self.label_file, 'r', encoding='utf-8') as f: 537 | label_dict = json.load(f) 538 | self.ans_label2id = label_dict['ans_type'] 539 | self.prop_label2id = label_dict['main_property'] 540 | self.entity_label2id = label_dict['entity'] 541 | self.binary_label2id = label_dict['binary_type'] 542 | for k, v in self.ans_label2id.items(): 543 | self.ans_id2label[v] = k 544 | for k, v in self.prop_label2id.items(): 545 | self.prop_id2label[v] = k 546 | for k, v in self.entity_label2id.items(): 547 | self.entity_id2label[v] = k 548 | for k, v in self.binary_label2id.items(): 549 | self.binary_id2label[v] = k 550 | 551 | 552 | def create_test_BIEO(excel_path, test = True): 553 | """ 554 | 由excel生成对应的bieo标注,并保存,同时测试标注方法的准确性 555 | @param excel_path: 数据路径 556 | @return: 557 | """ 558 | def get_bieo(data): 559 | ''' 560 | 得到bmes 561 | :param data: train_denoised.xlsx的dataframe 562 | :return: dict ['question': ['O','B','E']] 563 | ''' 564 | # TODO 565 | # 标注失败:最便宜最优惠的标注 应该为 价格1; 566 | # 三十、四十、 十元(先改三十元)、 六块、 一个月 567 | # 有效期太少了 568 | type_map = {'价格': 'PRICE', '流量': 'FLOW', '子业务': 'SERVICE'} 569 | result_dict = {} 570 | 571 | for index, row in data.iterrows(): 572 | question = row['用户问题'] 573 | char_list = ['O' for _ in range(len(question))] 574 | 575 | constraint_names = row['约束属性名'] 576 | constraint_values = row['约束属性值'] 577 | constraint_names_list = re.split(r'[|\|]', str(constraint_names)) 578 | constraint_values_list = re.split(r'[|\|]', str(constraint_values)) 579 | constraint_names_list = [name.strip() for name in constraint_names_list] 580 | constraint_values_list = [value.strip() for value in constraint_values_list] 581 | question_len = len(question) 582 | question_index = 0 583 | # 在句子中标注constraint 584 | for cons_index in range(len(constraint_values_list)): 585 | name = constraint_names_list[cons_index] 586 | if name == '有效期': continue 587 | value = constraint_values_list[cons_index] 588 | if value in question[question_index:]: 589 | temp_index = question[question_index:].find(value) + question_index 590 | if len(value) == 1: 591 | char_list[temp_index] = 'S-' + type_map[name] 592 | continue 593 | else: 594 | for temp_i in range(temp_index + 1, temp_index + len(value) - 1): 595 | char_list[temp_i] = 'I-' + type_map[name] 596 | char_list[temp_index] = 'B-' + type_map[name] 597 | char_list[temp_index + len(value) - 1] = 'E-' + type_map[name] 598 | question_index = min(temp_index + len(value), question_len) 599 | elif value in question: 600 | temp_index = question.find(value) 601 | if len(value) == 1: 602 | char_list[temp_index] = 'S-' + type_map[name] 603 | continue 604 | else: 605 | for temp_i in range(temp_index + 1, temp_index + len(value) - 1): 606 | char_list[temp_i] = 'I-' + type_map[name] 607 | char_list[temp_index] = 'B-' + type_map[name] 608 | char_list[temp_index + len(value) - 1] = 'E-' + type_map[name] 609 | result_dict[question] = char_list 610 | 611 | return result_dict 612 | 613 | def test_bieo(excel_data): 614 | ''' 615 | 用来测试正则化生成的bieo相比于excel中的真值,准确度如何 616 | :param: excel中读取的dataframe 617 | :return: 618 | ''' 619 | 620 | annos = get_bieo(excel_data) 621 | 622 | TP, TN, FP = 0, 0, 0 623 | 624 | for index, row in excel_data.iterrows(): 625 | question = row['用户问题'] 626 | # 获取约束的标注 627 | anno_list = annos[question] 628 | anno_dict = get_anno_dict(question, anno_list) 629 | # 获取约束的真值 630 | constraint_names = row['约束属性名'] 631 | constraint_values = row['约束属性值'] 632 | constraint_names_list = re.split(r'[|\|]', str(constraint_names)) 633 | constraint_values_list = re.split(r'[|\|]', str(constraint_values)) 634 | constraint_dict = {} 635 | for constraint_index in range(len(constraint_names_list)): 636 | constraint_name = constraint_names_list[constraint_index].strip() 637 | if constraint_name == '有效期': continue 638 | constraint_value = constraint_values_list[constraint_index].strip() 639 | if constraint_name not in constraint_dict: 640 | constraint_dict[constraint_name] = [] 641 | constraint_dict[constraint_name].append(constraint_value) 642 | # 比较约束的真值和标注 643 | tp = 0 644 | anno_kv = [] 645 | constraint_kv = [] 646 | for k, vs in anno_dict.items(): 647 | for v in vs: 648 | anno_kv.append(k + v) 649 | for k, vs in constraint_dict.items(): 650 | for v in vs: 651 | constraint_kv.append(k + v) 652 | # 排除 二者均为空 的情况和 比较句 的情况 653 | if len(anno_kv) == 0 and constraint_kv[0] == 'nannan': continue 654 | if len(anno_kv) == 0 and (constraint_kv[0] == '价格1' or constraint_kv[0] == '流量1'): continue 655 | 656 | anno_len = len(anno_kv) 657 | cons_len = len(constraint_kv) 658 | for kv in constraint_kv: 659 | if kv in anno_kv: 660 | tp += 1 661 | anno_kv.remove(kv) 662 | if tp != cons_len: 663 | print('-------') 664 | print(question) 665 | print('anno: ', anno_kv) 666 | print('cons: ', constraint_kv) 667 | TP += tp 668 | FP += (cons_len - tp) 669 | TN += (anno_len - tp) 670 | print('测试bmes结果:' + 'TP: {} FP: {} TN:{} '.format(TP, FP, TN)) 671 | 672 | def add_lost_anno(bieo_dict): 673 | from triples import KnowledgeGraph 674 | 675 | kg = KnowledgeGraph('../data/process_data/triples.rdf') 676 | df = pd.read_excel('../data/raw_data/train_denoised.xlsx') 677 | df.fillna('') 678 | id_list = set() 679 | ans_list = [] 680 | for iter, row in df.iterrows(): 681 | ans_true = list(set(row['答案'].split('|'))) 682 | question = row['用户问题'] 683 | ans_type = row['答案类型'] 684 | # 只对属性值的句子做处理 685 | if ans_type != '属性值': 686 | continue 687 | entity = row['实体'] 688 | main_property = row['属性名'].split('|') 689 | # 排除属性中没有'-'的情况,只要'档位介绍-xx'的情况 690 | if '-' not in main_property[0]: 691 | continue 692 | operator = row['约束算子'] 693 | # 排除operator为min或max的情况 694 | if operator != 'min' and operator != 'max': 695 | operator == 'other' 696 | else: 697 | continue 698 | sub_properties = {} 699 | cons_names = str(row['约束属性名']).split('|') 700 | cons_values = str(row['约束属性值']).split('|') 701 | if cons_names == ['nan']: cons_names = [] 702 | for index in range(len(cons_names)): 703 | if cons_names[index] not in sub_properties: 704 | sub_properties[cons_names[index]] = [] 705 | sub_properties[cons_names[index]].append(cons_values[index]) 706 | price_ans, flow_ans, service_ans = kg.fetch_wrong_ans(question, ans_type, entity, main_property, operator, 707 | []) 708 | rdf_properties = {} 709 | rdf_properties['价格'] = price_ans 710 | rdf_properties['流量'] = flow_ans 711 | rdf_properties['子业务'] = service_ans 712 | compare_result = [] 713 | for name, values in rdf_properties.items(): 714 | for value in values: 715 | if value in question: 716 | if name in sub_properties and value in sub_properties[name]: 717 | continue 718 | elif name in sub_properties: 719 | if value == '年包' and '半年包' in sub_properties[name]: 720 | continue 721 | if value == '百度' and '百度包' in sub_properties[name]: 722 | continue 723 | elif value in entity: 724 | continue 725 | else: 726 | compare_result.append(name + '_' + value) 727 | id_list.add(iter) 728 | if compare_result != []: 729 | ans_list.append(compare_result) 730 | 731 | raw_data = pd.read_excel(excel_path) 732 | 733 | if test: 734 | test_bieo(raw_data) 735 | bieo_dict = get_bieo(raw_data) 736 | 737 | with open(r'../data/file/train_bieo.json', 'w') as f: 738 | json.dump(bieo_dict, f, indent=2, ensure_ascii=False) 739 | 740 | return bieo_dict 741 | 742 | 743 | def create_test_BIO(excel_path, test = True): 744 | """ 745 | 由excel生成对应的BIO标注,并保存,同时测试标注方法的准确性 746 | @param excel_path: 数据路径 747 | @return: 748 | """ 749 | def get_bio(data): 750 | ''' 751 | 得到bmes 752 | :param data: train_denoised.xlsx的dataframe 753 | :return: dict ['question': ['O','B','I']] 754 | ''' 755 | # TODO 756 | # 标注失败:最便宜最优惠的标注 应该为 价格1; 757 | # 三十、四十、 十元(先改三十元)、 六块、 一个月 758 | # 有效期太少了 可用规则处理 759 | type_map = {'价格': 'PRICE', '流量': 'FLOW', '子业务': 'SERVICE'} 760 | result_dict = {} 761 | 762 | for index, row in data.iterrows(): 763 | question = row['用户问题'] 764 | char_list = ['O' for _ in range(len(question))] 765 | 766 | constraint_names = row['约束属性名'] 767 | constraint_values = row['约束属性值'] 768 | constraint_names_list = re.split(r'[|\|]', str(constraint_names)) 769 | constraint_values_list = re.split(r'[|\|]', str(constraint_values)) 770 | constraint_names_list = [name.strip() for name in constraint_names_list] 771 | constraint_values_list = [value.strip() for value in constraint_values_list] 772 | question_len = len(question) 773 | question_index = 0 774 | # 在句子中标注constraint 775 | for cons_index in range(len(constraint_values_list)): 776 | name = constraint_names_list[cons_index] 777 | if name == '有效期': 778 | continue 779 | value = constraint_values_list[cons_index] 780 | if value in question[question_index:]: 781 | temp_index = question[question_index:].find(value) + question_index 782 | char_list[temp_index] = 'B-' + type_map[name] 783 | for temp_i in range(temp_index + 1, temp_index + len(value)): 784 | char_list[temp_i] = 'I-' + type_map[name] 785 | question_index = min(temp_index + len(value), question_len) 786 | elif value in question: 787 | temp_index = question.find(value) 788 | if char_list[temp_index] == 'O': 789 | char_list[temp_index] = 'B-' + type_map[name] 790 | for temp_i in range(temp_index + 1, temp_index + len(value)): 791 | if char_list[temp_i] == 'O': 792 | char_list[temp_i] = 'I-' + type_map[name] 793 | else: 794 | print('标注冲突:"{}"'.format(question)) 795 | break 796 | else: 797 | print('标注冲突。"{}"'.format(question)) 798 | result_dict[question] = char_list 799 | 800 | return result_dict 801 | 802 | def test_bio(excel_data): 803 | ''' 804 | 用来测试正则化生成的BIO相比于excel中的真值,准确度如何 805 | :param: excel中读取的dataframe 806 | :return: 807 | ''' 808 | 809 | annos = get_bio(excel_data) 810 | TP, TN, FP = 0, 0, 0 811 | for index, row in excel_data.iterrows(): 812 | question = row['用户问题'] 813 | # 获取约束的标注 814 | anno_list = annos[question] 815 | anno_dict = get_anno_dict(question, anno_list) 816 | # 获取约束的真值 817 | constraint_names = row['约束属性名'] 818 | constraint_values = row['约束属性值'] 819 | constraint_names_list = re.split(r'[|\|]', str(constraint_names)) 820 | constraint_values_list = re.split(r'[|\|]', str(constraint_values)) 821 | constraint_dict = {} 822 | for constraint_index in range(len(constraint_names_list)): 823 | constraint_name = constraint_names_list[constraint_index].strip() 824 | if constraint_name == '有效期': 825 | continue 826 | constraint_value = constraint_values_list[constraint_index].strip() 827 | if constraint_name not in constraint_dict: 828 | constraint_dict[constraint_name] = [] 829 | constraint_dict[constraint_name].append(constraint_value) 830 | # 比较约束的真值和标注 831 | tp = 0 832 | anno_kv = [] 833 | constraint_kv = [] 834 | for k, vs in anno_dict.items(): 835 | for v in vs: 836 | anno_kv.append(k + v) 837 | for k, vs in constraint_dict.items(): 838 | for v in vs: 839 | constraint_kv.append(k + v) 840 | # 排除 二者均为空 的情况和 比较句 的情况 841 | if len(anno_kv) == 0 and constraint_kv[0] == 'nannan': 842 | continue 843 | if len(anno_kv) == 0 and (constraint_kv[0] == '价格1' or constraint_kv[0] == '流量1'): 844 | continue 845 | 846 | anno_len = len(anno_kv) 847 | cons_len = len(constraint_kv) 848 | for kv in constraint_kv: 849 | if kv in anno_kv: 850 | tp += 1 851 | anno_kv.remove(kv) 852 | if tp != cons_len: 853 | print('-------') 854 | print(question) 855 | print('anno: ', anno_kv) 856 | print('cons: ', constraint_kv) 857 | TP += tp 858 | FP += (cons_len - tp) 859 | TN += (anno_len - tp) 860 | print('测试bmes结果:' + 'TP: {} FP: {} TN:{} '.format(TP, FP, TN)) 861 | 862 | raw_data = pd.read_excel(excel_path) 863 | 864 | if test: 865 | test_bio(raw_data) 866 | bio_dict = get_bio(raw_data) 867 | 868 | with open(r'../data/file/train_bio.json', 'w') as f: 869 | json.dump(bio_dict, f, ensure_ascii=False, indent=2) 870 | return bio_dict 871 | 872 | 873 | if __name__ == '__main__': 874 | # ----------------自动创建NER标注文件-------------------- 875 | create_test_BIEO(excel_path='../data/raw_data/train_denoised.xlsx') 876 | create_test_BIO(excel_path='../data/raw_data/train_denoised.xlsx') 877 | -------------------------------------------------------------------------------- /code/data_argument.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | from utils import setup_seed, get_time_dif, view_gpu_info 4 | import pandas as pd 5 | from copy import deepcopy 6 | import math 7 | import random 8 | from tqdm import tqdm 9 | import numpy as np 10 | import json 11 | import re 12 | 13 | 14 | def denoising(source_file=r'../data/raw_data/train.xlsx', target_file=r'../data/raw_data/train_denoised.xlsx'): 15 | """ 16 | 对原数据 train.xlsx 进行去噪,并保存为 train_denoised.xlsx 17 | 主要进行了以下处理:1 去除多余空格 2 统一'|' 3 字母转换为小写 4 去除约束算子 5 float(nan)转换为'' 18 | @return: 19 | """ 20 | def process_field(field): 21 | """ 22 | 对字段进行处理: 1 去掉空格 2 统一'|' 3 全部用小写 23 | @param field 24 | @return: field 25 | """ 26 | field = field.replace(' ','').replace('|', '|') 27 | field = field.lower() 28 | return field 29 | 30 | def question_replace(question): 31 | """ 32 | 对问题进行去噪 33 | :param question: 34 | :return: 35 | """ 36 | question = question.replace('二十', '20') 37 | question = question.replace('三十', '30') 38 | question = question.replace('四十', '40') 39 | question = question.replace('五十', '50') 40 | question = question.replace('六十', '60') 41 | question = question.replace('七十', '70') 42 | question = question.replace('八十', '80') 43 | question = question.replace('九十', '90') 44 | question = question.replace('一百', '100') 45 | question = question.replace('十块', '10块') 46 | question = question.replace('十元', '10元') 47 | question = question.replace('六块', '6块') 48 | question = question.replace('一个月', '1个月') 49 | question = question.replace('2O', '20') 50 | if '一元一个g' not in question: 51 | question = question.replace('一元', '1元') 52 | if 'train' in source_file: 53 | question = question.replace(' ', '_') 54 | else: 55 | question = question.replace(' ', '') 56 | question = question.lower() 57 | return question 58 | raw_data = pd.read_excel(source_file) 59 | # 训练数据 60 | if 'train' in source_file: 61 | raw_data['有效率'] = [1] * len(raw_data) 62 | for index, row in raw_data.iterrows(): 63 | row['用户问题'] = question_replace(row['用户问题']) 64 | # TODO 检查是否有效 65 | # if row['实体'] == '畅享套餐促销优惠': 66 | # row['实体'] = '畅享套餐促销优惠活动' 67 | for index_row in range(len(row)): 68 | field = str(row.iloc[index_row]) 69 | if field == 'nan': 70 | field = '' 71 | row.iloc[index_row] = process_field(field) 72 | if not (row['约束算子'] == 'min' or row['约束算子'] == 'max'): 73 | row['约束算子'] = '' 74 | raw_data.iloc[index] = row 75 | # 测试数据 76 | else: 77 | length = [] 78 | columns = raw_data.columns 79 | for column in columns: 80 | length.append(len(str(raw_data.iloc[1].at[column]))) 81 | max_id = np.argmax(length) 82 | for index, row in raw_data.iterrows(): 83 | # raw_data.loc[index, 'query'] = question_replace(row['query']) 84 | raw_data.loc[index, columns[max_id]] = question_replace(row[columns[max_id]]) 85 | # print(raw_data) 86 | raw_data.to_excel(target_file, index=False) 87 | 88 | 89 | def read_synonyms(): 90 | """ 91 | 读取synonyms.txt 并转换成dict 92 | @return: 93 | """ 94 | synonyms_dict = {} 95 | with open(r'../data/raw_data/synonyms.txt', 'r') as f: 96 | lines = f.readlines() 97 | for line in lines: 98 | line = line.strip('\n') 99 | ent, syn_str = line.split() 100 | syns = syn_str.split('|') 101 | synonyms_dict[ent] = syns 102 | 103 | return synonyms_dict 104 | 105 | 106 | def create_argument_data(synonyms_dict): 107 | """ 108 | 对数据进行同义词替换,并保存为新的xlsx 109 | @param synonyms_dict: 来自synonyms.txt 110 | @return: 111 | """ 112 | raw_data = pd.read_excel(r'../data/raw_data/train_denoised.xlsx') 113 | new_data = [] 114 | for index, row in raw_data.iterrows(): 115 | question = row['用户问题'] 116 | for k, vs in synonyms_dict.items(): 117 | if k in question: 118 | if '无' in vs: 119 | continue 120 | for v in vs: 121 | row_temp = deepcopy(row) 122 | row_temp['用户问题'] = question.replace(k, v) 123 | new_data.append(row_temp) 124 | new_data = pd.DataFrame(new_data) 125 | save_data = pd.concat([raw_data, new_data], ignore_index=True) 126 | print(save_data) 127 | save_data.to_excel(r'../data/raw_data/train_syn.xlsx', index=False) 128 | 129 | 130 | def augment_for_few_data(source_file, target_file, t=20, efficiency=0.8): 131 | """ 132 | 扩充样本量少于t的到t 133 | :return: 134 | """ 135 | from nlpcda import Similarword, RandomDeleteChar, CharPositionExchange 136 | 137 | def augment2t_nlpcda(column2id, threshold=t): 138 | """ 139 | 使用nlpcda包来扩充 140 | """ 141 | from utils import setup_seed 142 | setup_seed(1) 143 | a1 = CharPositionExchange(create_num=2, change_rate=0.1, seed=1) 144 | # a2 = RandomDeleteChar(create_num=2, change_rate=0.1, seed=1) 145 | a3 = Similarword(create_num=2, change_rate=0.1, seed=1) 146 | for k, v in tqdm(column2id.items()): 147 | if len(v) < threshold: 148 | p = int(math.ceil(threshold / len(v) * 1.0)) 149 | for row_id in v: 150 | aug_list = [] 151 | row = raw_data.loc[row_id] 152 | question = str(row['用户问题']).strip() 153 | while len(aug_list) < p: 154 | aug_list += a1.replace(question)[1:] 155 | # aug_list += a2.replace(question)[1:] 156 | aug_list += a3.replace(question)[1:] 157 | aug_list = list(set(aug_list)) 158 | for a_q in sorted(aug_list): 159 | copy_row = deepcopy(row) 160 | copy_row['用户问题'] = a_q 161 | target_data.loc[len(target_data)] = copy_row 162 | # 为并列句去除逗号增强 163 | if k == '并列句': 164 | for row_id in v: 165 | row = raw_data.loc[row_id] 166 | question = str(row['用户问题']).strip() 167 | if ',' in question: 168 | question = question.replace(',', '') 169 | copy_row = deepcopy(row) 170 | copy_row['用户问题'] = question 171 | target_data.loc[len(target_data)] = copy_row 172 | target_data['有效率'] = [efficiency] * len(target_data) 173 | print('开始扩充少量数据...目前阈值为{}'.format(t)) 174 | ans2id, prop2id, entity2id = {}, {}, {} 175 | raw_data = pd.read_excel(source_file) 176 | target_data = raw_data.drop(raw_data.index) 177 | for idx, row in raw_data.iterrows(): 178 | # 答案类型 179 | if row['答案类型'] not in ans2id: 180 | ans2id[row['答案类型']] = set() 181 | ans2id[row['答案类型']].add(idx) 182 | # 属性名 183 | prop_list = row['属性名'].split('|') 184 | for prop in prop_list: 185 | if prop not in prop2id: 186 | prop2id[prop] = set() 187 | prop2id[prop].add(idx) 188 | # 实体 189 | if row['实体'] not in entity2id: 190 | entity2id[row['实体']] = set() 191 | entity2id[row['实体']].add(idx) 192 | # 添加 193 | # augment_to_t(ans2id) 194 | # augment_to_t(prop2id) 195 | # augment_to_t(entity2id) 196 | augment2t_nlpcda(ans2id) 197 | augment2t_nlpcda(prop2id) 198 | augment2t_nlpcda(entity2id) 199 | target_data.to_excel(target_file, index=False) 200 | 201 | 202 | def augment_for_synonyms(source_file, target_file): 203 | """ 204 | 根据一定的规则,有选择的将同义词替换成原句子 205 | :param source_file: 206 | :param target_file: 207 | :return: 208 | """ 209 | print('开始进行同义词增强...') 210 | with open('../data/file/entity_count.json', 'r', encoding='utf-8') as f: 211 | entity_count = json.load(f) 212 | with open('../data/file/synonyms.json', 'r', encoding='utf-8') as f: 213 | synonym_dict = json.load(f) 214 | entity2synonym = synonym_dict['entity2synonym'] 215 | raw_data = pd.read_excel(source_file) 216 | target_data = raw_data.drop(raw_data.index) 217 | random.seed(1) 218 | for idx, row in tqdm(raw_data.iterrows(), total=len(raw_data)): 219 | question = row['用户问题'] 220 | entity = row['实体'] 221 | if entity in entity2synonym: 222 | synonym_list = sorted(entity2synonym[entity] + [entity], key=lambda item: len(item), reverse=True) 223 | contain_word = '' 224 | for s in synonym_list: 225 | if s in question: 226 | contain_word = s 227 | break 228 | if contain_word != '': 229 | synonym_list.remove(contain_word) 230 | rest_word = synonym_list 231 | # 根据entity的样本量来进行按概率采样,防止增强的样本量太大 232 | # 同义词增强,本质上还是需要侧重考虑少样本 233 | count = entity_count[entity] 234 | alpha = 1 235 | # 不需要修改rest_word 236 | if count <= 30: 237 | pass 238 | elif 30 < count <= 50: 239 | rest_word = random.sample(rest_word, int(math.ceil(len(rest_word)*0.5))) 240 | elif 50 < count <= 100: 241 | rest_word = random.sample(rest_word, int(math.ceil(len(rest_word)*0.2))) 242 | else: 243 | if random.random() < 0.5: 244 | rest_word = random.sample(rest_word, k=1) 245 | else: 246 | rest_word = [] 247 | for s in rest_word: 248 | copy_row = deepcopy(row) 249 | copy_row['用户问题'] = question.replace(contain_word, s) 250 | target_data.loc[len(target_data)] = copy_row 251 | target_data['有效率'] = [0.99] * len(target_data) 252 | target_data.to_excel(target_file, index=False) 253 | 254 | 255 | def augment_from_simbert(source_file, target_file, 256 | model_file='/data2/hezhenfeng/other_model_files/chinese_simbert_L-6_H-384_A-12', 257 | gpu_id=0, 258 | start_line=0, 259 | end_line=5000, 260 | efficiency=0.95): 261 | """ 262 | 利用Simbert生成相似句,取相似度大于95%的句子 263 | :param source_file: 264 | :param target_file: 265 | :param model_file: 266 | :param start_line: 起始行号 267 | :param end_line: 结束行号 268 | :param gpu_id: 269 | :param efficiency: 270 | :return: 271 | """ 272 | from nlpcda import Simbert 273 | print('加载simbert模型...') 274 | config = { 275 | 'model_path': model_file, 276 | 'CUDA_VISIBLE_DEVICES': '{}'.format(gpu_id), 277 | 'max_len': 40, 278 | 'seed': 1, 279 | 'device': 'cuda', 280 | 'threshold': efficiency 281 | } 282 | simbert = Simbert(config=config) 283 | raw_data = pd.read_excel(source_file) 284 | target_data = raw_data.drop(raw_data.index) 285 | for idx, row in tqdm(raw_data.iterrows(), total=len(raw_data)): 286 | if start_line <= idx < end_line: 287 | # 用pandas自带的str类的数据会无法复现结果 288 | synonyms = simbert.replace(sent=str(row['用户问题']).strip(), create_num=5) 289 | for synonym, similarity in synonyms: 290 | if similarity >= config['threshold']: 291 | copy_row = deepcopy(row) 292 | copy_row['用户问题'] = synonym 293 | target_data.loc[len(target_data)] = copy_row 294 | else: 295 | break 296 | elif idx >= end_line: 297 | break 298 | target_data['有效率'] = [efficiency] * len(target_data) 299 | target_data.to_excel(target_file, index=False) 300 | 301 | 302 | def function1(args): 303 | """ 304 | simbert多进程增强的子进程 305 | :param args: 306 | :return: 307 | """ 308 | idx, gpu_id, start, end = args['idx'], args['gpu_id'], args['start'], args['end'] 309 | augment_from_simbert(source_file='../data/raw_data/train_denoised.xlsx', 310 | target_file=f'../data/raw_data/train_augment_simbert_{idx}.xlsx', 311 | efficiency=0.95, 312 | start_line=start, 313 | end_line=end, 314 | gpu_id=gpu_id) 315 | 316 | 317 | def run_process(): 318 | """ 319 | simbert多进程增强的子进程 320 | :return: 321 | """ 322 | from multiprocessing import Pool 323 | args = [] 324 | # 初始设置为1 325 | cpu_worker_num = 1 326 | span = 5000//cpu_worker_num 327 | start = 0 328 | end = span 329 | for i in range(cpu_worker_num): 330 | args.append({'idx': i, 'gpu_id': 0, 'start': start, 'end': end}) 331 | start = end 332 | end += span 333 | if i == cpu_worker_num-2: 334 | end = 5000 335 | with Pool(cpu_worker_num) as p: 336 | p.map(function1, args) 337 | print('生成完毕,开始合并结果...') 338 | df = None 339 | for i in range(len(args)): 340 | temp_df = pd.read_excel(f'../data/raw_data/train_augment_simbert_{i}.xlsx') 341 | if i == 0: 342 | df = temp_df 343 | else: 344 | df = pd.concat([df, temp_df], ignore_index=True) 345 | df.to_excel('../data/raw_data/train_augment_simbert.xlsx', index=False) 346 | 347 | 348 | def multi_process_augment_from_simbert(): 349 | """ 350 | simbert用多进程来完成 351 | :return: 352 | """ 353 | import time 354 | start_time = time.time() 355 | run_process() 356 | print('生成时间为:', get_time_dif(start_time)) 357 | 358 | 359 | def augment_for_binary(source_file, target_file): 360 | """ 361 | 同义词增强以及随机去除标点符号 362 | :return: 363 | """ 364 | from utils import rm_symbol 365 | print('开始进行开通方式、条件数据增强...') 366 | with open('../data/file/synonyms.json', 'r', encoding='utf-8') as f: 367 | synonym_dict = json.load(f) 368 | synonym2entity = synonym_dict['synonym2entity'] 369 | for entity in set(synonym2entity.values()): 370 | synonym2entity[entity] = entity 371 | synonym_list = sorted(list(synonym2entity.keys()), key=lambda item: [len(item), item], reverse=True) 372 | raw_data = pd.read_excel(source_file) 373 | target_data = raw_data.drop(raw_data.index) 374 | random.seed(1) 375 | neg_word = ('取消掉', '不需要', '不用', '不想', '不要', '什么时候可以', '可以取消', '可以退订') 376 | for idx, row in tqdm(raw_data.iterrows(), total=len(raw_data)): 377 | copy_row = deepcopy(row) 378 | question_text = str(copy_row['用户问题']) 379 | p = random.random() 380 | if '开通' in row['属性名']: 381 | # 删除这些样本 382 | if question_text in ('20元20g还可以办理吗', ): 383 | continue 384 | # 随机删除逗号 385 | if p > 0.5 and ',' in row['用户问题']: 386 | copy_row['用户问题'] = rm_symbol(question_text) 387 | copy_row['有效率'] = 0.99 388 | target_data.loc[len(target_data)] = copy_row 389 | else: 390 | for synonym in synonym_list: 391 | if synonym in question_text: 392 | rw = random.choice(synonym_list) 393 | if rw != synonym: 394 | question_text.replace(synonym, rw) 395 | break 396 | copy_row['有效率'] = 0.99 397 | copy_row['用户问题'] = question_text 398 | target_data.loc[len(target_data)] = copy_row 399 | elif '取消' in row['属性名']: 400 | flag = False 401 | for n_w in neg_word: 402 | if n_w in question_text: 403 | flag = True 404 | break 405 | if flag: 406 | continue 407 | if '退订' in question_text or '取消' in question_text: 408 | replace_word = '办理' if p >= 0.5 else '开通' 409 | question_text = question_text.replace('退订', replace_word) 410 | question_text = question_text.replace('取消', replace_word) 411 | copy_row['用户问题'] = question_text 412 | copy_row['属性名'] = str(copy_row['属性名']).replace('取消', '开通') 413 | copy_row['有效率'] = 0.9 414 | target_data.loc[len(target_data)] = copy_row 415 | target_data.to_excel(target_file, index=False) 416 | 417 | 418 | def augment_for_ner(source_file, target_file): 419 | synonyms_dict = read_synonyms() 420 | df = pd.read_excel(source_file) 421 | df_syn1 = pd.DataFrame(columns = df.columns) 422 | df_syn2 = pd.DataFrame(columns = df.columns) 423 | import collections 424 | entity_dict = collections.defaultdict(dict) 425 | # 遍历得到对于每个实体,各个约束出现的次数 426 | for index, row in df.iterrows(): 427 | if row['约束算子'] == 'min' or row['约束算子'] == 'max': 428 | continue 429 | constraint_names = row['约束属性名'] 430 | constraint_values = row['约束属性值'] 431 | constraint_names_list = re.split(r'[|\|]', str(constraint_names)) 432 | constraint_values_list = re.split(r'[|\|]', str(constraint_values)) 433 | constraint_names_list = [name.strip() for name in constraint_names_list] 434 | constraint_values_list = [value.strip() for value in constraint_values_list] 435 | for i in range(len(constraint_values_list)): 436 | name = constraint_names_list[i] 437 | value = constraint_values_list[i] 438 | if name == '有效期': continue 439 | if name == 'nan': continue 440 | if value == '流量套餐': continue 441 | if name + '_' + value not in entity_dict[row['实体']]: 442 | entity_dict[row['实体']][name + '_' + value] = [] 443 | entity_dict[row['实体']][name + '_' + value].append(index) 444 | # 对约束次数少的实体进行同义词替换增强 445 | for entity, cons in entity_dict.items(): 446 | # 喜马拉雅的ner增强会因错别字无法标注 447 | if entity == '喜马拉雅流量包': continue 448 | if entity not in synonyms_dict: continue 449 | for con, con_list in cons.items(): 450 | if len(con_list) < 3: 451 | for index in con_list: 452 | question = df.iloc[index].at['用户问题'] 453 | syn_curr = None 454 | syn_curr_list = [] 455 | for syn in synonyms_dict[entity] + [entity]: 456 | if syn == '无': continue 457 | if syn in question: 458 | syn_curr_list.append(syn) 459 | syn_curr_list = sorted(syn_curr_list, key=lambda x:len(x)) 460 | syn_curr = syn_curr_list[-1] 461 | for syn in synonyms_dict[entity] + [entity]: 462 | if syn == '无': continue 463 | new_row = df.iloc[index].copy() 464 | new_row.at['用户问题'] = question.replace(syn_curr, syn) 465 | df_syn1 = df_syn1.append(new_row, ignore_index=True) 466 | for index, row in df.iterrows(): 467 | entity = row['实体'] 468 | question = row['用户问题'] 469 | if entity not in synonyms_dict: continue 470 | syn_curr = None 471 | syn_curr_list = [] 472 | for syn in synonyms_dict[entity] + [entity]: 473 | if syn == '无': continue 474 | if syn in question: 475 | syn_curr_list.append(syn) 476 | if len(syn_curr_list) == 0: continue 477 | syn_curr_list = sorted(syn_curr_list, key=lambda x:len(x)) 478 | syn_curr = syn_curr_list[-1] 479 | if any(char.isdigit() for char in syn_curr) or '年包' in syn_curr: 480 | df_syn2 = df_syn2.append(row.copy(), ignore_index=True) 481 | df_syn2 = df_syn2.append(row.copy(), ignore_index=True) 482 | df = df.append(df_syn1, ignore_index=True) 483 | df = df.append(df_syn2, ignore_index=True) 484 | df.to_excel(target_file, index=False) 485 | 486 | 487 | if __name__ == '__main__': 488 | # denoising(source_file='../data/raw_data/test.xlsx', target_file='../data/raw_data/test_denoised.xlsx') 489 | # synonyms_dict = read_synonyms() 490 | # create_argument_data(synonyms_dict) 491 | # setup_seed(1) 492 | # augment_for_few_data(source_file='../data/raw_data/train_denoised.xlsx', 493 | # target_file='../data/raw_data/train_augment_few_nlpcda.xlsx', 494 | # efficiency=0.8) 495 | # augment_for_synonyms(source_file='../data/raw_data/train_denoised.xlsx', 496 | # target_file='../data/raw_data/train_augment_synonyms_test2.xlsx') 497 | # multi_process_augment_from_simbert() 498 | # augment_from_simbert(source_file='../data/raw_data/train_denoised.xlsx', 499 | # target_file=f'../data/raw_data/train_augment_simbert.xlsx', 500 | # efficiency=0.95, 501 | # start_line=0, 502 | # end_line=5000, 503 | # gpu_id=8) 504 | augment_for_binary(source_file='../data/raw_data/train_denoised.xlsx', 505 | target_file='../data/raw_data/train_augment_binary.xlsx') 506 | -------------------------------------------------------------------------------- /code/doc2unix.py: -------------------------------------------------------------------------------- 1 | if __name__ == '__main__': 2 | with open('./run.sh', 'r') as f: 3 | text = f.read() 4 | text.replace('\r', '') 5 | 6 | with open('./run.sh', 'w') as f: 7 | f.write(text) 8 | -------------------------------------------------------------------------------- /code/docker_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | from data_argument import * 5 | from statistic import * 6 | from utils import * 7 | 8 | def prepare_dir(): 9 | ''' 10 | 用于生成存放数据的文件夹 11 | @return: 12 | ''' 13 | if not os.path.exists('/data'): 14 | os.mkdir('/data') 15 | if not os.path.exists('/data/raw_data'): 16 | os.mkdir('/data/raw_data') 17 | if not os.path.exists('/data/process_data'): 18 | os.mkdir('/data/process_data') 19 | if not os.path.exists('/data/file'): 20 | os.mkdir('/data/file') 21 | if not os.path.exists('/data/dataset'): 22 | os.mkdir('/data/dataset') 23 | if not os.path.exists('/data/Log'): 24 | os.mkdir('/data/Log') 25 | if not os.path.exists('/data/trained_model'): 26 | os.mkdir('/data/trained_model') 27 | if not os.path.exists('/data/results'): 28 | os.mkdir('/data/results') 29 | 30 | 31 | def prepare_copy_file(): 32 | ''' 33 | 为统一路径,将天池数据/tcdata中的内容复制到/data/raw_data中 34 | @return: 35 | ''' 36 | if os.path.exists('/tcdata'): 37 | for file_name in os.listdir('/tcdata'): 38 | shutil.copy(os.path.join('/tcdata', file_name), '/data/raw_data') 39 | 40 | # # 查看是否复制成功 41 | # raw_data = os.listdir('/data/raw_data') 42 | # for file_name in os.listdir('/tcdata'): 43 | # if file_name not in raw_data: 44 | # print('复制失败: ', file_name) 45 | 46 | 47 | def prepare_data(): 48 | ''' 49 | 一系列文件生成工作 50 | @return: 51 | ''' 52 | # 生成denoiseed.xlsx 53 | denoising(source_file='../data/raw_data/train.xlsx', 54 | target_file='../data/raw_data/train_denoised.xlsx') 55 | denoising(source_file='../tcdata/test2.xlsx', 56 | target_file='../data/raw_data/test_denoised.xlsx') 57 | # 生成统计文件:实体个数、同义词个数 58 | statistic_entity() 59 | statistic_synonyms() 60 | # entity_map() 61 | # 生成 数据增强 文件 62 | augment_for_few_data(source_file='../data/raw_data/train_denoised.xlsx', 63 | target_file='../data/raw_data/train_augment_few_nlpcda.xlsx') 64 | multi_process_augment_from_simbert() 65 | augment_for_synonyms(source_file='../data/raw_data/train_denoised.xlsx', 66 | target_file='../data/raw_data/train_augment_synonyms.xlsx') 67 | # 生成标注好的训练数据 68 | make_dataset(['../data/raw_data/train_denoised.xlsx'], 69 | target_file='../data/dataset/cls_labeled.txt', 70 | label_file='../data/dataset/cls_label2id.json', 71 | train=True) 72 | # 将增强文件合并 73 | make_dataset(['../data/raw_data/train_augment_few_nlpcda.xlsx', 74 | '../data/raw_data/train_augment_simbert.xlsx', 75 | '../data/raw_data/train_augment_synonyms.xlsx'], 76 | target_file='../data/dataset/augment3.txt', # 修改此处应该修改run.sh中的文件 77 | label_file=None, 78 | train=True) 79 | make_dataset(['../data/raw_data/test_denoised.xlsx'], 80 | target_file='../data/dataset/cls_unlabeled.txt', 81 | label_file=None, 82 | train=False) 83 | # ner数据增强 84 | augment_for_ner(source_file='../data/raw_data/train_denoised.xlsx', 85 | target_file='../data/raw_data/train_denoised_ner.xlsx') 86 | # 制作二分类训练数据 87 | augment_for_binary(source_file='../data/raw_data/train_denoised.xlsx', 88 | target_file='../data/raw_data/train_augment_binary.xlsx') 89 | make_dataset_for_binary(['../data/raw_data/train_denoised.xlsx'], 90 | target_file='../data/dataset/binary_labeled.txt') 91 | make_dataset_for_binary(['../data/raw_data/train_augment_binary.xlsx'], 92 | target_file='../data/dataset/binary_augment3.txt') 93 | # 制作rdf文件 94 | parse_triples_file('../data/raw_data/triples.txt') 95 | 96 | 97 | if __name__ == '__main__': 98 | print('docker运行开始,时间为', get_time_str()) 99 | print('---------------------开始process-----------------------') 100 | setup_seed(1) # todo 101 | prepare_dir() 102 | prepare_copy_file() 103 | prepare_data() 104 | print('---------------------process结束-----------------------') 105 | end = time.time() 106 | 107 | -------------------------------------------------------------------------------- /code/module.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: hezf 3 | @Time: 2021/6/3 19:39 4 | @desc: 5 | """ 6 | from sklearn.metrics import f1_score, jaccard_score, confusion_matrix 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch 10 | import numpy as np 11 | from typing import List 12 | 13 | 14 | class MutiTaskLoss(nn.Module): 15 | def __init__(self, use_efficiency=False): 16 | """ 17 | :param use_efficiency: 样本有效率标志,需要搭配起forward中的efficiency参数 18 | """ 19 | super(MutiTaskLoss, self).__init__() 20 | self.use_efficiency = use_efficiency 21 | if use_efficiency: 22 | self.ans_loss = nn.CrossEntropyLoss(reduction='none') 23 | else: 24 | self.ans_loss = nn.CrossEntropyLoss() 25 | self.prop_loss = FocalLoss(alpha=[0.759, 0.68, 0.9458, 0.9436, 0.9616, 0.915, 0.9618, 0.871, 0.8724, 0.994, 0.986, 0.994, 0.9986, 0.993, 0.991, 0.9968, 0.9986, 0.999, 0.9996], 26 | multi_label=True) 27 | # self.prop_loss = nn.BCELoss() 28 | if use_efficiency: 29 | self.entity_loss = nn.CrossEntropyLoss(reduction='none') 30 | else: 31 | self.entity_loss = nn.CrossEntropyLoss() 32 | # 可学习的权重 33 | self.loss_weight = nn.Parameter(torch.ones(3), requires_grad=True) 34 | 35 | def forward(self, ans_true, ans_pred, prop_true, prop_pred, entity_true, entity_pred, efficiency=None): 36 | # 使用样本有效率 37 | if self.use_efficiency: 38 | assert efficiency is not None, 'efficiency is None' 39 | batch_size = ans_true.shape[0] 40 | loss1 = self.ans_loss(ans_pred, ans_true).unsqueeze(0).mm(efficiency.unsqueeze(1)).squeeze(0)/batch_size 41 | loss2 = self.prop_loss(prop_pred, prop_true.float(), efficiency) 42 | loss3 = self.entity_loss(entity_pred, entity_true).unsqueeze(0).mm(efficiency.unsqueeze(1)).squeeze(0)/batch_size 43 | else: 44 | loss1 = self.ans_loss(ans_pred, ans_true) 45 | loss2 = self.prop_loss(prop_pred, prop_true.float()) 46 | loss3 = self.entity_loss(entity_pred, entity_true) 47 | weight = self.loss_weight.softmax(dim=0) 48 | loss = weight[0] * loss1 + weight[1]*loss2 + weight[2]*loss3 49 | return loss 50 | 51 | 52 | class MutiTaskLossV1(nn.Module): 53 | def __init__(self, use_efficiency=False): 54 | """ 55 | :param use_efficiency: 样本有效率标志,需要搭配起forward中的efficiency参数 56 | """ 57 | super(MutiTaskLossV1, self).__init__() 58 | self.use_efficiency = use_efficiency 59 | if use_efficiency: 60 | self.ans_loss = nn.CrossEntropyLoss(reduction='none') 61 | self.entity_loss = nn.CrossEntropyLoss(reduction='none') 62 | self.method_loss = nn.CrossEntropyLoss(reduction='none') 63 | self.condition_loss = nn.CrossEntropyLoss(reduction='none') 64 | else: 65 | self.ans_loss = nn.CrossEntropyLoss() 66 | self.entity_loss = nn.CrossEntropyLoss() 67 | self.method_loss = nn.CrossEntropyLoss() 68 | self.condition_loss = nn.CrossEntropyLoss() 69 | self.prop_loss = FocalLoss(alpha=[0.759, 0.68, 0.9458, 0.9436, 0.9616, 0.915, 0.9618, 0.871, 0.8724, 0.994, 0.986, 0.994, 0.9986, 0.993, 0.991, 0.9968, 0.9986, 0.999, 0.9996], 70 | multi_label=True) 71 | # self.prop_loss = nn.BCELoss() 72 | # 可学习的权重 73 | self.loss_weight = nn.Parameter(torch.ones(4), requires_grad=True) 74 | 75 | def forward(self, ans_true, ans_pred, prop_true, prop_pred, entity_true, entity_pred, method_true, method_pred, condition_true, condition_pred, efficiency=None): 76 | # 使用样本有效率 77 | if self.use_efficiency: 78 | assert efficiency is not None, 'efficiency is None' 79 | batch_size = ans_true.shape[0] 80 | loss1 = self.ans_loss(ans_pred, ans_true).unsqueeze(0).mm(efficiency.unsqueeze(1)).squeeze(0)/batch_size 81 | loss2 = self.prop_loss(prop_pred, prop_true.float(), efficiency) 82 | loss3 = self.entity_loss(entity_pred, entity_true).unsqueeze(0).mm(efficiency.unsqueeze(1)).squeeze(0)/batch_size 83 | loss4 = self.method_loss(method_pred, method_true).unsqueeze(0).mm(efficiency.unsqueeze(1)).squeeze(0)/batch_size 84 | loss5 = self.condition_loss(condition_pred, condition_true).unsqueeze(0).mm(efficiency.unsqueeze(1)).squeeze(0)/batch_size 85 | else: 86 | loss1 = self.ans_loss(ans_pred, ans_true) 87 | loss2 = self.prop_loss(prop_pred, prop_true.float()) 88 | loss3 = self.entity_loss(entity_pred, entity_true) 89 | loss4 = self.method_loss(method_pred, method_true) 90 | loss5 = self.condition_loss(condition_pred, condition_true) 91 | weight = self.loss_weight.softmax(dim=0) 92 | loss = weight[0] * loss1 + weight[1]*loss2 + weight[2]*loss3 + weight[3]*(loss4 + loss5)/2 93 | return loss 94 | 95 | 96 | class CrossEntropyWithEfficiency(nn.Module): 97 | def __init__(self): 98 | super(CrossEntropyWithEfficiency, self).__init__() 99 | self.loss = nn.CrossEntropyLoss(reduction='none') 100 | 101 | def forward(self, true, pred, efficiency): 102 | batch_size = true.shape[0] 103 | loss = self.loss(pred, true).unsqueeze(0).mm(efficiency.unsqueeze(1)).squeeze(0) / batch_size 104 | return loss 105 | 106 | # # rdrop 加入有效率前的备份 107 | # class RDropLoss(nn.Module): 108 | # def __init__(self, alpha=4): 109 | # super(RDropLoss, self).__init__() 110 | # self.ans_loss = nn.CrossEntropyLoss() 111 | # self.prop_loss = FocalLoss(alpha=[0.759, 0.68, 0.9458, 0.9436, 0.9616, 0.915, 0.9618, 0.871, 0.8724, 0.994, 0.986, 0.994, 0.9986, 0.993, 0.991, 0.9968, 0.9986, 0.999, 0.9996], 112 | # multi_label=True) 113 | # self.entity_loss = nn.CrossEntropyLoss() 114 | # self.kl = nn.KLDivLoss(reduction='batchmean') 115 | # self.alpha = alpha 116 | # # 可学习的权重 117 | # self.loss_weight = nn.Parameter(torch.ones(3), requires_grad=True) 118 | # 119 | # def r_drop(self, loss_func, pred: List, true, multi_label=True): 120 | # loss_0 = loss_func(pred[0], true) 121 | # loss_1 = loss_func(pred[1], true) 122 | # if multi_label: 123 | # kl_loss = (F.kl_div(pred[0].log(), pred[1], reduction='batchmean') + F.kl_div(pred[1].log(), pred[0], reduction='batchmean')) / 2 124 | # else: 125 | # kl_loss = (F.kl_div(F.log_softmax(pred[0], -1), F.softmax(pred[1], -1), reduction='batchmean') + F.kl_div(F.log_softmax(pred[1], -1), F.softmax(pred[0], -1), reduction='batchmean')) / 2 126 | # return loss_0 + loss_1 + self.alpha * kl_loss 127 | # 128 | # def forward(self, ans_true, ans_pred: List, prop_true, prop_pred: List, entity_true, entity_pred: List): 129 | # loss1 = self.r_drop(self.ans_loss, ans_pred, ans_true, multi_label=False) 130 | # # loss1 = (self.ans_loss(ans_pred[0], ans_true) + self.ans_loss(ans_pred[1], ans_true))/2 # prop_only 131 | # loss2 = self.r_drop(self.prop_loss, prop_pred, prop_true.float(), multi_label=True) 132 | # loss3 = self.r_drop(self.entity_loss, entity_pred, entity_true, multi_label=False) 133 | # # loss3 = (self.entity_loss(entity_pred[0], entity_true) + self.entity_loss(entity_pred[1], entity_true))/2 # prop_only 134 | # weight = self.loss_weight.softmax(dim=0) 135 | # loss = weight[0] * loss1 + weight[1]*loss2 + weight[2]*loss3 136 | # return loss 137 | 138 | 139 | class RDropLoss(nn.Module): 140 | def __init__(self, alpha=4): 141 | super(RDropLoss, self).__init__() 142 | self.ans_loss = nn.CrossEntropyLoss(reduction='none') 143 | self.prop_loss = FocalLoss(alpha=[0.759, 0.68, 0.9458, 0.9436, 0.9616, 0.915, 0.9618, 0.871, 0.8724, 0.994, 0.986, 0.994, 0.9986, 0.993, 0.991, 0.9968, 0.9986, 0.999, 0.9996], 144 | multi_label=True) 145 | self.entity_loss = nn.CrossEntropyLoss(reduction='none') 146 | self.kl = nn.KLDivLoss(reduction='batchmean') 147 | self.alpha = alpha 148 | # 可学习的权重 149 | self.loss_weight = nn.Parameter(torch.ones(3), requires_grad=True) 150 | 151 | def r_drop(self, loss_func, pred: List, true, efficiency, multi_label=True): 152 | if efficiency is None: 153 | efficiency = torch.tensor([1.0] * pred[0].shape[0]) 154 | # 多标签时,是用focalloss 155 | if multi_label: 156 | loss_0 = loss_func(pred[0], true, efficiency) 157 | loss_1 = loss_func(pred[1], true, efficiency) 158 | else: 159 | batch_size = pred[0].shape[0] 160 | loss_0 = loss_func(pred[0], true).unsqueeze(0).mm(efficiency.unsqueeze(1)).squeeze(0)/batch_size 161 | loss_1 = loss_func(pred[1], true).unsqueeze(0).mm(efficiency.unsqueeze(1)).squeeze(0)/batch_size 162 | if multi_label: 163 | kl_loss = (F.kl_div(pred[0].log(), pred[1], reduction='batchmean') + F.kl_div(pred[1].log(), pred[0], reduction='batchmean')) / 2 164 | else: 165 | kl_loss = (F.kl_div(F.log_softmax(pred[0], -1), F.softmax(pred[1], -1), reduction='batchmean') + F.kl_div(F.log_softmax(pred[1], -1), F.softmax(pred[0], -1), reduction='batchmean')) / 2 166 | return loss_0 + loss_1 + self.alpha * kl_loss 167 | 168 | def forward(self, ans_true, ans_pred: List, prop_true, prop_pred: List, entity_true, entity_pred: List, efficiency): 169 | loss1 = self.r_drop(self.ans_loss, ans_pred, ans_true, efficiency, multi_label=False) 170 | loss2 = self.r_drop(self.prop_loss, prop_pred, prop_true.float(), efficiency, multi_label=True) 171 | loss3 = self.r_drop(self.entity_loss, entity_pred, entity_true, efficiency, multi_label=False) 172 | weight = self.loss_weight.softmax(dim=0) 173 | loss = weight[0] * loss1 + weight[1]*loss2 + weight[2]*loss3 174 | return loss 175 | 176 | 177 | class MutiTaskLossFocal(nn.Module): 178 | def __init__(self): 179 | super(MutiTaskLossFocal, self).__init__() 180 | self.ans_loss = FocalLoss() 181 | self.prop_loss = FocalLoss(alpha=[0.759, 0.68, 0.9458, 0.9436, 0.9616, 0.915, 0.9618, 0.871, 0.8724, 0.994, 0.986, 0.994, 0.9986, 0.993, 0.991, 0.9968, 0.9986, 0.999, 0.9996], 182 | multi_label=True) 183 | self.entity_loss = FocalLoss() 184 | # 可学习的权重 185 | self.loss_weight = nn.Parameter(torch.ones(3), requires_grad=True) 186 | 187 | def forward(self, ans_true, ans_pred, prop_true, prop_pred, entity_true, entity_pred): 188 | loss1 = self.ans_loss(ans_pred, ans_true) 189 | loss2 = self.prop_loss(prop_pred, prop_true.float()) 190 | loss3 = self.entity_loss(entity_pred, entity_true) 191 | weight = self.loss_weight.softmax(dim=0) 192 | loss = weight[0] * loss1 + weight[1]*loss2 + weight[2]*loss3 193 | return loss 194 | 195 | 196 | class Metric(object): 197 | def __init__(self): 198 | super(Metric, self).__init__() 199 | 200 | def calculate(self, ans_pred, ans_true, prop_pred, prop_true, entity_pred, entity_true): 201 | ans_f1 = f1_score(ans_true, ans_pred, average='micro') 202 | # ans_matrix = confusion_matrix(ans_true, ans_pred) 203 | prop_j = jaccard_score(prop_true, prop_pred, average='micro') 204 | entity_f1 = f1_score(entity_true, entity_pred, average='micro') 205 | return ans_f1, prop_j, entity_f1 206 | 207 | 208 | class Attention(nn.Module): 209 | def __init__(self, hidden_size): 210 | super(Attention, self).__init__() 211 | self.linear = nn.Linear(hidden_size, 1) 212 | 213 | def forward(self, hidden_states: torch.Tensor, mask: torch.Tensor): 214 | x = self.linear(hidden_states).squeeze(-1) 215 | x = x.masked_fill(mask, -np.inf) 216 | attention_value = x.softmax(dim=-1).unsqueeze(1) 217 | x = torch.bmm(attention_value, hidden_states).squeeze(1) 218 | return x 219 | 220 | 221 | class FocalLoss(nn.Module): 222 | def __init__(self, gamma=2, alpha=1.0, size_average=True, multi_label=False): 223 | super(FocalLoss, self).__init__() 224 | self.gamma = gamma 225 | self.alpha = alpha 226 | if isinstance(self.alpha, list): 227 | self.alpha = torch.tensor(self.alpha) 228 | self.size_average = size_average 229 | self.multi_label = multi_label 230 | 231 | def forward(self, logits, labels, efficiency=None): 232 | """ 233 | logits: batch_size * n_class 234 | labels: batch_size 235 | """ 236 | batch_size, n_class = logits.shape[0], logits.shape[1] 237 | # 多分类分类 238 | if not self.multi_label: 239 | one_hots = torch.zeros([batch_size, n_class]).to(logits.device).scatter_(-1, labels.unsqueeze(-1), 1) 240 | p = torch.nn.functional.softmax(logits, dim=-1) 241 | log_p = torch.log(p) 242 | loss = - one_hots * (self.alpha * ((1 - p) ** self.gamma) * log_p) 243 | # 多标签分类 244 | else: 245 | p = logits 246 | pt = (labels - (1 - p)) * (2*labels-1) 247 | if isinstance(self.alpha, float): 248 | alpha_t = (labels - (1 - self.alpha)) * (2*labels-1) 249 | else: 250 | alpha_t = (labels - (1 - self.alpha.to(logits.device))) * (2*labels-1) 251 | loss = - alpha_t * ((1 - pt)**self.gamma) * torch.log(pt) 252 | # 加入有效率的计算 253 | if efficiency is not None: 254 | loss = torch.diag_embed(efficiency).mm(loss) 255 | if self.size_average: 256 | return loss.sum()/batch_size 257 | else: 258 | return loss.sum() 259 | 260 | 261 | class FGM(object): 262 | """ 263 | 对抗攻击 264 | """ 265 | def __init__(self, model: nn.Module, eps=1.): 266 | self.model = ( 267 | model.module if hasattr(model, "module") else model 268 | ) 269 | self.eps = eps 270 | self.backup = {} 271 | 272 | # only attack word embedding 273 | def attack(self, emb_name='word_embeddings'): 274 | for name, param in self.model.named_parameters(): 275 | if param.requires_grad and emb_name in name: 276 | self.backup[name] = param.data.clone() 277 | norm = torch.norm(param.grad) 278 | if norm and not torch.isnan(norm): 279 | r_at = self.eps * param.grad / norm 280 | param.data.add_(r_at) 281 | 282 | def restore(self, emb_name='word_embeddings'): 283 | for name, para in self.model.named_parameters(): 284 | if para.requires_grad and emb_name in name: 285 | assert name in self.backup 286 | para.data = self.backup[name] 287 | 288 | self.backup = {} 289 | -------------------------------------------------------------------------------- /code/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | mode='multi' 3 | debug_status='no' 4 | 5 | if test $debug_status = 'yes' 6 | then 7 | batch_size1=16 8 | batch_size2=16 9 | batch_size3=16 10 | ptm_epoch=1 11 | ptm_epoch_binary=1 12 | train_epoch=1 13 | train_epoch_binary=1 14 | pt_file=augment3.txt 15 | elif test $debug_status = 'no' 16 | then 17 | batch_size1=350 18 | batch_size2=256 19 | batch_size3=190 20 | ptm_epoch=14 21 | ptm_epoch_binary=5 22 | train_epoch=5 23 | train_epoch_binary=5 24 | pt_file=augment3.txt 25 | fi 26 | # TODO 注意回车标志只能是LF 27 | # 单模型 28 | if test $mode = 'single' 29 | then 30 | python docker_process.py 31 | # BERT等模型的batch_size=350,XLNet的batch_size=256 32 | # GPT2batch_size=180暂定,学习率1e-4 33 | # bert+rdrop的batch_size=180 34 | python train.py -tt only_train -e $ptm_epoch -tf $pt_file -ptm roberta -d 0 -m roberta -b $batch_size1 -l 0.0002 35 | python train.py -tt kfold -e $train_epoch -tf cls_labeled.txt -ptm roberta -d 0 -m roberta_model -b $batch_size1 -l 0.00005 -tm pretrained_roberta.pth 36 | python train_evaluate_constraint.py -m ner_model.pth -d 0 -tv bieo -k 1 37 | # 修改模型时,这里模型名也要修改 38 | python inference.py -cm roberta_model.pth -nm ner_model.pth -tv bieo -k 1 39 | # 多模型 40 | elif test $mode = 'multi' 41 | then 42 | python docker_process.py 43 | # bert 44 | python train.py -tt only_train -e $ptm_epoch -tf $pt_file -ptm bert -d 0 -m bert -b $batch_size1 -l 0.0002 45 | python train.py -tt kfold -e $train_epoch -tf cls_labeled.txt -ptm bert -d 0 -m bert_model -b $batch_size1 -l 0.00005 -tm pretrained_bert.pth 46 | # roberta 47 | python train.py -tt only_train -e $ptm_epoch -tf $pt_file -ptm roberta -d 0 -m roberta -b $batch_size1 -l 0.0002 48 | python train.py -tt kfold -e $train_epoch -tf cls_labeled.txt -ptm roberta -d 0 -m roberta_model -b $batch_size1 -l 0.00005 -tm pretrained_roberta.pth 49 | # electra 50 | python train.py -tt only_train -e $ptm_epoch -tf $pt_file -ptm electra -d 0 -m electra -b $batch_size1 -l 0.0002 51 | python train.py -tt kfold -e $train_epoch -tf cls_labeled.txt -ptm electra -d 0 -m electra_model -b $batch_size1 -l 0.00005 -tm pretrained_electra.pth 52 | # xlnet 53 | python train.py -tt only_train -e $ptm_epoch -tf $pt_file -ptm xlnet -d 0 -m xlnet -b $batch_size2 -l 0.0002 54 | python train.py -tt kfold -e $train_epoch -tf cls_labeled.txt -ptm xlnet -d 0 -m xlnet_model -b $batch_size2 -l 0.00005 -tm pretrained_xlnet.pth 55 | # gpt2 56 | #python train.py -tt only_train -e $ptm_epoch -tf $pt_file -ptm gpt2 -d 0 -m gpt2 -b $batch_size3 -l 0.0001 57 | #python train.py -tt kfold -e $train_epoch -tf cls_labeled.txt -ptm gpt2 -d 0 -m gpt2_model -b $batch_size3 -l 0.00005 -tm pretrained_gpt2.pth 58 | # macbert 59 | python train.py -tt only_train -e $ptm_epoch -tf $pt_file -ptm macbert -d 0 -m macbert -b $batch_size1 -l 0.0002 60 | python train.py -tt kfold -e $train_epoch -tf cls_labeled.txt -ptm macbert -d 0 -m macbert_model -b $batch_size1 -l 0.00005 -tm pretrained_macbert.pth 61 | # albert 62 | #python train.py -tt only_train -e $ptm_epoch -tf $pt_file -ptm albert -d 0 -m albert -b $batch_size1 -l 0.001 63 | #python train.py -tt kfold -e $train_epoch -tf cls_labeled.txt -ptm albert -d 0 -m albert_model -b $batch_size1 -l 0.00025 -tm pretrained_albert.pth 64 | # 二分类模型 65 | python train.py -tt only_binary -e $ptm_epoch_binary -tf binary_augment3.txt -ptm bert -d 0 -m bert -b 32 -l 0.00005 66 | python train.py -tt binary -e $train_epoch_binary -tf binary_labeled.txt -ptm bert -d 0 -m bert_binary -b 32 -l 0.00001 -tm pretrained_binary_bert.pth 67 | # ner 68 | python train_evaluate_constraint.py -m ner_model.pth -d 0 -tv bieo -k 5 69 | # inference 70 | python inference.py -cm bert_model.pth -nm ner_model.pth -tv bieo -e 1 -de 0 -b 4 -xm xlnet_model.pth -em electra_model.pth -rm roberta_model.pth -mm macbert_model.pth -k 5 -wb 71 | elif test $mode = 'orig' 72 | then 73 | python docker_process.py 74 | # BERT等模型的batch_size=350,XLNet的batch_size=256 75 | python train.py -tt kfold -e 50 -tf cls_labeled.txt -ptm bert -d 0 -m bert_model -b 350 -l 0.0002 76 | # eh为标注增强,可设为0 77 | python train_evaluate_constraint.py -m ner_model.pth -d 0 -tv bieo -k 5 78 | # 修改模型时,这里模型名也要修改, eh为标注增强,可设为0 79 | python inference.py -cm bert_model.pth -nm ner_model.pth -tv bieo -dp ../data/dataset/cls_unlabeled.txt -k 5 80 | else 81 | echo "其他" 82 | fi 83 | -------------------------------------------------------------------------------- /code/statistic.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: hezf 3 | @Time: 2021/6/19 17:20 4 | @desc: 统计数据模块 5 | """ 6 | import os 7 | import pandas as pd 8 | import re 9 | import json 10 | 11 | 12 | def statistic_entity(labeled_file: str = '../data/raw_data/train.xlsx'): 13 | """ 14 | 统计”实体“分布 15 | :param labeled_file: 16 | :return: 17 | """ 18 | labeled_data = pd.read_excel(labeled_file) 19 | entity_list = list(labeled_data.loc[:, '实体']) 20 | entity_count = {} 21 | for entity in entity_list: 22 | if entity not in entity_count: 23 | entity_count[entity] = 0 24 | entity_count[entity] += 1 25 | sorted_dict = sorted(entity_count.items(), key=lambda item: (item[1], item[0]), reverse=True) 26 | # for entity, count in sorted_dict: 27 | # print('实体:{}, 数量:{}'.format(entity, count)) 28 | with open('../data/file/entity_count.json', 'w', encoding='utf-8') as f: 29 | json.dump(entity_count, f, ensure_ascii=False, indent=2) 30 | 31 | 32 | def statistic_property(labeled_file: str = '../data/raw_data/train.xlsx'): 33 | """ 34 | 统计”属性名“分布 35 | :param labeled_file: 36 | :return: 37 | """ 38 | labeled_data = pd.read_excel(labeled_file) 39 | property_list = list(labeled_data.loc[:, '属性名']) 40 | property_count = {} 41 | for prop in property_list: 42 | prop = prop.split('|') 43 | for p in prop: 44 | if p not in property_count: 45 | property_count[p] = 0 46 | property_count[p] += 1 47 | sorted_dict = sorted(property_count.items(), key=lambda item: (item[1], item[0]), reverse=True) 48 | for prop, count in sorted_dict: 49 | print('属性名:{}, 数量:{}'.format(prop, count)) 50 | with open('../data/file/prop_count.json', 'w', encoding='utf-8') as f: 51 | json.dump(property_count, f, ensure_ascii=False, indent=2) 52 | 53 | 54 | def statistic_frequent_char(label_file: str = '../data/dataset/cls_labeled.txt'): 55 | """ 56 | 统计频繁使用的字符,据此推断停用词 57 | :param label_file: 58 | :return: 59 | """ 60 | char_count = {} 61 | less_5 = set() 62 | with open(label_file, 'r', encoding='utf-8') as f: 63 | for line in f: 64 | line = line.strip() 65 | question = line.split('\t')[0] 66 | for c in question: 67 | if c not in char_count: 68 | char_count[c] = 0 69 | char_count[c] += 1 70 | sorted_dict = sorted(char_count.items(), key=lambda item: (item[1], item[0]), reverse=True) 71 | for c, count in sorted_dict: 72 | if count < 10: 73 | less_5.add(c) 74 | print('字符:{}, 数量:{}'.format(c, count)) 75 | print('少于5个字符的有:') 76 | for c in less_5: 77 | print(c) 78 | 79 | 80 | def get_all_service(data): 81 | """ 82 | 得到子业务名称 83 | @param data: excel读取的dataframe数据 84 | @return: 包含所有子业务名称的list 85 | """ 86 | service_name = set() 87 | for index, row in data.iterrows(): 88 | constraint_names = row['约束属性名'] 89 | constraint_values = row['约束属性值'] 90 | constraint_names_list = re.split(r'[|\|]', str(constraint_names)) 91 | constraint_values_list = re.split(r'[|\|]', str(constraint_values)) 92 | constraint_names_list = [name.strip() for name in constraint_names_list] 93 | constraint_values_list = [value.strip() for value in constraint_values_list] 94 | for i in range(len(constraint_names_list)): 95 | if constraint_names_list[i] == '子业务': 96 | service_name.add(constraint_values_list[i]) 97 | return list(service_name) 98 | 99 | 100 | def statistic_duplicate_constraint(labeled_file: str): 101 | """ 102 | 统计约束属性值的子业务中,值相同的有哪些 // 结果 10条,全是'合约版' 103 | @param labeled_file: 104 | @return: 105 | """ 106 | duplicate = 0 107 | 108 | labeled_data = pd.read_excel(labeled_file) 109 | con_names = labeled_data['约束属性名'].tolist() 110 | con_values = labeled_data['约束属性值'].tolist() 111 | con_names = [re.split(r'[|\|]', str(item)) for item in con_names] 112 | con_values = [re.split(r'[|\|]', str(item)) for item in con_values] 113 | for index in range(len(con_names)): 114 | services = [] 115 | con_name = con_names[index] 116 | con_value = con_values[index] 117 | for name_id in range(len(con_name)): 118 | if con_name[name_id] == '子业务': 119 | if con_value[name_id] in services: 120 | print(services) 121 | duplicate += 1 122 | services.append(con_value[name_id]) 123 | print(duplicate) 124 | 125 | 126 | def statistic_synonyms(): 127 | """ 128 | 统计包含的同义词,英文全小写。并将清洗后的同义词写入json文件中 129 | :return: 130 | """ 131 | def in_ner_set(word): 132 | for ner in ner_set: 133 | if ner in word: 134 | return True 135 | return False 136 | ner_set = {'咪咕直播流量包', '上网版', '成员', '咪咕', '热剧vip', 'plus版', 'pptv', '优酷', '百度', '腾讯视频', '2020版', 137 | '大陆及港澳台版', '芒果', '乐享版', '小额版', '年包', '体验版-12个月', '电影vip', '合约版', '优酷会员', '王者荣耀', 138 | '长期版', '宝藏版', '免费版', '流量套餐', '乐视会员', '2019版', '基础版', '月包', '全球通版', '畅享包', '腾讯', 139 | '爱奇艺会员', '喜马拉雅', '普通版', '流量包', '芒果tv', '百度包', '24月方案', '2018版', '个人版', '半年包', 140 | '咪咕流量包', '爱奇艺', '阶梯版', '12月方案', '网易', '家庭版', '198', '5', '12', '160', '20', '220', '398', '15', 141 | '200', '24', '30', '3', '288', '18', '238', '120', 142 | '128', '60', '9', '10', '300', '188', '59', '68', '8', '38', '2', '80', '680', '70', '158', '1', 143 | '380', '298', '11', '65', '40', '19', '23', '29', '6', '99', '500', '22', '49', '100', '40', '700', '110', 144 | '20', '100', '500', '1', '30', '300'} 145 | # 有实体冲突的同义词:(语音留言、家庭亲情号、手机视频流量包) 146 | conflict = ('语音留言', '家庭亲情号', '手机视频流量包', '流量包') 147 | synonym_list = [] 148 | synonym2entity = dict() 149 | for_cls = True 150 | with open('../data/raw_data/synonyms.txt', 'r', encoding='utf-8') as f: 151 | for line in f: 152 | blocks = line.strip().lower().split(' ') 153 | entity, temp_synonym_list = blocks[0], blocks[1].split('|') 154 | if for_cls or not in_ner_set(entity): 155 | for s in temp_synonym_list: 156 | s = s.strip() 157 | if s != '无' and s != '' and s not in conflict: 158 | # 为分类模型考虑的同义词增强 159 | if for_cls or not in_ner_set(s): 160 | synonym2entity[s] = entity 161 | synonym_list.append(s) 162 | synonym_list.sort(key=lambda item: len(item), reverse=True) 163 | if not for_cls: 164 | remove_synonym = list() 165 | for i, synonym in enumerate(synonym_list): 166 | for j in range(i): 167 | if synonym in synonym_list[j] and synonym2entity[synonym] != synonym2entity[synonym_list[j]]: 168 | # 冲突则移除该词,降低误差 169 | print('冲突:', synonym, synonym_list[j]) 170 | remove_synonym += [synonym, synonym_list[j]] 171 | # print('需要移除的同义词有:', set(remove_synonym)) 172 | for r_s in set(remove_synonym): 173 | synonym2entity.pop(r_s) 174 | entity2synonym = {} 175 | for s, e in synonym2entity.items(): 176 | if e not in entity2synonym: 177 | entity2synonym[e] = [] 178 | entity2synonym[e].append(s) 179 | with open('../data/file/synonyms{}.json'.format('' if for_cls else '_ner'), 'w', encoding='utf-8') as f: 180 | json.dump({'entity2synonym': entity2synonym, 'synonym2entity': synonym2entity}, 181 | f, 182 | ensure_ascii=False, 183 | indent=2) 184 | print('同义词写入成功...') 185 | 186 | 187 | def statistic_prop_labels(): 188 | """ 189 | 统计属性名的占样本的比例 190 | [0.241, 0.32, 0.0542, 0.0564, 0.0384, 0.085, 0.0382, 0.129, 0.1276, 0.006, 0.014, 0.006, 0.0014, 0.007, 0.009, 0.0032, 0.0014, 0.001, 0.0004] 191 | [0.759, 0.68, 0.9458, 0.9436, 0.9616, 0.915, 0.9618, 0.871, 0.8724, 0.994, 0.986, 0.994, 0.9986, 0.993, 0.991, 0.9968, 0.9986, 0.999, 0.9996] 192 | 193 | [0.75, 0.91, 0.87, 0.67, 0.99, 0.94, 0.87, 0.96, 0.99, 0.99, 0.99, 0.96, 0.94, 0.98] 194 | """ 195 | with open('../data/dataset/cls_label2id_fewer.json', 'r', encoding='utf-8') as f: 196 | label2id = json.load(f)['main_property'] 197 | count = 0 198 | prop_count = [0] * len(label2id) 199 | alpha_count = [] 200 | with open('../data/dataset/cls_labeled_fewer.txt', 'r', encoding='utf-8') as f: 201 | for line in f: 202 | count += 1 203 | blocks = line.strip().split('\t') 204 | props = blocks[2].split('|') 205 | for p in props: 206 | if p not in label2id: 207 | print(line) 208 | prop_count[label2id[p]] += 1 209 | for i in range(len(prop_count)): 210 | prop_count[i] /= count 211 | alpha_count.append(1 - prop_count[i]) 212 | print(prop_count) 213 | print(alpha_count) 214 | 215 | 216 | def statistic_service(labeled_file: str = '../data/raw_data/train_denoised.xlsx'): 217 | """ 218 | 219 | :param labeled_file: 220 | :return: 221 | """ 222 | labeled_data = pd.read_excel(labeled_file) 223 | name_list = list(labeled_data.loc[:, '约束属性名']) 224 | value_list = list(labeled_data.loc[:, '约束属性值']) 225 | service = set() 226 | price = set() 227 | flow = set() 228 | for i, name in enumerate(name_list): 229 | n_list = str(name).split('|') 230 | v_list = str(value_list[i]).split('|') 231 | for j, n in enumerate(n_list): 232 | if n == '子业务': 233 | service.add(v_list[j]) 234 | elif n == '流量': 235 | flow.add(v_list[j]) 236 | elif n == '价格': 237 | price.add(v_list[j]) 238 | print('所有的子业务:', service) 239 | print('所有的价格:', price) 240 | print('所有的流量:', flow) 241 | 242 | 243 | def statistic_wrong_rdf(): 244 | ''' 245 | 向kg输入train.xlsx前面的字段,输出预测结果,并和给出的答案做对比 246 | @return: train_wrong_triple.xlsx : 保存对比结果 247 | ''' 248 | print('正在获取rdf结果并对比') 249 | from triples import KnowledgeGraph 250 | 251 | kg = KnowledgeGraph('../data/process_data/triples.rdf') 252 | df = pd.read_excel('../data/raw_data/train_denoised.xlsx') 253 | df.fillna('') 254 | id_list = [] 255 | ans_list = [] 256 | for iter, row in df.iterrows(): 257 | ans_true = list(set(row['答案'].split('|'))) 258 | question = row['用户问题'] 259 | ans_type = row['答案类型'] 260 | entity = row['实体'] 261 | main_property = row['属性名'].split('|') 262 | operator = row['约束算子'] 263 | if operator != 'min' and operator != 'max': 264 | operator == 'other' 265 | sub_properties = [] 266 | cons_names = str(row['约束属性名']).split('|') 267 | cons_values = str(row['约束属性值']).split('|') 268 | if cons_names == ['nan']: cons_names = [] 269 | for index in range(len(cons_names)): 270 | sub_properties.append([cons_names[index], cons_values[index]]) 271 | ans = kg.fetch_ans(question, ans_type, entity, main_property, operator, sub_properties) 272 | 273 | def is_same(ans, ans_true): 274 | for an in ans: 275 | if an in ans_true: 276 | ans_true.remove(an) 277 | else: 278 | return False 279 | if len(ans_true) != 0: 280 | return False 281 | return True 282 | 283 | if not is_same(ans, ans_true): 284 | id_list.append(iter) 285 | ans_list.append(ans) 286 | print(id_list) 287 | df_save = df.iloc[id_list, [0, 1, 2, 3, 4, 5, 6, 7]] 288 | df_save['预测'] = ans_list 289 | print(df_save) 290 | df_save.to_excel('/data/huangbo/project/Tianchi_nlp_git/data/raw_data/train_wrong_triple.xlsx') 291 | 292 | 293 | def statistic_wrong_cons(): 294 | ''' 295 | 找出属性句中(min max不要,并且只要'档位介绍-xx')中原句中出现但约束中却没出现的价格、子业务、流量 296 | @return: 297 | ''' 298 | print('正在获取rdf结果并对比') 299 | from triples import KnowledgeGraph 300 | 301 | kg = KnowledgeGraph('../data/process_data/triples.rdf') 302 | df = pd.read_excel('../data/raw_data/train_denoised.xlsx') 303 | df.fillna('') 304 | id_list = set() 305 | ans_list = [] 306 | for iter, row in df.iterrows(): 307 | ans_true = list(set(row['答案'].split('|'))) 308 | question = row['用户问题'] 309 | ans_type = row['答案类型'] 310 | # 只对属性值的句子做处理 311 | if ans_type != '属性值': 312 | continue 313 | entity = row['实体'] 314 | main_property = row['属性名'].split('|') 315 | # 排除属性中没有'-'的情况,只要'档位介绍-xx'的情况 316 | if '-' not in main_property[0]: 317 | continue 318 | operator = row['约束算子'] 319 | # 排除operator为min或max的情况 320 | if operator != 'min' and operator != 'max': 321 | operator == 'other' 322 | else: 323 | continue 324 | sub_properties = {} 325 | cons_names = str(row['约束属性名']).split('|') 326 | cons_values = str(row['约束属性值']).split('|') 327 | if cons_names == ['nan']: cons_names = [] 328 | for index in range(len(cons_names)): 329 | if cons_names[index] not in sub_properties: 330 | sub_properties[cons_names[index]] = [] 331 | if cons_names[index] == '子业务': 332 | sub_properties[cons_names[index]].append(cons_values[index]) 333 | else: 334 | sub_properties[cons_names[index]].append(int(cons_values[index])) 335 | price_ans, flow_ans, service_ans = kg.fetch_wrong_ans(question, ans_type, entity, main_property, operator, []) 336 | rdf_properties = {} 337 | rdf_properties['价格'] = price_ans 338 | rdf_properties['流量'] = flow_ans 339 | rdf_properties['子业务'] = service_ans 340 | compare_result = [] 341 | for name, values in rdf_properties.items(): 342 | for value in values: 343 | if name != '子业务': value = int(value) 344 | if name == '流量' and (value > 99 and value % 1024 == 0): 345 | value = int(value // 1024) 346 | if str(value) in question: 347 | if name in sub_properties: 348 | if value in sub_properties[name]: 349 | continue 350 | if value == '年包' and '半年包' in sub_properties[name]: 351 | continue 352 | if value == '百度' and '百度包' in sub_properties[name]: 353 | continue 354 | elif str(value) in entity: 355 | continue 356 | compare_result.append(name + '_' + str(value)) 357 | id_list.add(iter) 358 | if compare_result != []: 359 | ans_list.append(compare_result) 360 | 361 | index = list(id_list) 362 | index.sort() 363 | df_save = df.iloc[index, [0, 1, 2, 3, 4, 5, 6, 7]] 364 | df_save['预测'] = ans_list 365 | df_save.to_excel('../data/raw_data/train_wrong_cons.xlsx') 366 | 367 | 368 | def statistic_wrong_cons_bieo(): 369 | ''' 370 | 找出属性句中(min max不要,并且只要'档位介绍-xx')中原句中出现但约束中却没出现的价格、子业务、流量 371 | 根据bieo的结果进行筛选,补全bieo 372 | @return: 373 | ''' 374 | print('正在通过rdf补充信息') 375 | from triples import KnowledgeGraph 376 | from data import create_test_BIEO, create_test_BIO 377 | 378 | kg = KnowledgeGraph('../data/process_data/triples.rdf') 379 | with open('../data/file/train_bieo.json') as f: 380 | bieo_dict = json.load(f) 381 | # bieo_dict = create_test_BIEO('../data/raw_data/train_denoised_desyn.xlsx', False) 382 | type_map = {'价格': 'PRICE', '流量': 'FLOW', '子业务': 'SERVICE'} 383 | 384 | df = pd.read_excel('../data/raw_data/train_denoised.xlsx') 385 | df.fillna('') 386 | id_list = set() 387 | ans_list = [] 388 | for iter, row in df.iterrows(): 389 | ans_true = list(set(row['答案'].split('|'))) 390 | question = row['用户问题'] 391 | ans_type = row['答案类型'] 392 | # 只对属性值的句子做处理 393 | if ans_type != '属性值': 394 | pass 395 | #continue 396 | entity = row['实体'] 397 | main_property = row['属性名'].split('|') 398 | # 排除属性中没有'-'的情况,只要'档位介绍-xx'的情况 399 | if '-' not in main_property[0]: 400 | # continue 401 | pass 402 | operator = row['约束算子'] 403 | # 排除operator为min或max的情况 404 | if operator != 'min' and operator != 'max': 405 | operator == 'other' 406 | else: 407 | continue 408 | anno = bieo_dict[question] 409 | sub_properties = {} 410 | cons_names = str(row['约束属性名']).split('|') 411 | cons_values = str(row['约束属性值']).split('|') 412 | if cons_names == ['nan']: cons_names = [] 413 | for index in range(len(cons_names)): 414 | if cons_names[index] not in sub_properties: 415 | sub_properties[cons_names[index]] = [] 416 | if cons_names[index] == '子业务': 417 | sub_properties[cons_names[index]].append(cons_values[index]) 418 | else: 419 | sub_properties[cons_names[index]].append(int(cons_values[index])) 420 | price_ans, flow_ans, service_ans = kg.fetch_wrong_ans(question, ans_type, entity, main_property, operator, []) 421 | rdf_properties = {} 422 | rdf_properties['价格'] = price_ans 423 | rdf_properties['流量'] = flow_ans 424 | rdf_properties['子业务'] = service_ans 425 | compare_result = [] 426 | for name, values in rdf_properties.items(): 427 | for value in values: 428 | if name != '子业务': value = int(value) 429 | if name == '流量' and (value > 99 and value%1024 == 0): 430 | value = int(value//1024) 431 | if str(value) in question: 432 | if name in sub_properties and value in sub_properties[name]: 433 | continue 434 | if value == '百度' and '百度包' in sub_properties['子业务']: 435 | continue 436 | question_index = 0 437 | if str(value) in entity and (question.count(str(value)) == 1 or str(value) == '1'): 438 | continue 439 | while question[question_index:].find(str(value)) != -1: 440 | temp_index = question_index + question[question_index:].find(str(value)) 441 | question_index = min(len(question), temp_index + len(str(value))) 442 | if anno[temp_index] == 'O': 443 | if name == '流量': 444 | if question_index < len(question): 445 | if question[question_index] == '元' or question[question_index].isnumeric(): 446 | continue 447 | if name == '价格': 448 | if question_index < len(question): 449 | if question[question_index] == 'g' or question[question_index].isnumeric(): 450 | continue 451 | if len(str(value)) == 1: 452 | anno[temp_index] = 'S-' + type_map[name] 453 | else: 454 | for temp_i in range(temp_index + 1, temp_index + len(str(value)) - 1): 455 | anno[temp_i] = 'I-' + type_map[name] 456 | anno[temp_index] = 'B-' + type_map[name] 457 | anno[temp_index + len(str(value)) - 1] = 'E-' + type_map[name] 458 | compare_result.append(name + '_' + str(value)) 459 | bieo_dict[question] = anno 460 | id_list.add(iter) 461 | 462 | if compare_result != []: 463 | ans_list.append(compare_result) 464 | 465 | index = list(id_list) 466 | index.sort() 467 | df_save = df.iloc[index, [0, 1, 2, 3, 4, 5, 6, 7]] 468 | df_save['预测'] = ans_list 469 | df_save.to_excel('../data/raw_data/train_wrong_cons.xlsx') 470 | with open(r'../data/file/train_bieo_enhence.json', 'w') as f: 471 | json.dump(bieo_dict, f, indent=2, ensure_ascii=False) 472 | 473 | 474 | def statistic_sym_in_question(): 475 | from data_argument import read_synonyms 476 | synonyms = read_synonyms() 477 | df = pd.read_excel('../data/raw_data/ensemble_bert_aug2_use_efficiency_2021-07-20-08-29-22_seed1_fi0_gpu2080_be81_0.95411990.pth.xlsx') 478 | # df = pd.read_excel( 479 | # '../data/raw_data/train_denoised.xlsx') 480 | c = 0 481 | no = set() 482 | yes = set() 483 | for iter, row in df.iterrows(): 484 | entity = row['实体'] 485 | question = row['用户问题'] 486 | con_name = row['约束属性名'] 487 | con_value = row['约束属性值'] 488 | con_names = re.split(r'[|\|]', str(con_name)) 489 | con_values = re.split(r'[|\|]', str(con_value)) 490 | if entity not in synonyms: continue 491 | for synonym in synonyms[entity]: 492 | if synonym in question: 493 | for value in con_values: 494 | if value in synonym: 495 | no.add(synonym) 496 | if synonym not in no: 497 | yes.add(synonym) 498 | print(no) 499 | print(yes) 500 | for iter, row in df.iterrows(): 501 | entity = row['实体'] 502 | question = row['用户问题'] 503 | con_name = row['约束属性名'] 504 | con_value = row['约束属性值'] 505 | con_names = re.split(r'[|\|]', str(con_name)) 506 | con_values = re.split(r'[|\|]', str(con_value)) 507 | if entity not in synonyms: continue 508 | new_question = question 509 | for synonym in synonyms[entity]: 510 | if synonym in question and (synonym not in no): 511 | new_question = question.replace(synonym, entity) 512 | row['用户问题'] = new_question 513 | # df.to_excel('../data/raw_data/train_denoised_desyn.xlsx') 514 | 515 | 516 | df = pd.read_excel( 517 | '../data/raw_data/train_denoised.xlsx') 518 | c = 0 519 | no = set() 520 | yes = set() 521 | for iter, row in df.iterrows(): 522 | entity = row['实体'] 523 | question = row['用户问题'] 524 | con_name = row['约束属性名'] 525 | con_value = row['约束属性值'] 526 | con_names = re.split(r'[|\|]', str(con_name)) 527 | con_values = re.split(r'[|\|]', str(con_value)) 528 | if entity not in synonyms: continue 529 | for synonym in synonyms[entity]: 530 | if synonym in question: 531 | for value in con_values: 532 | if value in synonym: 533 | no.add(synonym) 534 | if synonym not in no: 535 | yes.add(synonym) 536 | print(no) 537 | print(yes) 538 | for iter, row in df.iterrows(): 539 | entity = row['实体'] 540 | question = row['用户问题'] 541 | con_name = row['约束属性名'] 542 | con_value = row['约束属性值'] 543 | con_names = re.split(r'[|\|]', str(con_name)) 544 | con_values = re.split(r'[|\|]', str(con_value)) 545 | if entity not in synonyms: continue 546 | new_question = question 547 | for synonym in synonyms[entity]: 548 | if synonym in question and (synonym not in no): 549 | new_question = question.replace(synonym, entity) 550 | row['用户问题'] = new_question 551 | 552 | 553 | # def compare_result(result1_file, result2_file): 554 | # """ 555 | # 556 | # :param result1_file: 原结果 557 | # :param result2_file: 新结果 558 | # :return: 559 | # """ 560 | # result_path = '../data/results/' 561 | # result1 = json.load(open(os.path.join(result_path, result1_file), 'r')) 562 | # result2 = json.load(open(os.path.join(result_path, result2_file), 'r')) 563 | # for idx, value in result1['model_result'].items(): 564 | # if value != result2['model_result'][idx]: 565 | # print('结果1:', value, '\n结果2:', result2['model_result'][idx]) 566 | # print('\n') 567 | 568 | 569 | def string2result(txt): 570 | with open(txt, 'r') as f: 571 | line = f.read() 572 | result_dict = {} 573 | str_list = line.split('+') 574 | for str in str_list[:-1]: 575 | pro_list = str.split('=') 576 | id = pro_list[0] 577 | result_dict[id] = {} 578 | result_dict[id]['question'] = pro_list[1] 579 | result_dict[id]['ans_type'] = pro_list[2] 580 | result_dict[id]['entity'] = pro_list[3] 581 | result_dict[id]['main_property'] = pro_list[4].split('x') 582 | result_dict[id]['operator'] = pro_list[5] 583 | result_dict[id]['sub_properties'] = [] 584 | if pro_list[6] != '': 585 | sub_list = pro_list[6].split('x') 586 | for sub in sub_list: 587 | name = sub.split('_')[0] 588 | value = sub.split('_')[1] 589 | if name == '流量': value = int(value) 590 | if pro_list[5] == 'max' or pro_list[5] == 'min': 591 | value = int(value) 592 | result_dict[id]['sub_properties'].append((name, value)) 593 | 594 | with open('../data/results/result_docker.json', 'w', encoding='utf-8') as f: 595 | json.dump({'model_result':result_dict}, f, ensure_ascii=False, indent=2) 596 | 597 | 598 | def compare_result(file1, file2): 599 | with open('../data/results/' + file1, 'r') as f: 600 | base = json.load(f)['model_result'] 601 | with open('../data/results/' + file2, 'r') as f: 602 | result = json.load(f)['model_result'] 603 | 604 | c = 0 605 | for index, dic in base.items(): 606 | dic1 = result[index] 607 | # if dic1['sub_properties'] != dic['sub_properties']: 608 | # if dic1['entity'] != dic['entity']: 609 | # if dic1['main_property'] != dic['main_property']: 610 | # if dic1['ans_type'] != dic['ans_type']: 611 | if dic1['sub_properties'] != dic['sub_properties'] or dic1['entity'] != dic['entity'] or set(dic1['main_property']) != set(dic['main_property']) or dic1['ans_type'] != dic['ans_type']: 612 | c += 1 613 | print(index) 614 | print('file1:', dic) 615 | print('file2:', dic1) 616 | print() 617 | print(c) 618 | 619 | 620 | def static_confusion_entity(): 621 | with open('../data/dataset/cls_label2id.json', 'r', encoding='utf-8') as f: 622 | label2id = json.load(f) 623 | entities = list(label2id['entity']) 624 | for i in range(len(entities)): 625 | entity1 = entities[i] 626 | for j in range(len(entities)): 627 | entity2 = entities[j] 628 | if i != j and entity1.find(entity2) != -1: 629 | print(entity1, 'contains', entity2) 630 | 631 | 632 | def static_no_answer(file): 633 | """ 634 | 统计答案中没有属性名的样本 635 | :param file: 636 | :return: 637 | """ 638 | with open('../data/results/' + file, 'r') as f: 639 | result_dict = json.load(f) 640 | result = result_dict['result'] 641 | model_result = result_dict['model_result'] 642 | for i in result: 643 | main_property = model_result[i]['main_property'] 644 | for prop in main_property: 645 | if '-' in prop: 646 | relation = prop.split('-')[1] 647 | else: 648 | relation = prop 649 | if relation not in ('流量', '价格', '子业务', '上线时间', '语音时长', '带宽', '内含其它服务') and model_result[i]['ans_type'] != '比较句': 650 | if relation not in result[i]: 651 | print('{}\t{}'.format(model_result[i], result[i])) 652 | 653 | 654 | if __name__ == '__main__': 655 | # statistic_entity(labeled_file='../data/raw_data/train_denoised.xlsx') 656 | # statistic_property() 657 | # statistic_synonyms() 658 | # statistic_prop_labels() 659 | # statistic_service() 660 | # compare_result(result1_file='drop_1_2021-07-28-02-58-01_seed1_gpu2080_be34_0.95407602.pth1.json', 661 | # result2_file='drop_1_2021-07-28-02-58-01_seed1_gpu2080_be34_0.95407602.pth.json') 662 | # statistic_prop_labels() 663 | # compare_result(result1_file='bert_pt_augment2_2021-07-30-05-44-08_seed1_gpu2080_be6_0.99729231.pth.json', 664 | # result2_file='ensemble_bert_aug2_use_efficiency_2021-07-20-08-29-22_seed1_fi0_gpu2080_be81_0.95411990.pth.json') 665 | string2result('../data/results/result_text.txt') 666 | compare_result('9677_38.json', '9682.json') 667 | -------------------------------------------------------------------------------- /code/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: hezf 3 | @Time: 2021/7/3 19:43 4 | @desc: 5 | """ 6 | import warnings 7 | warnings.filterwarnings("ignore") 8 | # from nlpcda import Simbert, Similarword, RandomDeleteChar, Homophone, CharPositionExchange 9 | import time 10 | from multiprocessing import Process 11 | from data_argument import augment_from_simbert 12 | import pandas as pd 13 | import json 14 | 15 | # ---simbert--- 16 | # config = { 17 | # 'model_path': '/data2/hezhenfeng/other_model_files/chinese_simbert_L-6_H-384_A-12', 18 | # 'CUDA_VISIBLE_DEVICES': '6', 19 | # 'max_len': 40, 20 | # 'seed': 1, 21 | # 'device': 'gpu' 22 | # } 23 | # 24 | # simbert = Simbert(config=config) 25 | # sent_list = ['9元百度专属定向流量包如何取消', 26 | # '你告诉我7天5g视频会员流量包怎么开通,多少钱', 27 | # '您好:怎么开通70爱奇艺', 28 | # 'plus会员领取的权益可以取消吗', 29 | # '通州卡如何取消'] 30 | # 31 | # for sent in sent_list: 32 | # synonyms = simbert.replace(sent=sent, create_num=5) 33 | # print(synonyms) 34 | 35 | # ------------nlpcda 一般增强------------- 36 | # start = time.time() 37 | # sent_list = ['9元百度专属定向流量包如何取消', 38 | # '你告诉我7天5g视频会员流量包怎么开通,多少钱', 39 | # '您好:怎么开通70爱奇艺', 40 | # 'plus会员领取的权益可以取消吗', 41 | # '通州卡如何取消'] 42 | # 43 | # smw = CharPositionExchange(create_num=3, change_rate=0.01) 44 | # for sent in sent_list: 45 | # rs1 = smw.replace(sent) 46 | # print(rs1) 47 | # end = time.time() 48 | # print(end-start) 49 | 50 | # ---------textda--------------- 51 | # from utils import setup_seed 52 | # from textda.data_expansion import data_expansion 53 | # import random 54 | # 55 | # if __name__ == '__main__': 56 | # # setup_seed(1) 57 | # random.seed(1) 58 | # print(data_expansion('这是一句测试的句子。')) 59 | 60 | # ------------多进程------------ 61 | 62 | # def function1(id): # 这里是子进程 63 | # augment_from_simbert(source_file='../data/raw_data/train_denoised.xlsx', 64 | # target_file='../data/raw_data/train_augment_simbert.xlsx') 65 | # 66 | # 67 | # def run_process(): # 这里是主进程 68 | # from multiprocessing import Process 69 | # process = [Process(target=function1, args=(1,)), 70 | # Process(target=function1, args=(2,)), ] 71 | # [p.start() for p in process] # 开启了两个进程 72 | # [p.join() for p in process] # 等待两个进程依次结束 73 | 74 | 75 | # ------------标注文件转化成预测样本的格式--------------- 76 | def to_predict_file(): 77 | data = [] 78 | with open('../data/dataset/cls_labeled.txt', 'r', encoding='utf-8') as f: 79 | for line in f: 80 | data.append(line.strip().split('\t')[0]) 81 | with open('../data/dataset/labeled_predict.txt', 'w', encoding='utf-8') as f: 82 | for line in data: 83 | f.write(line+'\n') 84 | 85 | 86 | def xlsx2json(): 87 | data = {} 88 | df = pd.read_excel('../data/raw_data/train_denoised.xlsx') 89 | for i in range(len(df)): 90 | line = df.loc[i] 91 | data[str(i)] = { 92 | 'question': line['用户问题'], 93 | 'ans_type': line['答案类型'], 94 | 'entity': line['实体'], 95 | 'main_property': line['属性名']} 96 | with open('../data/results/train_result.json', 'w', encoding='utf-8') as f: 97 | json.dump({'model_result': data}, f, indent=2, ensure_ascii=False) 98 | 99 | 100 | if __name__ == '__main__': 101 | # xlsx2json() 102 | from tqdm import tqdm 103 | from utils import judge_cancel 104 | yes = 0 105 | # with open('../data/dataset/cls_labeled.txt', 'r', encoding='utf-8') as f: 106 | # for line in tqdm(f): 107 | # instance = line.strip().split('\t') 108 | # if judge_cancel(instance[0]): 109 | # if '取消' not in instance[2]: 110 | # print(instance) 111 | # else: 112 | # yes += 1 113 | # print(yes) 114 | with open('../data/dataset/cls_unlabeled2.txt', 'r', encoding='utf-8') as f: 115 | for line in tqdm(f): 116 | instance = line.strip().split('\t') 117 | if judge_cancel(instance[0]): 118 | print(instance) 119 | -------------------------------------------------------------------------------- /code/train_evaluate_constraint.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import re 3 | import json 4 | import os 5 | import random 6 | import time 7 | import torch 8 | from transformers import BertModel, BertTokenizer 9 | 10 | from model import Metrics, BILSTM_CRF_Model, BERT_BILSTM_CRF_Model 11 | from data import * 12 | from utils import setup_seed, load_model 13 | from statistic import get_all_service, statistic_wrong_cons_bieo 14 | import argparse 15 | 16 | 17 | def train(excel_path, model_name, tag_version='bio'): 18 | 19 | def build_map(lists): 20 | maps = {} 21 | for list_ in lists: 22 | for e in list_: 23 | if e == '': continue 24 | if e not in maps: 25 | maps[e] = len(maps) 26 | maps[''] = len(maps) 27 | maps[''] = len(maps) 28 | maps[''] = len(maps) 29 | maps[''] = len(maps) 30 | return maps 31 | 32 | # data 33 | print("读取数据...") 34 | with open(r'../data/file/train_{}.json'.format(tag_version)) as f: 35 | train_dict = json.load(f) 36 | # with open(r'../data/file/train_dev_ids.json') as f: 37 | # id_dict = json.load(f) 38 | raw_data = pd.read_excel(excel_path) 39 | question_list = raw_data['用户问题'].tolist() 40 | question_list = [question.replace(' ', '') for question in question_list] 41 | for i in range(len(question_list)): 42 | question = question_list[i] 43 | question = question.replace('三十', '30') 44 | question = question.replace('四十', '40') 45 | question = question.replace('十块', '10块') 46 | question = question.replace('六块', '6块') 47 | question = question.replace('一个月', '1个月') 48 | question = question.replace('2O', '20') 49 | question_list[i] = question 50 | # 按照划分好的id划分数据 51 | # train_list = [question_list[i] for i in id_dict['train_ids']] 52 | # dev_list = [question_list[i] for i in id_dict['dev_ids']] 53 | # test_list = [question_list[i] for i in id_dict['dev_ids']] 54 | # 随机划分数据 55 | random.shuffle(question_list) 56 | train_list = question_list[:int(0.8*len(question_list))] 57 | dev_list = question_list[int(0.8*len(question_list)): int(0.9*len(question_list))] 58 | test_list = question_list[int(0.9*len(question_list)):] 59 | 60 | train_word_lists, train_tag_lists = process_data(train_list, train_dict) 61 | dev_word_lists, dev_tag_lists = process_data(dev_list, train_dict) 62 | test_word_lists, test_tag_lists = process_data(test_list, train_dict, test=True) 63 | # 生成word2id 和 tag2id 并保存 64 | word2id = build_map(train_word_lists) 65 | tag2id = build_map(train_tag_lists) 66 | with open(r'../data/file/word2id.json', 'w') as f: json.dump(word2id, f) 67 | with open(r'../data/file/tag2id_{}.json'.format(tag_version), 'w') as f: json.dump(tag2id, f) 68 | # train 69 | print("正在训练评估Bi-LSTM+CRF模型...") 70 | start = time.time() 71 | vocab_size = len(word2id) 72 | out_size = len(tag2id) 73 | bilstmcrf_model = BILSTM_CRF_Model(vocab_size, out_size, gpu_id=args.cuda) 74 | bilstmcrf_model.train(train_word_lists, train_tag_lists, 75 | dev_word_lists, dev_tag_lists, word2id, tag2id, debug=args.debug) 76 | torch.save(bilstmcrf_model, '../data/trained_model/{}'.format(model_name)) 77 | print("训练完毕,共用时{}秒.".format(int(time.time() - start))) 78 | print("评估{}模型中...".format('bilstm-crf')) 79 | pred_tag_lists, test_tag_lists = bilstmcrf_model.test( 80 | test_word_lists, test_tag_lists, word2id, tag2id) 81 | 82 | print(len(test_tag_lists)) 83 | print(len(pred_tag_lists)) 84 | if args.debug=='1': 85 | metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=False) 86 | metrics.report_scores() 87 | metrics.report_confusion_matrix() 88 | 89 | 90 | def train_kfold(excel_path, model_name, kfold, tag_version='bio'): 91 | 92 | def build_map(lists): 93 | maps = {} 94 | for list_ in lists: 95 | for e in list_: 96 | if e == '': continue 97 | if e not in maps: 98 | maps[e] = len(maps) 99 | maps[''] = len(maps) 100 | maps[''] = len(maps) 101 | maps[''] = len(maps) 102 | maps[''] = len(maps) 103 | return maps 104 | 105 | # data 106 | # print("读取数据...") 107 | with open(r'../data/file/train_{}.json'.format(tag_version)) as f: 108 | train_dict = json.load(f) 109 | raw_data = pd.read_excel(excel_path) 110 | question_list = raw_data['用户问题'].tolist() 111 | question_list = [question.replace(' ', '') for question in question_list] 112 | for i in range(len(question_list)): 113 | question = question_list[i] 114 | question = question.replace('三十', '30') 115 | question = question.replace('四十', '40') 116 | question = question.replace('十块', '10块') 117 | question = question.replace('六块', '6块') 118 | question = question.replace('一个月', '1个月') 119 | question = question.replace('2O', '20') 120 | question_list[i] = question 121 | # 按照划分好的id划分数据 122 | # train_list = [question_list[i] for i in id_dict['train_ids']] 123 | # dev_list = [question_list[i] for i in id_dict['dev_ids']] 124 | # test_list = [question_list[i] for i in id_dict['dev_ids']] 125 | # 随机划分数据 126 | random.shuffle(question_list) 127 | block_len = len(question_list)//kfold 128 | data_blocks = [question_list[i*block_len: (i+1)*block_len] for i in range(kfold-1)] 129 | data_blocks.append(question_list[(kfold-1)*block_len:]) 130 | # 使用所有训练数据生成word2id tag2id 131 | all_train_list = [] 132 | for block in data_blocks: 133 | all_train_list += block 134 | all_train_word_lists, all_train_tag_lists = process_data(all_train_list, train_dict) 135 | all_word2id = build_map(all_train_word_lists) 136 | all_tag2id = build_map(all_train_tag_lists) 137 | with open(r'../data/file/word2id_{}.json'.format('all'), 'w') as f: json.dump(all_word2id, f) 138 | with open(r'../data/file/tag2id_{}_{}.json'.format(tag_version, 'all'), 'w') as f: json.dump(all_tag2id, f) 139 | # 开始kfold训练 140 | for i, block in enumerate(data_blocks): 141 | train_list, dev_list = [], block 142 | for _ in range(kfold): 143 | if _ != i: 144 | train_list += data_blocks[_] 145 | train_word_lists, train_tag_lists = process_data(train_list, train_dict) 146 | dev_word_lists, dev_tag_lists = process_data(dev_list, train_dict) 147 | # 生成word2id 和 tag2id 并保存 148 | word2id = build_map(train_word_lists) 149 | tag2id = build_map(train_tag_lists) 150 | with open(r'../data/file/word2id_{}.json'.format(i), 'w') as f: json.dump(word2id, f) 151 | with open(r'../data/file/tag2id_{}_{}.json'.format(tag_version, i), 'w') as f: json.dump(tag2id, f) 152 | # train 153 | print("正在训练评估Bi-LSTM+CRF模型...") 154 | start = time.time() 155 | vocab_size = len(all_word2id) 156 | out_size = len(all_tag2id) 157 | bilstmcrf_model = BILSTM_CRF_Model(vocab_size, out_size, gpu_id=args.cuda) 158 | bilstmcrf_model.train(train_word_lists, train_tag_lists, 159 | dev_word_lists, dev_tag_lists, all_word2id, all_tag2id, debug=args.debug) 160 | torch.save(bilstmcrf_model, '../data/trained_model/{}'.format(model_name.replace('.pth', '_{}'.format(i)+'.pth'))) 161 | print("训练完毕,共用时{}秒.".format(int(time.time() - start))) 162 | if args.debug == '1': 163 | print("评估{}模型中...".format('bilstm-crf')) 164 | pred_tag_lists, dev_tag_lists = bilstmcrf_model.test( 165 | dev_word_lists, dev_tag_lists, all_word2id, all_tag2id) 166 | 167 | 168 | print(len(dev_tag_lists)) 169 | print(len(pred_tag_lists)) 170 | metrics = Metrics(dev_tag_lists, pred_tag_lists, remove_O=False) 171 | metrics.report_scores() 172 | metrics.report_confusion_matrix() 173 | 174 | 175 | def train2(excel_path, model_name, tag_version='bio'): 176 | 177 | def build_map(lists): 178 | maps = {} 179 | for list_ in lists: 180 | for e in list_: 181 | if e == '': continue 182 | if e not in maps: 183 | maps[e] = len(maps) 184 | maps[''] = len(maps) 185 | maps[''] = len(maps) 186 | maps[''] = len(maps) 187 | maps[''] = len(maps) 188 | return maps 189 | 190 | # data 191 | print("读取数据...") 192 | with open(r'../data/file/train_{}.json'.format(tag_version)) as f: 193 | train_dict = json.load(f) 194 | with open(r'../data/file/train_dev_ids.json') as f: 195 | id_dict = json.load(f) 196 | raw_data = pd.read_excel(excel_path) 197 | question_list = raw_data['用户问题'].tolist() 198 | question_list = [question.replace(' ', '') for question in question_list] 199 | for i in range(len(question_list)): 200 | question = question_list[i] 201 | question = question.replace('三十', '30') 202 | question = question.replace('四十', '40') 203 | question = question.replace('十块', '10块') 204 | question = question.replace('六块', '6块') 205 | question = question.replace('一个月', '1个月') 206 | question = question.replace('2O', '20') 207 | question_list[i] = question 208 | # 按照划分好的id划分数据 209 | # train_list = [question_list[i] for i in id_dict['train_ids']] 210 | # dev_list = [question_list[i] for i in id_dict['dev_ids']] 211 | # test_list = [question_list[i] for i in id_dict['dev_ids']] 212 | # 随机划分数据 213 | random.shuffle(question_list) 214 | train_list = question_list[:int(0.8*len(question_list))] 215 | dev_list = question_list[int(0.8*len(question_list)): int(0.9*len(question_list))] 216 | test_list = question_list[int(0.9*len(question_list)):] 217 | 218 | train_word_lists, train_tag_lists = process_data(train_list, train_dict) 219 | dev_word_lists, dev_tag_lists = process_data(dev_list, train_dict) 220 | test_word_lists, test_tag_lists = process_data(test_list, train_dict, test=True) 221 | # bert 222 | BERT_PRETRAINED_PATH = '../data/bert_pretrained_model' 223 | word2id = {} 224 | with open(os.path.join(BERT_PRETRAINED_PATH, 'vocab.txt'), 'r') as f: 225 | count = 0 226 | for line in f: 227 | word2id[line.split('\n')[0]] = count 228 | count += 1 229 | bert_model = BertModel.from_pretrained(BERT_PRETRAINED_PATH) 230 | # 生成word2id 和 tag2id 并保存 231 | tag2id = build_map(train_tag_lists) 232 | with open(r'../data/file/word2id_bert.json', 'w') as f: json.dump(word2id, f) 233 | with open(r'../data/file/tag2id_{}.json'.format(tag_version), 'w') as f: json.dump(tag2id, f) 234 | 235 | # train 236 | print("正在训练评估Bert-Bi-LSTM+CRF模型...") 237 | start = time.time() 238 | vocab_size = len(word2id) 239 | out_size = len(tag2id) 240 | bilstmcrf_model = BERT_BILSTM_CRF_Model(vocab_size, out_size, bert_model, gpu_id=args.cuda) 241 | bilstmcrf_model.train(train_word_lists, train_tag_lists, 242 | dev_word_lists, dev_tag_lists, word2id, tag2id) 243 | torch.save(bilstmcrf_model, '../data/trained_model/{}'.format(model_name)) 244 | print("训练完毕,共用时{}秒.".format(int(time.time() - start))) 245 | print("评估{}模型中...".format('bert-bilstm-crf')) 246 | pred_tag_lists, test_tag_lists = bilstmcrf_model.test( 247 | test_word_lists, test_tag_lists, word2id, tag2id) 248 | 249 | print(len(test_tag_lists)) 250 | print(len(pred_tag_lists)) 251 | metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=False) 252 | metrics.report_scores() 253 | metrics.report_confusion_matrix() 254 | 255 | 256 | def evaluate(excel_path, word2id_version='word2id.json', model_name='bilstmcrf.pth', tag_version='bio'): 257 | with open(r'../data/file/train_{}.json'.format(tag_version), 'r') as f: 258 | train_dict = json.load(f) 259 | with open(r'../data/file/{}'.format(word2id_version), 'r') as f: word2id = json.load(f) 260 | with open(r'../data/file/tag2id_{}.json'.format(tag_version), 'r') as f: tag2id = json.load(f) 261 | # 得到训练数据 262 | raw_data = pd.read_excel(excel_path) 263 | question_list = raw_data['用户问题'].tolist() 264 | question_list = [question.replace(' ', '') for question in question_list] 265 | for i in range(len(question_list)): 266 | question = question_list[i] 267 | question = question.replace('三十', '30') 268 | question = question.replace('四十', '40') 269 | question = question.replace('十块', '10块') 270 | question = question.replace('六块', '6块') 271 | question = question.replace('一个月', '1个月') 272 | question = question.replace('2O', '20') 273 | question_list[i] = question 274 | index_list = list(range(len(question_list))) 275 | random.shuffle(index_list) 276 | question_list = [question_list[i] for i in index_list] 277 | test_word_lists, test_tag_lists = process_data(question_list[:], train_dict) 278 | # 得到原来的约束属性名和约束值 279 | con_names = raw_data['约束属性名'].tolist() 280 | con_values = raw_data['约束属性值'].tolist() 281 | con_names = [re.split(r'[|\|]', str(con_names[i])) for i in index_list] 282 | con_values = [re.split(r'[|\|]', str(con_values[i])) for i in index_list] 283 | con_dict_list = [] 284 | for i in range(len(con_names)): 285 | con_dict = {} 286 | con_name = con_names[i] 287 | con_value = con_values[i] 288 | for j in range(len(con_name)): 289 | name = con_name[j].strip() 290 | value = con_value[j].strip() 291 | if name == 'nan' or name == '有效期': continue 292 | if name == '价格' and value == '1' and '最' in question_list[i]: continue 293 | if name == '流量' and value == '1' and '最' in question_list[i]: continue 294 | if name not in con_dict: con_dict[name] = [] 295 | con_dict[name].append(value) 296 | con_dict_list.append(con_dict) 297 | 298 | service_names = get_all_service(raw_data) 299 | 300 | bilstmcrf_model = load_model(r'../data/trained_model/{}'.format(model_name), map_location=lambda storage, loc: storage.cuda('cuda:{}'.format(args.cuda))) 301 | bilstmcrf_model.device = torch.device('cuda:{}'.format(args.cuda)) 302 | pred_tag_lists, test_tag_lists = bilstmcrf_model.test( 303 | test_word_lists, test_tag_lists, word2id, tag2id) 304 | c = 0 305 | for index in range(len(pred_tag_lists)): 306 | pre_dict = get_anno_dict(question_list[index], pred_tag_lists[index], service_names) 307 | # 对比 官方标注 和 非增强预测标注 308 | if set(pre_dict) != set(con_dict_list[index]): 309 | c+=1 310 | print(question_list[index]) 311 | print('true:',test_tag_lists[index]) 312 | print('pred:',pred_tag_lists[index]) 313 | print('true:',con_dict_list[index]) 314 | print('pred:',pre_dict) 315 | print(c) 316 | 317 | 318 | def evaluate_kfold(excel_path, kfold, word2id_version='word2id.json', model_name='bilstmcrf.pth', tag_version='bio'): 319 | for i in range(kfold): 320 | with open(r'../data/file/train_{}.json'.format(tag_version), 'r') as f: 321 | train_dict = json.load(f) 322 | with open(r'../data/file/{}'.format(word2id_version.replace('.json', '_{}.json'.format('all'))), 'r') as f: word2id = json.load(f) 323 | with open(r'../data/file/tag2id_{}_{}.json'.format(tag_version, 'all'), 'r') as f: tag2id = json.load(f) 324 | # 得到训练数据 325 | raw_data = pd.read_excel(excel_path) 326 | question_list = raw_data['用户问题'].tolist() 327 | question_list = [question.replace(' ', '') for question in question_list] 328 | for i in range(len(question_list)): 329 | question = question_list[i] 330 | question = question.replace('三十', '30') 331 | question = question.replace('四十', '40') 332 | question = question.replace('十块', '10块') 333 | question = question.replace('六块', '6块') 334 | question = question.replace('一个月', '1个月') 335 | question = question.replace('2O', '20') 336 | question_list[i] = question 337 | index_list = list(range(len(question_list))) 338 | random.shuffle(index_list) 339 | question_list = [question_list[i] for i in index_list] 340 | test_word_lists, test_tag_lists = process_data(question_list[:], train_dict) 341 | # 得到原来的约束属性名和约束值 342 | con_names = raw_data['约束属性名'].tolist() 343 | con_values = raw_data['约束属性值'].tolist() 344 | con_names = [re.split(r'[|\|]', str(con_names[i])) for i in index_list] 345 | con_values = [re.split(r'[|\|]', str(con_values[i])) for i in index_list] 346 | con_dict_list = [] 347 | for i in range(len(con_names)): 348 | con_dict = {} 349 | con_name = con_names[i] 350 | con_value = con_values[i] 351 | for j in range(len(con_name)): 352 | name = con_name[j].strip() 353 | value = con_value[j].strip() 354 | if name == 'nan' or name == '有效期': continue 355 | if name == '价格' and value == '1' and '最' in question_list[i]: continue 356 | if name == '流量' and value == '1' and '最' in question_list[i]: continue 357 | if name not in con_dict: con_dict[name] = [] 358 | con_dict[name].append(value) 359 | con_dict_list.append(con_dict) 360 | 361 | service_names = get_all_service(raw_data) 362 | 363 | bilstmcrf_model = load_model(r'../data/trained_model/{}'.format(model_name.replace('.pth', '_{}.pth'.format(i))), map_location=lambda storage, loc: storage.cuda('cuda:{}'.format(args.cuda))) 364 | bilstmcrf_model.device = torch.device('cuda:{}'.format(args.cuda)) 365 | pred_tag_lists, test_tag_lists = bilstmcrf_model.test( 366 | test_word_lists, test_tag_lists, word2id, tag2id) 367 | # 使用了 enhence 又需要对比增强前的 368 | c = 0 369 | for index in range(len(pred_tag_lists)): 370 | pre_dict = get_anno_dict(question_list[index], pred_tag_lists[index], service_names) 371 | test_dict = get_anno_dict(question_list[index], test_tag_lists[index][:-1], service_names) 372 | # 对比 官方标注 和 非增强预测标注 373 | if set(pre_dict) != set(con_dict_list[index]): 374 | c+=1 375 | print(question_list[index]) 376 | print('true:',test_tag_lists[index]) 377 | print('pred:',pred_tag_lists[index]) 378 | print('true:',con_dict_list[index]) 379 | print('pred:',pre_dict) 380 | print(c) 381 | 382 | 383 | if __name__ == '__main__': 384 | print('开始train_ner') 385 | setup_seed(1) 386 | excel_path = r'../data/raw_data/train_denoised_ner.xlsx' 387 | # excel_path = r'../data/raw_data/train_syn.xlsx' 388 | # os.environ["CUDA_VISIBLE_DEVICES"] = "5" 389 | # create_test_BIEO(excel_path) 390 | parser = argparse.ArgumentParser() 391 | parser.add_argument('--train_type', '-tt', default='lstm', type=str) 392 | parser.add_argument('--tag_version', '-tv', default='bieo', type=str) 393 | parser.add_argument('--model_name', '-m', default='ner_model_test.pth', type=str) 394 | parser.add_argument('--cuda', '-c', default=0) 395 | parser.add_argument('--debug', '-d', default='1', choices=['0', '1']) 396 | parser.add_argument('--kfold', '-k', default=5, type=int) 397 | # parser.add_argument('--kfold', '-k', action='store_true') 398 | args = parser.parse_args() 399 | a = args.kfold 400 | if args.tag_version == 'bio': 401 | create_test_BIO(excel_path) 402 | else: 403 | create_test_BIEO(excel_path) 404 | # statistic_wrong_cons_bieo() 405 | if args.train_type == 'lstm': 406 | if args.kfold == 1: 407 | train(excel_path, model_name=args.model_name, tag_version=args.tag_version) 408 | if args.debug == '1': 409 | evaluate(excel_path, word2id_version='word2id.json', model_name=args.model_name, tag_version=args.tag_version) 410 | else: 411 | train_kfold(excel_path, model_name=args.model_name, kfold=args.kfold, tag_version=args.tag_version) 412 | if args.debug == '1': 413 | evaluate_kfold(excel_path, kfold=args.kfold, word2id_version='word2id.json', model_name=args.model_name, 414 | tag_version=args.tag_version) 415 | else: 416 | train2(excel_path, model_name=args.model_name, tag_version=args.tag_version) 417 | if args.debug == '1': 418 | evaluate(excel_path, word2id_version='word2id_bert.json', model_name=args.model_name, tag_version=args.tag_version) -------------------------------------------------------------------------------- /code/triples.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: hezf 3 | @Time: 2021/6/8 11:08 4 | @desc: 负责RDF图的构建以及查询 5 | """ 6 | from rdflib import Graph, Namespace 7 | from utils import to_sparql 8 | from typing import List, Dict, Tuple 9 | import re 10 | 11 | 12 | class KnowledgeGraph(object): 13 | def __init__(self, rdf_file): 14 | super(KnowledgeGraph, self).__init__() 15 | self.rdf_file = rdf_file 16 | self.graph = Graph() 17 | self.init() 18 | 19 | def init(self): 20 | # print('parsing file...') 21 | self.graph.parse(self.rdf_file, format='n3') 22 | 23 | def query(self, q): 24 | """ 25 | 在图谱中查询答案 26 | :param q: sparql语句 27 | :return: List[str] 28 | """ 29 | answers = [] 30 | results = self.graph.query(q).bindings 31 | for result in results: 32 | for value in result.values(): 33 | value = value.toPython() 34 | value = value[value.rfind('/')+1:] 35 | answers.append(value) 36 | return answers 37 | 38 | def fetch_ans(self, question: str, ans_type: str, entity: str, main_property: List, operator: str, 39 | sub_properties: List[Tuple] = None): 40 | """ 41 | 获取最终答案的函数 42 | :param question: 问题本身 43 | :param ans_type: 答案类型【比较句、并列句、属性值】 44 | :param entity: 实体 45 | :param main_property: 属性名["档位介绍-流量"或者"XX规则"],并列句有两个属性名,而其他句式只有一个属性名 46 | :param sub_properties: 约束属性键值对【[(key, value), (key, value), ...]】。不用字典是因为比较句中key有可能相同 47 | :param operator: 约束算子【min max other】 48 | :return: 49 | """ 50 | # 问句特色: 51 | # 1 比较句:属性名只要一个,比较句的约束属性名肯定为两个;句式一般为:是否和哪个两种; 52 | # 2 并列句:有多个属性名,少量情况含有min,max算子,只要将属性名分批答案取出即可 53 | # 3 属性句:最普通的情况 54 | try: 55 | for i in range(len(sub_properties)): 56 | if sub_properties[i][0] == '流量': 57 | value = int(sub_properties[i][1]) 58 | if 1 < value < 55 and entity != '流量加油包': 59 | value *= 1024 60 | sub_properties[i] = ('流量', value) 61 | # print('修改流量的句子是:', question) 62 | if ans_type == '比较句': 63 | if len(sub_properties) == 0: 64 | return [] 65 | # 处理当预测的约束属性名个数大于两个时的情况 66 | if len(sub_properties) > 2: 67 | key_count = {} 68 | for k, v in sub_properties: 69 | if k not in key_count: 70 | key_count[k] = 0 71 | key_count[k] += 1 72 | real_k = None 73 | for k, count in key_count.items(): 74 | if count >= 2: 75 | real_k = k 76 | break 77 | if real_k is not None: 78 | for k, v in sub_properties: 79 | if k != real_k: 80 | sub_properties.remove((k, v)) 81 | q0 = to_sparql(entity=entity, main_property=main_property[0], sub_properties=[sub_properties[0]]) 82 | ans0 = self.query(q0) 83 | if len(sub_properties) >= 2: 84 | q1 = to_sparql(entity=entity, main_property=main_property[0], sub_properties=[sub_properties[1]]) 85 | ans1 = self.query(q1) 86 | else: 87 | ans1 = None 88 | if len(ans0) > 0: 89 | ans0 = ans0[0] 90 | else: 91 | ans0 = None 92 | if ans1 is not None and len(ans1) > 0: 93 | ans1 = ans1[0] 94 | else: 95 | ans1 = None 96 | keywords = ['哪', '那个'] 97 | flag = False 98 | for kw in keywords: 99 | if kw in question: 100 | flag = True 101 | break 102 | # "哪个"类型的问题 103 | if flag: 104 | if ans1 is None: 105 | ans = [sub_properties[0][1]] 106 | else: 107 | if ans0 == ans1: 108 | ans = [sub_properties[0][1], sub_properties[1][1]] 109 | else: 110 | bigger_keywords = ['多', '贵'] 111 | smaller_keywords = ['便宜', '优惠', '少', '实惠'] 112 | big_flag = False 113 | for bkw in bigger_keywords: 114 | if bkw in question: 115 | big_flag = True 116 | break 117 | if big_flag: 118 | if int(ans0) > int(ans1): 119 | ans = [sub_properties[0][1]] 120 | else: 121 | ans = [sub_properties[1][1]] 122 | else: 123 | if int(ans0) < int(ans1): 124 | ans = [sub_properties[0][1]] 125 | else: 126 | ans = [sub_properties[1][1]] 127 | # "是否"类型的问题 128 | else: 129 | if ans1 is None: 130 | ans = ['no'] 131 | else: 132 | if ans0 == ans1: 133 | ans = ['yes'] 134 | else: 135 | ans = ['no'] 136 | # equal_kw = ['一样', '相同', '等价', '等同', '相等'] 137 | # eq_q = False 138 | # for k in equal_kw: 139 | # if question.find(k) != -1: 140 | # eq_q = True 141 | # break 142 | # # 相等问题 143 | # if eq_q: 144 | # if ans0 == ans1: 145 | # ans = ['yes'] 146 | # else: 147 | # ans = ['no'] 148 | # # 区别问题 149 | # else: 150 | # if ans0 != ans1: 151 | # ans = ['yes'] 152 | # else: 153 | # ans = ['no'] 154 | # 并列句 155 | elif ans_type == '并列句': 156 | ans = [] 157 | if operator == 'min': 158 | if len(sub_properties) == 0: 159 | return [] 160 | # 分两段获取答案 161 | key = sub_properties[0][0] 162 | query_str = to_sparql(entity=entity, main_property='档位介绍-'+key) 163 | temp_ans = self.query(query_str) 164 | temp_ans = [int(i) for i in temp_ans] 165 | if len(temp_ans) == 0: 166 | return [] 167 | first_step_ans = min(temp_ans) 168 | for m_p in main_property: 169 | query_str = to_sparql(entity=entity, main_property=m_p, sub_properties=[(key, first_step_ans)]) 170 | ans += self.query(query_str) 171 | elif operator == 'max': 172 | if len(sub_properties) == 0: 173 | return [] 174 | key = sub_properties[0][0] 175 | query_str = to_sparql(entity=entity, main_property='档位介绍-'+key) 176 | temp_ans = self.query(query_str) 177 | temp_ans = [int(i) for i in temp_ans] 178 | if len(temp_ans) == 0: 179 | return [] 180 | first_step_ans = max(temp_ans) 181 | for m_p in main_property: 182 | query_str = to_sparql(entity=entity, main_property=m_p, sub_properties=[(key, first_step_ans)]) 183 | ans += self.query(query_str) 184 | # =!= 185 | else: 186 | for m_p in main_property: 187 | query_str = to_sparql(entity=entity, main_property=m_p, sub_properties=sub_properties) 188 | ans += self.query(query_str) 189 | # 属性值 190 | elif ans_type == '属性值': 191 | # 答案只有一个,只需要选择一个最小(大)的即可。分两阶段获取答案 192 | if operator == 'min': 193 | if len(sub_properties) == 0: 194 | # 直接来近似得到 195 | temp_ans = self.query(to_sparql(entity=entity, main_property=main_property[0])) 196 | temp_ans = [int(i) for i in temp_ans] 197 | ans = min(temp_ans) 198 | return [ans] 199 | else: 200 | key = sub_properties[0][0] 201 | query_str = to_sparql(entity=entity, main_property=main_property[0].split('-')[0]+'-'+key) 202 | temp_ans = self.query(query_str) 203 | temp_ans = [int(i) for i in temp_ans] 204 | first_step_ans = min(temp_ans) 205 | query_str = to_sparql(entity=entity, main_property=main_property[0], sub_properties=[(key, first_step_ans)]) 206 | ans = self.query(query_str) 207 | elif operator == 'max': 208 | if len(sub_properties) == 0: 209 | query_str = to_sparql(entity=entity, main_property=main_property[0]) 210 | temp_ans = self.query(query_str) 211 | temp_ans = [int(i) for i in temp_ans] 212 | ans = max(temp_ans) 213 | return [ans] 214 | else: 215 | key = sub_properties[0][0] 216 | query_str = to_sparql(entity=entity, main_property=main_property[0].split('-')[0]+'-'+key) 217 | temp_ans = self.query(query_str) 218 | # 修复字符串'700'大于'10000'的bug 219 | temp_ans = [int(i) for i in temp_ans] 220 | first_step_ans = max(temp_ans) 221 | query_str = to_sparql(entity=entity, main_property=main_property[0], sub_properties=[(key, first_step_ans)]) 222 | ans = self.query(query_str) 223 | else: 224 | query_str = to_sparql(entity=entity, main_property=main_property[0], sub_properties=sub_properties) 225 | ans = self.query(query_str) 226 | else: 227 | ans = [] 228 | print('-------------------{}:乱码对结果造成影响---------------------'.format(ans_type)) 229 | ans = list(set(ans)) 230 | return ans 231 | except Exception as e: 232 | print('---------------查找知识图谱时发生异常:{}------------------------'.format(e)) 233 | print('当前参数为,问题:{}, 答案类型:{}, 实体:{}, 属性名{}, 算子:{}, 约束属性:{}'.format(question, ans_type, entity, main_property, operator, sub_properties)) 234 | return [] 235 | 236 | def fetch_wrong_ans(self, question: str, ans_type: str, entity: str, main_property: List, operator: str, 237 | sub_properties: List[Tuple] = None): 238 | """ 239 | 在没有约束传入时,查看对应的子业务、价格、流量有哪些,从而可以与原句对比,看是否有遗漏的约束标注 240 | :param question: 问题本身 241 | :param ans_type: 答案类型【比较句、并列句、属性值】 242 | :param entity: 实体 243 | :param main_property: 属性名["档位介绍-流量"或者"XX规则"],并列句有两个属性名,而其他句式只有一个属性名 244 | :param sub_properties: 约束属性键值对【[(key, value), (key, value), ...]】。不用字典是因为比较句中key有可能相同 245 | :param operator: 约束算子【min max other】 246 | :return: 247 | """ 248 | # 问句特色: 249 | # 1 比较句:属性名只要一个,比较句的约束属性名肯定为两个;句式一般为:是否和哪个两种; 250 | # 2 并列句:有多个属性名,少量情况含有min,max算子,只要将属性名分批答案取出即可 251 | # 3 属性句:最普通的情况 252 | for i in range(len(sub_properties)): 253 | if sub_properties[i][0] == '流量': 254 | value = int(sub_properties[i][1]) 255 | if 1 < value < 55: 256 | value *= 1024 257 | sub_properties[i] = ('流量', value) 258 | # print('修改流量的句子是:', question) 259 | ans, query_str = [], '' 260 | # 比较句不需要判断,认为约束的标注必然是对的 261 | if ans_type == '比较句': 262 | # pass 263 | # 判断句仍然需要标注。。。 264 | price_property = '档位介绍-价格' 265 | flow_property = '档位介绍-流量' 266 | service_property = '档位介绍-子业务' 267 | query_str = to_sparql(entity=entity, main_property=price_property, sub_properties=[]) 268 | price_ans = self.query(query_str) 269 | query_str = to_sparql(entity=entity, main_property=flow_property, sub_properties=[]) 270 | flow_ans = self.query(query_str) 271 | query_str = to_sparql(entity=entity, main_property=service_property, sub_properties=[]) 272 | service_ans = self.query(query_str) 273 | # 并列句不需要判断,认为约束一般都是没有的 274 | elif ans_type == '并列句': 275 | # pass 276 | # 并列句仍然需要标注 277 | price_property = '档位介绍-价格' 278 | flow_property = '档位介绍-流量' 279 | service_property = '档位介绍-子业务' 280 | query_str = to_sparql(entity=entity, main_property=price_property, sub_properties=[]) 281 | price_ans = self.query(query_str) 282 | query_str = to_sparql(entity=entity, main_property=flow_property, sub_properties=[]) 283 | flow_ans = self.query(query_str) 284 | query_str = to_sparql(entity=entity, main_property=service_property, sub_properties=[]) 285 | service_ans = self.query(query_str) 286 | # 属性句 287 | else: 288 | # 答案只有一个,只需要选择一个最小(大)的即可。分两阶段获取答案 289 | if operator == 'min': 290 | if len(sub_properties) == 0: 291 | # 直接来近似得到 292 | temp_ans = self.query(to_sparql(entity=entity, main_property=main_property[0])) 293 | temp_ans = [int(i) for i in temp_ans] 294 | ans = min(temp_ans) 295 | return [ans] 296 | else: 297 | key = sub_properties[0][0] 298 | query_str = to_sparql(entity=entity, main_property=main_property[0].split('-')[0]+'-'+key) 299 | temp_ans = self.query(query_str) 300 | temp_ans = [int(i) for i in temp_ans] 301 | first_step_ans = min(temp_ans) 302 | query_str = to_sparql(entity=entity, main_property=main_property[0], sub_properties=[(key, first_step_ans)]) 303 | ans = self.query(query_str) 304 | elif operator == 'max': 305 | if len(sub_properties) == 0: 306 | query_str = to_sparql(entity=entity, main_property=main_property[0]) 307 | temp_ans = self.query(query_str) 308 | temp_ans = [int(i) for i in temp_ans] 309 | ans = max(temp_ans) 310 | return [ans] 311 | else: 312 | key = sub_properties[0][0] 313 | query_str = to_sparql(entity=entity, main_property=main_property[0].split('-')[0]+'-'+key) 314 | temp_ans = self.query(query_str) 315 | # 修复字符串'700'大于'10000'的bug 316 | temp_ans = [int(i) for i in temp_ans] 317 | first_step_ans = max(temp_ans) 318 | query_str = to_sparql(entity=entity, main_property=main_property[0], sub_properties=[(key, first_step_ans)]) 319 | ans = self.query(query_str) 320 | else: 321 | # if '-' in main_property[0]: 322 | price_property = '档位介绍-价格' 323 | flow_property = '档位介绍-流量' 324 | service_property = '档位介绍-子业务' 325 | query_str = to_sparql(entity=entity, main_property=price_property, sub_properties=[]) 326 | price_ans = self.query(query_str) 327 | query_str = to_sparql(entity=entity, main_property=flow_property, sub_properties=[]) 328 | flow_ans = self.query(query_str) 329 | query_str = to_sparql(entity=entity, main_property=service_property, sub_properties=[]) 330 | service_ans = self.query(query_str) 331 | price_ans = list(set(price_ans)) 332 | flow_ans = list(set(flow_ans)) 333 | service_ans = list(set(service_ans)) 334 | return (price_ans, flow_ans, service_ans) 335 | 336 | if __name__ == '__main__': 337 | kg = KnowledgeGraph('../data/process_data/triples.rdf') 338 | # q = 'select ?ans where { ?instance. ?instance ?ans}' 339 | # ans = kg.query(q) 340 | ans = kg.fetch_ans(**{'question': '你好!花季守护的软件,需要在孩子的手机里安装东西吗?', 'ans_type': '属性值', 'entity': '花季守护业务', 'main_property': ['档位介绍-使用方法'], 'operator': 'other', 'sub_properties': []}) 341 | print(ans) 342 | 343 | import pandas as pd 344 | df = pd.read_excel('../data/raw_data/train_denoised.xlsx') 345 | df.fillna('') 346 | id_list = [] 347 | ans_list = [] 348 | for iter, row in df.iterrows(): 349 | ans_true = list(set(row['答案'].split('|'))) 350 | question = row['用户问题'] 351 | ans_type = row['答案类型'] 352 | entity = row['实体'] 353 | main_property = row['属性名'].split('|') 354 | operator = row['约束算子'] 355 | if operator != 'min' and operator != 'max': 356 | operator == 'other' 357 | sub_properties = [] 358 | cons_names = str(row['约束属性名']).split('|') 359 | cons_values = str(row['约束属性值']).split('|') 360 | if cons_names == ['nan']: cons_names = [] 361 | for index in range(len(cons_names)): 362 | sub_properties.append([cons_names[index], cons_values[index]]) 363 | ans = kg.fetch_ans(question, ans_type, entity, main_property, operator, sub_properties) 364 | 365 | def is_same(ans, ans_true): 366 | for an in ans: 367 | if an in ans_true: 368 | ans_true.remove(an) 369 | else: 370 | return False 371 | if len(ans_true) != 0: 372 | return False 373 | return True 374 | 375 | 376 | if not is_same(ans, ans_true): 377 | id_list.append(iter) 378 | ans_list.append(ans) 379 | print(id_list) 380 | df_save = df.iloc[id_list, [0, 1, 2, 3, 4, 5, 6, 7]] 381 | df_save['预测'] = ans_list 382 | print(df_save) 383 | df_save.to_excel('/data/huangbo/project/Tianchi_nlp_git/data/raw_data/train_wrong_triple.xlsx', index=False) 384 | -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import * 2 | from rdflib import Graph, Namespace 3 | import rdflib 4 | import copy 5 | from typing import List, Dict, Tuple 6 | import pandas as pd 7 | import numpy as np 8 | import torch 9 | import os 10 | import random 11 | import json 12 | import time 13 | import logging 14 | from datetime import timedelta 15 | 16 | 17 | class TrainingConfig(object): 18 | """ 19 | BiLSTM模型的训练参数 20 | """ 21 | batch_size = 64 22 | # 学习速率 23 | lr = 0.001 24 | epoches = 50 25 | print_step = 10 26 | 27 | 28 | class LSTMConfig(object): 29 | """ 30 | BiLSTM模型中LSTM模块的参数 31 | """ 32 | emb_size = 128 # 词向量的维数 33 | hidden_size = 128 # lstm隐向量的维数 34 | 35 | 36 | class BertTrainingConfig(object): 37 | """ 38 | BiLSTM模型的训练参数 39 | """ 40 | batch_size = 64 41 | # 学习速率 42 | lr = 0.00001 43 | other_lr = 0.0001 44 | epoches = 50 45 | print_step = 5 46 | 47 | 48 | class BERTLSTMConfig(object): 49 | """ 50 | BERT+BiLSTM模型中LSTM模块的参数 51 | """ 52 | emb_size = 768 # 词向量的维数 53 | hidden_size = 512 # lstm隐向量的维数 54 | 55 | 56 | def score_evalution(answers, predictions): 57 | """ 58 | 用avr_F1评价预测结果 59 | @param answers: {'0':'asd|sdf', '1':'qwe|wer' } 60 | @param predictions: {'0':'asd|sdf', '1':'qwe|wer' } 61 | @return: avr_F1 62 | """ 63 | avr_F1 = 0 64 | for index, answer in answers.items(): 65 | prediction = predictions[index] 66 | answer_list = answer.split('|') 67 | answer_set = set() 68 | for item in answer_list: 69 | answer_set.add(item) 70 | 71 | prediction_list = prediction.split('|') 72 | prediction_set = set() 73 | for item in prediction_list: 74 | prediction_set.add(item) 75 | 76 | intersection_set = answer_set.intersection(prediction_set) 77 | 78 | A = len(answer_set) 79 | G = len(prediction_set) 80 | if G==0 or len(intersection_set) == 0: 81 | avr_F1 += 0 82 | continue 83 | P = len(intersection_set)/(A * 1.0) 84 | R = len(intersection_set)/(G * 1.0) 85 | avr_F1 += (2 * P * R)/(P + R) 86 | avr_F1 /= len(answers) 87 | return avr_F1 88 | 89 | def cal_lstm_crf_loss(crf_scores, targets, tag2id): 90 | """计算双向LSTM-CRF模型的损失 91 | 该损失函数的计算可以参考:https://arxiv.org/pdf/1603.01360.pdf 92 | """ 93 | targets_copy = copy.deepcopy(targets) 94 | pad_id = tag2id.get('') 95 | start_id = tag2id.get('') 96 | end_id = tag2id.get('') 97 | 98 | device = crf_scores.device 99 | 100 | # targets_copy:[B, L] crf_scores:[B, L, T, T] 101 | batch_size, max_len = targets_copy.size() 102 | target_size = len(tag2id) 103 | 104 | # mask = 1 - ((targets_copy == pad_id) + (targets_copy == end_id)) # [B, L] 105 | mask = (targets_copy != pad_id) 106 | lengths = mask.sum(dim=1) 107 | targets_copy = indexed(targets_copy, target_size, start_id) 108 | 109 | # # 计算Golden scores方法1 110 | # import pdb 111 | # pdb.set_trace() 112 | targets_copy = targets_copy.masked_select(mask) # [real_L] 113 | 114 | flatten_scores = crf_scores.masked_select( 115 | mask.view(batch_size, max_len, 1, 1).expand_as(crf_scores) 116 | ).view(-1, target_size*target_size).contiguous() 117 | 118 | golden_scores = flatten_scores.gather( 119 | dim=1, index=targets_copy.unsqueeze(1)).sum() 120 | 121 | # 计算golden_scores方法2:利用pack_padded_sequence函数 122 | # targets_copy[targets_copy == end_id] = pad_id 123 | # scores_at_targets = torch.gather( 124 | # crf_scores.view(batch_size, max_len, -1), 2, targets_copy.unsqueeze(2)).squeeze(2) 125 | # scores_at_targets, _ = pack_padded_sequence( 126 | # scores_at_targets, lengths-1, batch_first=True 127 | # ) 128 | # golden_scores = scores_at_targets.sum() 129 | 130 | # 计算all path scores 131 | # scores_upto_t[i, j]表示第i个句子的第t个词被标注为j标记的所有t时刻事前的所有子路径的分数之和 132 | scores_upto_t = torch.zeros(batch_size, target_size).to(device) 133 | for t in range(max_len): 134 | # 当前时刻 有效的batch_size(因为有些序列比较短) 135 | batch_size_t = (lengths > t).sum().item() 136 | if t == 0: 137 | scores_upto_t[:batch_size_t] = crf_scores[:batch_size_t, 138 | t, start_id, :] 139 | else: 140 | # We add scores at current timestep to scores accumulated up to previous 141 | # timestep, and log-sum-exp Remember, the cur_tag of the previous 142 | # timestep is the prev_tag of this timestep 143 | # So, broadcast prev. timestep's cur_tag scores 144 | # along cur. timestep's cur_tag dimension 145 | scores_upto_t[:batch_size_t] = torch.logsumexp( 146 | crf_scores[:batch_size_t, t, :, :] + 147 | scores_upto_t[:batch_size_t].unsqueeze(2), 148 | dim=1 149 | ) 150 | all_path_scores = scores_upto_t[:, end_id].sum() 151 | 152 | # 训练大约两个epoch loss变成负数,从数学的角度上来说,loss = -logP 153 | loss = (all_path_scores - golden_scores) / batch_size 154 | return loss 155 | 156 | 157 | def tensorized(batch, maps): 158 | """ 159 | BiLSTM训练使用 160 | @param batch: 161 | @param maps: 162 | @return: 163 | """ 164 | PAD = maps.get('') 165 | UNK = maps.get('') 166 | 167 | max_len = len(batch[0]) 168 | batch_size = len(batch) 169 | 170 | batch_tensor = torch.ones(batch_size, max_len).long() * PAD 171 | for i, l in enumerate(batch): 172 | for j, e in enumerate(l): 173 | batch_tensor[i][j] = maps.get(e, UNK) 174 | # batch各个元素的长度 175 | lengths = [len(l) for l in batch] 176 | 177 | return batch_tensor, lengths 178 | 179 | 180 | def tensorized_bert(batch, maps): 181 | """ 182 | BiLSTM训练使用 183 | @param batch: 184 | @param maps: 185 | @return: 186 | """ 187 | PAD = maps.get('[PAD]') 188 | UNK = maps.get('[UNK]') 189 | 190 | max_len = len(batch[0]) 191 | batch_size = len(batch) 192 | 193 | batch_tensor = torch.ones(batch_size, max_len).long() * PAD 194 | for i, l in enumerate(batch): 195 | for j, e in enumerate(l): 196 | batch_tensor[i][j] = maps.get(e, UNK) 197 | # batch各个元素的长度 198 | lengths = [len(l) for l in batch] 199 | 200 | return batch_tensor, lengths 201 | 202 | 203 | def sort_by_lengths(word_lists, tag_lists): 204 | pairs = list(zip(word_lists, tag_lists)) 205 | indices = sorted(range(len(pairs)), 206 | key=lambda k: len(pairs[k][0]), 207 | reverse=True) 208 | pairs = [pairs[i] for i in indices] 209 | # pairs.sort(key=lambda pair: len(pair[0]), reverse=True) 210 | 211 | word_lists, tag_lists = list(zip(*pairs)) 212 | 213 | return word_lists, tag_lists, indices 214 | 215 | 216 | def indexed(targets, tagset_size, start_id): 217 | """将targets中的数转化为在[T*T]大小序列中的索引,T是标注的种类""" 218 | batch_size, max_len = targets.size() 219 | for col in range(max_len-1, 0, -1): 220 | targets[:, col] += (targets[:, col-1] * tagset_size) 221 | targets[:, 0] += (start_id * tagset_size) 222 | return targets 223 | 224 | 225 | def get_operator(question): 226 | """ 227 | 得到句子的约束算子,只针对min\max的算子,对算子为‘=’或者空的情况均返回None 228 | @param question: 用户问题(str) 229 | @return: 约束算子 和 约束属性 230 | """ 231 | operater, obj = None, None 232 | if '最多' in question: 233 | operater, obj = 'max', '流量' 234 | elif '最少' in question: 235 | operater, obj = 'min', '流量' 236 | elif '最便宜的' in question or '最实惠的' in question: 237 | operater, obj = 'min', '价格' 238 | elif '最贵' in question: 239 | operater, obj = 'max', '价格' 240 | return operater, obj 241 | 242 | 243 | def parse_triples_file(file_path): 244 | """ 245 | 解析triples.txt为triples.rdf 246 | """ 247 | triples = [] 248 | # read raw_triples 249 | with open(file_path, 'r', encoding='utf-8') as f: 250 | for line in f: 251 | line = line.strip() 252 | line = line.replace(' ', '') 253 | line_block = line.split(' ') 254 | t = [] 255 | flag = True 256 | for block in line_block: 257 | # 调整三元组格式与训练文件一致 258 | block = block.replace('档位介绍表', '档位介绍') 259 | # 三元组改成小写 260 | block = block.lower() 261 | if block.find('http://yunxiaomi.com/kbqa') != -1: 262 | _ = block.find('_') 263 | t.append('http://yunxiaomi.com/'+block[_+1: -1]) 264 | else: 265 | block = block[1:-1] 266 | if '|' in block: 267 | blocks = block.split('|') 268 | for b in blocks: 269 | flag = False 270 | copy_t = copy.deepcopy(t) 271 | copy_t.append('http://yunxiaomi.com/'+b) 272 | triples.append(copy_t) 273 | else: 274 | t.append('http://yunxiaomi.com/'+block) 275 | if flag: 276 | triples.append(t) 277 | # transfer to RDF file 278 | graph = Graph() 279 | for triple in triples: 280 | if triple is None: 281 | print('None') 282 | continue 283 | graph.add((rdflib.term.URIRef(u'{}'.format(triple[0])), 284 | rdflib.term.URIRef(u'{}'.format(triple[1])), 285 | rdflib.term.URIRef(u'{}'.format(triple[2])))) 286 | graph.serialize('../data/process_data/triples.rdf', format='n3') 287 | graph.close() 288 | 289 | 290 | def to_sparql(entity: str, main_property: str, sub_properties: List[Tuple] = None): 291 | """ 292 | 将实体信息等转化成sparql语句 293 | :param entity: str 294 | :param main_property: str 295 | :param sub_properties: List[Tuple[key, value], ...] 296 | :return: str 297 | """ 298 | if sub_properties is None: 299 | sub_properties = [] 300 | prefix = '' 301 | # relations第0个元素存放主要关系, 之后存放次要关系 302 | relations, conditions = [], '' 303 | # ”档位介绍-取消方式“ 类场景 304 | if main_property.find('-') != -1: 305 | relations.append(main_property.split('-')[0]) 306 | relations.append(main_property.split('-')[1]) 307 | # “生效规则” 类场景 308 | else: 309 | relations.append(main_property) 310 | # 填充关系到sparql 311 | for i, r in enumerate(relations): 312 | condition = '' 313 | if i == 0: 314 | condition = prefix.format(entity) + ' ' + prefix.format(relations[i]) 315 | if len(relations) > 1: 316 | condition += ' ?instance' 317 | else: 318 | condition += ' ?ans' 319 | elif i == 1: 320 | condition += '?instance ' + prefix.format(relations[i]) + ' ?ans' 321 | # else: 322 | # condition += '?instance ' + prefix.format(relations[i]) + ' ' + prefix.format(relations[i]) 323 | if len(relations)-1 != i or (len(relations) > 1 and len(sub_properties) > 0): 324 | condition += '. ' 325 | conditions += condition 326 | idx = 0 327 | if len(relations) > 1: 328 | for key, value in sub_properties: 329 | condition = '?instance ' + prefix.format(key) + ' ' + prefix.format(value) 330 | if len(sub_properties) - 1 > idx: 331 | condition += '. ' 332 | conditions += condition 333 | idx += 1 334 | s = """select ?ans where {%s}""" % (conditions, ) 335 | return s 336 | 337 | 338 | def make_dataset(data_path, target_file, label_file, train=True): 339 | if isinstance(data_path, list): 340 | df = pd.DataFrame(pd.read_excel(data_path[0])) 341 | for i in range(1, len(data_path)): 342 | df = pd.concat([df, pd.DataFrame(pd.read_excel(data_path[i]))], ignore_index=True) 343 | else: 344 | df = pd.DataFrame(pd.read_excel(data_path)) 345 | # # entity_map 346 | # with open('../data/file/entity_map.json', 'r', encoding='utf-8') as f: 347 | # entity_mapping = json.load(f) 348 | # 标签和id的映射 349 | ans_label2id = {} 350 | prop_label2id = {} 351 | entity_label2id = {} 352 | # 开通方式、条件 353 | binary_label2id = {'档位介绍-开通方式': 0, '档位介绍-开通条件': 1} 354 | # 标注数据xlxs转化成txt 355 | if train: 356 | df = df.loc[:, ['用户问题', '答案类型', '属性名', '实体', '答案', '有效率']] 357 | # 未标注数据xlxs转化成txt 358 | else: 359 | length = [] 360 | columns = df.columns 361 | for column in columns: 362 | length.append(len(str(df.iloc[1].at[column]))) 363 | max_id = np.argmax(length) 364 | # df = df.loc[:, ['query']] 365 | df = df.loc[:, [columns[max_id]]] 366 | # 转换数据 367 | with open(target_file, 'w', encoding='utf-8') as f: 368 | for i in range(len(df)): 369 | line = df.loc[i] 370 | line = list(line) 371 | for idx in range(len(line)): 372 | # 大写字母改成小写 373 | line[idx] = str(line[idx]).strip().lower() 374 | if train: 375 | # 答案类型 376 | if line[1] not in ans_label2id: 377 | ans_label2id[line[1]] = len(ans_label2id) 378 | # 属性名 379 | sub_blocks = line[2].split('|') 380 | for j, sub_b in enumerate(sub_blocks): 381 | # 把这几类统一变成”其他“类 382 | # if sub_b in ('适用app', '生效规则', '叠加规则', '封顶规则'): 383 | # sub_b = '其他' 384 | # sub_blocks[j] = sub_b 385 | if sub_b not in prop_label2id: 386 | prop_label2id[sub_b] = len(prop_label2id) 387 | # line[2] = '|'.join(list(set(sub_blocks))) 388 | # # 实体 389 | # if line[3] in entity_mapping: 390 | # line[3] = entity_mapping[line[3]] 391 | if line[3] not in entity_label2id: 392 | entity_label2id[line[3]] = len(entity_label2id) 393 | f.write('\t'.join(line)+'\n') 394 | # 整理标签和id的映射 395 | if train: 396 | if label_file is not None: 397 | with open(label_file, 'w', encoding='utf-8') as f: 398 | json.dump({'ans_type': ans_label2id, 'main_property': prop_label2id, 'entity': entity_label2id, 399 | 'binary_type': binary_label2id}, 400 | fp=f, 401 | ensure_ascii=False, indent=2) 402 | 403 | 404 | def make_dataset_for_binary(data_path, target_file): 405 | """ 406 | 为二分类任务制作数据集 407 | :param data_path: 408 | :param target_file: 409 | :return: 410 | """ 411 | if isinstance(data_path, list): 412 | df = pd.DataFrame(pd.read_excel(data_path[0])) 413 | for i in range(1, len(data_path)): 414 | df = pd.concat([df, pd.DataFrame(pd.read_excel(data_path[i]))], ignore_index=True) 415 | else: 416 | df = pd.DataFrame(pd.read_excel(data_path)) 417 | # 标注数据xlxs转化成txt 418 | df = df.loc[:, ['用户问题', '答案类型', '属性名', '实体', '答案', '有效率']] 419 | method_count, condition_text = 0, [] 420 | # 转换数据 421 | with open(target_file, 'w', encoding='utf-8') as f: 422 | for i in range(len(df)): 423 | line = df.loc[i] 424 | if str(line['用户问题']) in ('20元20g还可以办理吗', ): 425 | continue 426 | props = str(line['属性名']) 427 | line = [str(line['用户问题']), '', str(line['有效率'])] 428 | if '开通方式' in props and '开通条件' not in props: 429 | line[1] = '档位介绍-开通方式' 430 | f.write('\t'.join(line)+'\n') 431 | method_count += 1 432 | elif '开通方式' not in props and '开通条件' in props: 433 | line[1] = '档位介绍-开通条件' 434 | f.write('\t'.join(line)+'\n') 435 | condition_text.append(line) 436 | # 中和结果,重采样 437 | if method_count > len(condition_text): 438 | for i in range(method_count - len(condition_text)): 439 | line = random.choice(condition_text) 440 | f.write('\t'.join(line)+'\n') 441 | 442 | 443 | def label_to_multi_hot(max_len, label_ids): 444 | """ 445 | 转化成multi-hot编码 446 | :param max_len: multi-hot编码的最大长度 447 | :param label_ids: 448 | :return: 449 | """ 450 | multi_hot = [0] * max_len 451 | for idx in label_ids: 452 | multi_hot[idx] = 1 453 | return multi_hot 454 | 455 | 456 | def logits_to_multi_hot(data: torch.Tensor, ans_pred: torch.Tensor, label_hub, threshold=0.5): 457 | """ 458 | logits转化成multi-hot编码 459 | 1 运用规则: 当并列句中,属性名肯定有两个,而其他句子中属性名只有一个 460 | :return: 461 | """ 462 | if data.is_cuda: 463 | data = data.detach().data.cpu().numpy() 464 | else: 465 | data = data.detach().numpy() 466 | result = [] 467 | for i in range(data.shape[0]): 468 | temp = [0] * data.shape[1] 469 | if label_hub.ans_id2label[ans_pred[i].item()] == '并列句': 470 | # 获取分数最高的两个属性 471 | max_idx = np.argmax(data[i], axis=-1) 472 | data[i][max_idx] = -1e5 473 | temp[max_idx] = 1 474 | max_idx = np.argmax(data[i], axis=-1) 475 | temp[max_idx] = 1 476 | else: 477 | # 获取分数最高的1个属性 478 | max_idx = np.argmax(data[i], axis=-1) 479 | temp[max_idx] = 1 480 | result.append(temp) 481 | return np.array(result) 482 | 483 | 484 | def logits_to_multi_hot_old_version(data: torch.Tensor, threshold=0.5): 485 | """ 486 | logits转化成multi-hot编码【没有考虑答案类型的旧版本】 487 | :return: 488 | """ 489 | if data.is_cuda: 490 | data = data.detach().data.cpu().numpy() 491 | else: 492 | data = data.numpy() 493 | result = [] 494 | for i in range(data.shape[0]): 495 | result.append([1 if v >= threshold else 0 for v in list(data[i])]) 496 | return np.array(result) 497 | 498 | 499 | def remove_stop_words(sentence: str, stopwords: List): 500 | """ 501 | 移除停用词 502 | :param sentence: 503 | :param stopwords: 504 | :return: 505 | """ 506 | for word in stopwords: 507 | sentence = sentence.replace(word, '') 508 | return sentence 509 | 510 | 511 | def load_model(model_path, map_location=None): 512 | print('模型加载中...') 513 | model = torch.load(model_path, map_location=map_location) 514 | return model 515 | 516 | 517 | def save_model(model, model_path, model_name, debug='1'): 518 | # debug模式下,需要删除之前同名的模型 519 | if debug == '1': 520 | for file in os.listdir(model_path): 521 | # 删除同一个模型 522 | if model_name[:-20] in file: 523 | os.remove(os.path.join(model_path, file)) 524 | print('模型保存中...') 525 | torch.save(model, os.path.join(model_path, model_name)) 526 | 527 | 528 | def setup_seed(seed): 529 | """ 530 | 确定随机数 531 | :param seed: 种子 532 | :return: 533 | """ 534 | torch.manual_seed(seed) 535 | torch.cuda.manual_seed_all(seed) 536 | random.seed(seed) 537 | np.random.seed(seed) 538 | os.environ['PYTHONHASHSEED'] = str(seed) # 为了禁止hash随机化,使得实验可复现。 539 | 540 | 541 | def split_train_dev(data_size: int, radio=0.8): 542 | """ 543 | 从标注数据中随机划分训练与验证集,返回标注数据中的行号列表 544 | :param data_size: 545 | :param radio: 546 | :return: [[id1, id2, id3...], [idx,... ]] 547 | """ 548 | line_ids = [i for i in range(data_size)] 549 | random.shuffle(line_ids) 550 | train_ids = line_ids[:int(data_size*radio)] 551 | dev_ids = line_ids[int(data_size*radio):] 552 | with open('../data/file/train_dev_ids.json', 'w', encoding='utf-8') as f: 553 | json.dump({'train_ids': train_ids, 'dev_ids': dev_ids}, f, ensure_ascii=False, indent=2) 554 | 555 | 556 | def split_labeled_data(source_file, train_file, dev_file): 557 | """ 558 | 根据训练与验证的ID,写入文件 559 | :param source_file: 560 | :param train_file: 561 | :param dev_file: 562 | :return: 563 | """ 564 | train_dev_ids = json.load(open('../data/file/train_dev_ids.json', 'r', encoding='utf-8')) 565 | train_dev_ids['dev_ids'] = set(train_dev_ids['dev_ids']) 566 | train_data, dev_data = [], [] 567 | with open(source_file, 'r', encoding='utf-8') as f: 568 | for i, line in enumerate(f): 569 | if i in train_dev_ids['dev_ids']: 570 | dev_data.append(line.strip()) 571 | else: 572 | train_data.append(line.strip()) 573 | with open(train_file, 'w', encoding='utf-8') as f: 574 | f.write('\n'.join(train_data)) 575 | with open(dev_file, 'w', encoding='utf-8') as f: 576 | f.write('\n'.join(dev_data)) 577 | 578 | 579 | def get_time_str(): 580 | """ 581 | 返回当前时间戳字符串 582 | :return: 583 | """ 584 | return time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) 585 | 586 | 587 | def get_time_dif(start_time): 588 | """ 589 | 获取已经使用的时间 590 | :param start_time: 591 | :return: 592 | """ 593 | end_time = time.time() 594 | time_dif = end_time - start_time 595 | return timedelta(seconds=int(round(time_dif))) 596 | 597 | 598 | def get_labels(ans_pred: torch.Tensor, prop_pred: np.ndarray, entity_pred: torch.Tensor, label_hub): 599 | """ 600 | 根据标签ID返回标签名称 601 | :param ans_pred: 602 | :param prop_pred: 603 | :param entity_pred: 604 | :param label_hub: 605 | :return: 606 | """ 607 | ans_labels, prop_labels, entity_labels = [], [], [] 608 | for ans_id in ans_pred: 609 | ans_labels.append(label_hub.ans_id2label[ans_id.item()]) 610 | for entity_id in entity_pred: 611 | entity_labels.append(label_hub.entity_id2label[entity_id.item()]) 612 | for line in prop_pred: 613 | temp = [] 614 | for i in range(len(line)): 615 | if line[i] == 1: 616 | temp.append(label_hub.prop_id2label[i]) 617 | prop_labels.append(temp) 618 | return ans_labels, prop_labels, entity_labels 619 | 620 | 621 | def fetch_error_cases(pred_data: Dict, gold_data: Dict, question: Dict): 622 | """ 623 | 获取预测错误的样本 624 | :param pred_data: 625 | :param gold_data: 626 | :return: 627 | """ 628 | time_str = get_time_str() 629 | time_str = time_str.replace(':', ':') 630 | count = 1 631 | with open('../data/results/error_case_{}.txt'.format(time_str), 'w', encoding='utf-8') as f: 632 | for idx, res in pred_data['result'].items(): 633 | ans = gold_data[idx] 634 | ans = set(ans.split('|')) 635 | if '' in ans: 636 | ans.remove('') 637 | res = set(res.split('|')) 638 | if ans != res: 639 | f.write('第{}个错误结果\n'.format(count)) 640 | f.write('问题编号为: {}, 问题为:{}\n'.format(idx, question[idx])) 641 | f.write('正确答案:{}\n'.format(ans)) 642 | f.write('预测答案:{}\n'.format(res)) 643 | f.write('预测中间结果:{}\n'.format(pred_data['model_result'][idx])) 644 | f.write('【问题分析】:【】\n') 645 | f.write('\n') 646 | count += 1 647 | 648 | 649 | def k_fold_data(data: List[Dict], k=5, batch_size=32, seed=1, collate_fn='1'): 650 | """ 651 | data即位DataSet类中的data。将data分成k份,然后组成训练验证集 652 | :param data: 653 | :param k: 654 | :param batch_size: 655 | :param seed: 随机数种子 656 | :collate_fn: 657 | :return: 658 | """ 659 | print('K折数据划分中...') 660 | from data import BertDataset, bert_collate_fn, binary_collate_fn 661 | from torch.utils.data import DataLoader 662 | temp_data = copy.deepcopy(data) 663 | random.seed(seed) 664 | random.shuffle(temp_data) 665 | block_len = len(temp_data)//k 666 | data_blocks = [temp_data[i*block_len: (i+1)*block_len] for i in range(k)] 667 | train_dev_tuples = [] 668 | for i, block in enumerate(data_blocks): 669 | train, dev = [], block 670 | for _ in range(k): 671 | if _ != i: 672 | train += data_blocks[_] 673 | train_set = BertDataset(train_file=None, tokenizer=None, label_hub=None, init=False) 674 | dev_set = BertDataset(train_file=None, tokenizer=None, label_hub=None, init=False) 675 | train_set.data = train 676 | dev_set.data = dev 677 | if collate_fn == '1': 678 | train_dev_tuples.append((DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=bert_collate_fn), 679 | DataLoader(dev_set, batch_size=batch_size, shuffle=True, collate_fn=bert_collate_fn))) 680 | else: 681 | train_dev_tuples.append( 682 | (DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=binary_collate_fn), 683 | DataLoader(dev_set, batch_size=batch_size, shuffle=True, collate_fn=binary_collate_fn))) 684 | return train_dev_tuples 685 | 686 | 687 | def filter_sub_properties(sub_properties: List[Tuple], entity_label: str, ans_label: str): 688 | """ 689 | 过滤 sub_properties 690 | 去掉 实体中流量或价钱被预测成约束的情况 或 去掉重复的约束属性名 691 | :param sub_properties: 692 | :param entity_label: 693 | :param ans_label: 694 | :return: 695 | """ 696 | sub_properties_filter = [] 697 | sub_property_set = set() 698 | for sub_property in sub_properties: 699 | if sub_property[0] not in sub_property_set: 700 | sub_property_set.add(sub_property[0]) 701 | # 重复的约束属性名而且答案类型是属性值 702 | elif ans_label == '属性值': 703 | continue 704 | if sub_property[0] == '子业务': 705 | sub_properties_filter.append(sub_property) 706 | elif sub_property[1] in entity_label: 707 | if entity_label == '1元5gb流量券': # 该实体中 1元 是约束条件 708 | sub_properties_filter.append(sub_property) 709 | else: 710 | continue 711 | else: 712 | sub_properties_filter.append(sub_property) 713 | return sub_properties_filter 714 | 715 | 716 | def reverse_prop(prop_label: List[str]): 717 | """ 718 | 置换标签,在查找答案为空的时候使用 719 | :param prop_label: 720 | :return: 721 | """ 722 | changed = False 723 | for i, prop in enumerate(prop_label): 724 | # 假设以下条件独立 725 | if '方式' in prop: 726 | prop_label[i] = prop.replace('方式', '条件') 727 | changed = True 728 | elif '条件' in prop: 729 | prop_label[i] = prop.replace('条件', '方式') 730 | changed = True 731 | elif '档位介绍-有效期规则' == prop: 732 | prop_label[i] = '生效规则' 733 | changed = True 734 | elif '生效规则' == prop: 735 | prop_label[i] = '档位介绍-有效期规则' 736 | changed = True 737 | return changed 738 | 739 | 740 | def reverse_entity(entity_label: str): 741 | """ 742 | 置换属性名,在查找答案为空的时候使用 743 | :param entity_label: 744 | :return: 745 | """ 746 | if entity_label == '嗨购月包': 747 | return True, '嗨购产品' 748 | elif entity_label == '嗨购产品': 749 | return True, '嗨购月包' 750 | if entity_label == '北京移动plus会员权益卡': 751 | return True, '北京移动plus会员' 752 | elif entity_label == '北京移动plus会员': 753 | return True, '北京移动plus会员权益卡' 754 | else: 755 | return False, '' 756 | 757 | 758 | def rm_symbol(sentence): 759 | import re 760 | return re.sub(',|\.|,|。|?|\?|!|!|:', '', sentence) 761 | 762 | 763 | def judge_cancel(question): 764 | """ 765 | 判断“取消”属性名 766 | :param question: 767 | :return: 768 | """ 769 | import re 770 | neg_word, pos_word = ['不要', '不想', '取消', '不需要', '不用'], ['办理', '需要', '开通'] 771 | sentence_blocks = re.split(',|\.|:|,|。', question) 772 | for block in sentence_blocks: 773 | for n in neg_word: 774 | if n in block: 775 | for p in pos_word: 776 | if p in block: 777 | return True 778 | return False 779 | 780 | 781 | def rules_to_judge_prop(question): 782 | """ 783 | 规则判断属性名:适用app、叠加规则、封顶规则、档位介绍-带宽、档位介绍-有效期规则、生效规则 784 | :param question: 785 | :return: 786 | """ 787 | # 适用app:当前只需要判断是否包含“app”和“适用”关键字 788 | if 'app' in question and '适用' in question: 789 | return '适用app' 790 | # 生效规则:判断包含“生效”并带疑问关键字 791 | if '生效' in question and ('么' in question or '吗' in question or '啥' in question): 792 | return '生效规则' 793 | # 叠加规则:包含“可以叠加”关键字。注意与“叠加包”区别 794 | if '可以叠加' in question and '叠加包' not in question: 795 | return '叠加规则' 796 | # 封顶规则:包含“限速”或“上限”关键字,不能包含“解除”、”恢复“关键字 797 | if ('限速' in question or '上限' in question) and ('解除' not in question and '恢复' not in question): 798 | return '封顶规则' 799 | # 有效期规则 800 | if ('到期' in question or '有效期' in question) and ('办理' not in question and '取消' not in question and '关闭' not in question): 801 | return '档位介绍-有效期规则' 802 | # 带宽 803 | if '通带宽' in question or '网速' in question: 804 | return '档位介绍-带宽' 805 | return None 806 | 807 | 808 | def rules_to_judge_entity(question, predict_entity): 809 | """ 810 | 规则来判断实体 811 | :param question: 812 | :param predict_entity 813 | :return: 814 | """ 815 | if predict_entity == '畅享套餐': 816 | if '升档' in question: 817 | return '新畅享套餐升档优惠活动' 818 | elif '促销' in question or '78元无限流量套餐' in question: 819 | return '畅享套餐促销优惠活动' 820 | elif '新全球通' in question: 821 | return '新全球通畅享套餐' 822 | elif '首充活动' in question: 823 | return '畅享套餐首充活动' 824 | elif predict_entity == '移动王卡': 825 | if '惠享合约' in question: 826 | return '移动王卡惠享合约' 827 | elif predict_entity == '5g畅享套餐': 828 | if '合约版' in question: 829 | return '5g畅享套餐合约版' 830 | elif predict_entity == '30元5gb包': 831 | if '半价' in question: 832 | return '30元5gb半价体验版' 833 | elif predict_entity == '移动花卡': 834 | if '新春升级' in question: 835 | return '移动花卡新春升级版' 836 | elif predict_entity == '随心看会员': 837 | if '合约版' in question: 838 | return '随心看会员合约版' 839 | elif predict_entity == '北京移动plus会员': 840 | if '权益卡' in question: 841 | return '北京移动plus会员权益卡' 842 | elif predict_entity == '5g智享套餐': 843 | if '合约版' in question: 844 | return '5g智享套餐合约版' 845 | elif '家庭版' in question: 846 | return '5g智享套餐家庭版' 847 | elif predict_entity == '全国亲情网': 848 | if '亲情网免费' in question: 849 | return '全国亲情网功能费优惠活动' 850 | elif predict_entity == '精灵卡': 851 | if '首充' in question and '优惠' in question: 852 | return '精灵卡首充优惠活动' 853 | # 合并的样本外的样本 854 | else: 855 | if '无忧' in question: 856 | return '流量无忧包' 857 | return None 858 | 859 | 860 | class MyLogger(object): 861 | def __init__(self, log_file, debug='0'): 862 | self.ch = logging.StreamHandler() 863 | self.formatter = logging.Formatter("%(asctime)s - %(message)s") 864 | self.fh = logging.FileHandler(log_file, mode='w') 865 | self.logger = logging.getLogger() 866 | self.debug = debug 867 | self.init() 868 | 869 | def init(self): 870 | self.logger.setLevel(logging.INFO) 871 | # 输出到文件 872 | self.fh.setLevel(logging.INFO) 873 | self.fh.setFormatter(self.formatter) 874 | # 输出到控制台 875 | self.ch.setLevel(logging.INFO) 876 | self.ch.setFormatter(self.formatter) 877 | self.logger.addHandler(self.ch) 878 | self.logger.addHandler(self.fh) 879 | 880 | def log(self, message): 881 | if self.debug == '1': 882 | self.logger.info(message) 883 | 884 | 885 | def generate_false_label(predict_file): 886 | """ 887 | 根据预测的结果反向生成同名伪标签xlsx 888 | :param predict_file 889 | :return 890 | """ 891 | result_path = '../data/results/' 892 | with open(result_path+predict_file, 'r', encoding='utf-8') as f: 893 | predict_dict = json.load(f) 894 | predict_results = [] 895 | for i, predict_entity in predict_dict['model_result'].items(): 896 | instance = [] 897 | for key, value in predict_entity.items(): 898 | if key == 'main_property': 899 | instance.append('|'.join(value)) 900 | elif key == 'sub_properties': 901 | cons_prop, cons_value = [], [] 902 | for pair in value: 903 | cons_prop.append(pair[0]) 904 | cons_value.append(str(pair[1])) 905 | instance.append('|'.join(cons_prop)) 906 | instance.append('|'.join(cons_value)) 907 | else: 908 | if key == 'operator' and value == 'other': 909 | instance.append('') 910 | else: 911 | instance.append(value) 912 | instance.append(predict_dict['result'][i]) 913 | predict_results.append(instance) 914 | header = ['用户问题', '答案类型', '实体', '属性名', '约束算子', '约束属性名', '约束属性值', '答案'] 915 | pd.DataFrame(predict_results, index=None, columns=header).to_excel(result_path+predict_file[:-5]+'.xlsx', index=False) 916 | print('预测结果写入xlsx成功') 917 | 918 | 919 | def view_gpu_info(gpu_id: int): 920 | import pynvml 921 | pynvml.nvmlInit() 922 | handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id) 923 | mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) 924 | # 这里是字节bytes,所以要想得到以兆M为单位就需要除以1024**2 925 | print('当前为第{}号显卡,总显存数:{}MB;显存使用数:{}MB;显存剩余数:{}MB'.format(gpu_id, mem_info.total/1024**2, mem_info.used/1024**2, mem_info.free/1024**2)) # 第二块显卡总的显存大小 926 | return mem_info.total/1024**2, mem_info.used/1024**2, mem_info.free/1024**2 927 | 928 | 929 | def max_count(data: dict, k=1): 930 | """ 931 | 对于存在"标签"->"个数"的字典data,得到最大数量的标签 932 | :param data: 933 | :param k: 前k个统计 934 | :return: 935 | """ 936 | max_count_, la = 0, None 937 | second_count_, la2 = 0, None 938 | for lab, c in data.items(): 939 | if c > max_count_: 940 | max_count_ = c 941 | la = lab 942 | # 只有在标签个数至少为2时才有第二多的标签,否则直接返回空结果 943 | if k == 2 and len(data) > 1: 944 | for lab, c in data.items(): 945 | if c > second_count_ and lab != la: 946 | second_count_ = c 947 | la2 = lab 948 | if k == 1: 949 | return la, max_count_ 950 | else: 951 | return la, max_count_, la2, second_count_ 952 | 953 | 954 | def vote_integration(ans_logits: List[torch.Tensor], 955 | prop_logits: List[torch.Tensor], 956 | entity_logits: List[torch.Tensor], 957 | label_hub, 958 | batch_size: int): 959 | """ 960 | 模型投票融合 961 | :param ans_logits: 962 | :param prop_logits: 963 | :param entity_logits: 964 | :param label_hub: 965 | :param batch_size: 966 | :return: 最终的标签 967 | """ 968 | vote_method = 2 969 | mid_results = {'ans': [], 'prop': [], 'entity': []} 970 | for i in range(batch_size): 971 | mid_results['ans'].append({}) 972 | mid_results['prop'].append({}) 973 | mid_results['entity'].append({}) 974 | ans_results, prop_results, entity_results = [], [], [] 975 | for i in range(len(ans_logits)): 976 | ans_pred = ans_logits[i].data.cpu().argmax(dim=-1) 977 | prop_pred = logits_to_multi_hot(prop_logits[i], ans_pred, label_hub) 978 | entity_pred = entity_logits[i].data.cpu().argmax(dim=-1) 979 | ans_labels, prop_labels, entity_labels = get_labels(ans_pred, prop_pred, entity_pred, label_hub) 980 | for j in range(batch_size): 981 | if ans_labels[j] not in mid_results['ans'][j]: 982 | mid_results['ans'][j][ans_labels[j]] = 0 983 | mid_results['ans'][j][ans_labels[j]] += 1 984 | # prop投票1: key是模型预测的直接结果 985 | if vote_method == 1: 986 | prop_label = tuple(sorted(prop_labels[j])) 987 | if prop_label not in mid_results['prop'][j]: 988 | mid_results['prop'][j][prop_label] = 0 989 | mid_results['prop'][j][prop_label] += 1 990 | # prop投票2: key是模型预测的每个个体标签 991 | else: 992 | prop_label = prop_labels[j] 993 | for lab in prop_label: 994 | if lab not in mid_results['prop'][j]: 995 | mid_results['prop'][j][lab] = 0 996 | mid_results['prop'][j][lab] += 1 997 | if entity_labels[j] not in mid_results['entity'][j]: 998 | mid_results['entity'][j][entity_labels[j]] = 0 999 | mid_results['entity'][j][entity_labels[j]] += 1 1000 | for instance in mid_results['ans']: 1001 | ans_results.append(max_count(instance)[0]) 1002 | if vote_method == 1: 1003 | for i, instance in enumerate(mid_results['prop']): 1004 | prop_results.append(max_count(instance)[0]) 1005 | else: 1006 | for i, instance in enumerate(mid_results['prop']): 1007 | ans = ans_results[i] 1008 | la, max_count_, la2, second_count_ = max_count(instance, k=2) 1009 | if ans == '并列句': 1010 | if second_count_ <= len(ans_logits)//3: 1011 | print('-------------------出现了不应该发生的情况-------------------') 1012 | ans_results[i] = '属性值' 1013 | prop_results.append([la]) 1014 | else: 1015 | prop_results.append([la, la2]) 1016 | else: 1017 | prop_results.append([la]) 1018 | if second_count_ > len(ans_logits)//2: 1019 | prop_results[-1].append(la2) 1020 | for instance in mid_results['entity']: 1021 | entity_results.append(max_count(instance)[0]) 1022 | return ans_results, prop_results, entity_results 1023 | 1024 | 1025 | def average_integration(ans_logits: List[torch.Tensor], 1026 | prop_logits: List[torch.Tensor], 1027 | entity_logits: List[torch.Tensor], 1028 | label_hub): 1029 | """ 1030 | 模型平均融合 1031 | :param ans_logits: 1032 | :param prop_logits: 1033 | :param entity_logits: 1034 | :param label_hub: 1035 | :return: 最终的标签 1036 | """ 1037 | local_ans_logits, local_prop_logits, local_entity_logits = torch.stack(ans_logits).mean(dim=0), torch.stack(prop_logits).mean(dim=0), torch.stack(entity_logits).mean(dim=0) 1038 | ans_pred = local_ans_logits.data.cpu().argmax(dim=-1) 1039 | prop_pred = logits_to_multi_hot(local_prop_logits, ans_pred, label_hub) 1040 | entity_pred = local_entity_logits.data.cpu().argmax(dim=-1) 1041 | ans_labels, prop_labels, entity_labels = get_labels(ans_pred, prop_pred, entity_pred, label_hub) 1042 | return ans_labels, prop_labels, entity_labels 1043 | 1044 | 1045 | def binary_average_integration(logits: List[torch.Tensor], label_hub): 1046 | """ 1047 | 二进制模型的平均融合 1048 | :param logits: 1049 | :param label_hub: 1050 | :return: 1051 | """ 1052 | local_logits = torch.stack(logits).mean(dim=0) 1053 | pred = local_logits.data.cpu().argmax(dim=-1) 1054 | labels = [] 1055 | for label_id in pred: 1056 | labels.append(label_hub.binary_id2label[label_id.item()]) 1057 | return labels 1058 | 1059 | 1060 | def tqdm_with_debug(data, debug=None): 1061 | """ 1062 | 将tqdm加入debug模式 1063 | :param data: 1064 | :param debug: 1065 | :return: 1066 | """ 1067 | from tqdm import tqdm 1068 | if debug == '1' or debug is True: 1069 | if isinstance(data, enumerate): 1070 | temp_data = list(data) 1071 | else: 1072 | temp_data = data 1073 | len_data = len(temp_data) 1074 | ans = tqdm(temp_data, total=len_data) 1075 | return ans 1076 | else: 1077 | return data 1078 | 1079 | 1080 | # def entity_map(): 1081 | # a = {'新畅享套餐升档优惠活动': '畅享套餐', '畅享套餐促销优惠活动': '畅享套餐', '畅享套餐促销优惠': '畅享套餐', '新全球通畅享套餐': '畅享套餐', '畅享套餐首充活动': '畅享套餐', 1082 | # '移动王卡惠享合约': '移动王卡', '5g畅享套餐合约版': '5g畅享套餐', '30元5gb半价体验版': '30元5gb包', '移动花卡新春升级版': '移动花卡', '随心看会员合约版': '随心看会员', '北京移动plus会员权益卡': '北京移动plus会员', 1083 | # '5g智享套餐合约版': '5g智享套餐', '5g智享套餐家庭版': '5g智享套餐', 1084 | # '全国亲情网功能费优惠活动': '全国亲情网', '精灵卡首充优惠活动': '精灵卡'} 1085 | # with open('../data/file/entity_map.json', 'w', encoding='utf-8') as f: 1086 | # json.dump(a, f, ensure_ascii=False, indent=2) 1087 | 1088 | 1089 | if __name__ == '__main__': 1090 | # parse_triples_file('../data/raw_data/triples.txt') 1091 | # print(to_sparql(entity='视频会员通用流量月包', main_property='档位介绍-上线时间', sub_properties={'价格': '70', '子业务': '优酷会员'})) 1092 | # make_dataset(['../data/raw_data/train_augment_few_nlpcda.xlsx', 1093 | # '../data/raw_data/train_augment_simbert.xlsx', 1094 | # '../data/raw_data/train_augment_synonyms.xlsx'], 1095 | # target_file='../data/dataset/augment3.txt', 1096 | # label_file=None, 1097 | # train=True) 1098 | make_dataset_for_binary(['../data/raw_data/train_denoised.xlsx'], 1099 | target_file='../data/dataset/binary_labeled.txt') 1100 | make_dataset_for_binary(['../data/raw_data/train_augment_few_nlpcda.xlsx', 1101 | '../data/raw_data/train_augment_simbert.xlsx', 1102 | '../data/raw_data/train_augment_synonyms.xlsx'], 1103 | target_file='../data/dataset/binary_augment3.txt') 1104 | # make_dataset(['../data/raw_data/test2_denoised.xlsx'], 1105 | # target_file='../data/dataset/cls_unlabeled2.txt', 1106 | # label_file=None, 1107 | # train=False) 1108 | # split_train_dev(5000) 1109 | # split_labeled_data(source_file='../data/dataset/cls_labeled.txt', 1110 | # train_file='../data/dataset/cls_train.txt', 1111 | # dev_file='../data/dataset/cls_dev.txt') 1112 | # ------------------获取错误答案详情--------------- 1113 | # with open('../data/results/ans_dev.json', 'r', encoding='utf-8') as f: 1114 | # pred = json.load(f) 1115 | # with open('../data/dataset/cls_dev.txt', 'r', encoding='utf-8') as f: 1116 | # answer, question = {}, {} 1117 | # for i, line in enumerate(f): 1118 | # ans = line.strip().split('\t')[-1] 1119 | # answer[str(i)] = ans 1120 | # question[str(i)] = line.strip().split('\t')[0] 1121 | # fetch_error_cases(pred, answer, question) 1122 | # generate_false_label('ensemble_bert_aug2_use_efficiency_2021-07-20-08-29-22_seed1_fi0_gpu2080_be81_0.95411990.pth.json') 1123 | # make_dataset(['../data/raw_data/test2_denoised.xlsx'], 1124 | # target_file='../data/dataset/cls_unlabeled.txt', 1125 | # label_file=None, 1126 | # train=False) 1127 | # entity_map() 1128 | --------------------------------------------------------------------------------