├── docs ├── rule.xlsx └── Task1_Submission_Guide.pdf ├── dataSet ├── 往年数据.zip ├── dic │ ├── province.txt │ ├── city.txt │ └── railway_station.txt └── process.py ├── run.sh ├── sample ├── random_out.py ├── extract_features.py ├── sample.json └── evaluation.py ├── .gitignore ├── README.md ├── convert_tf_checkpoint_to_pytorch.py ├── domain_rule.py ├── optimization.py ├── tokenization.py ├── run_classifier_dataset_utils.py ├── rule.py ├── run_classifier.py └── modeling.py /docs/rule.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ygwpz/SMP2019-ECDT-NLU/HEAD/docs/rule.xlsx -------------------------------------------------------------------------------- /dataSet/往年数据.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ygwpz/SMP2019-ECDT-NLU/HEAD/dataSet/往年数据.zip -------------------------------------------------------------------------------- /docs/Task1_Submission_Guide.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ygwpz/SMP2019-ECDT-NLU/HEAD/docs/Task1_Submission_Guide.pdf -------------------------------------------------------------------------------- /dataSet/dic/province.txt: -------------------------------------------------------------------------------- 1 | 河北省 2 | 山东省 3 | 辽宁省 4 | 黑龙江省 5 | 甘肃省 6 | 吉林省 7 | 青海省 8 | 河南省 9 | 江苏省 10 | 湖北省 11 | 湖南省 12 | 浙江省 13 | 江西省 14 | 广东省 15 | 云南省 16 | 福建省 17 | 台湾省 18 | 海南省 19 | 山西省 20 | 四川省 21 | 陕西省 22 | 贵州省 23 | 安徽省 24 | 河北 25 | 山东 26 | 辽宁 27 | 黑龙江 28 | 甘肃 29 | 吉林 30 | 青海 31 | 河南 32 | 江苏 33 | 湖北 34 | 湖南 35 | 浙江 36 | 江西 37 | 广东 38 | 云南 39 | 福建 40 | 台湾 41 | 海南 42 | 山西 43 | 四川 44 | 陕西 45 | 贵州 46 | 安徽 47 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 python3 run_classifier.py \ 2 | --task_name NLU \ 3 | --train_data dataSet/train.json \ 4 | --bert_model baidu_ernie/pytorch_model.bin \ 5 | --config baidu_ernie/config.json \ 6 | --vocab baidu_ernie/vocab.txt \ 7 | --max_seq_length 32 \ 8 | --train_batch_size 8 \ 9 | --learning_rate 6e-5 \ 10 | --num_train_epochs 9.0 \ 11 | --output_dir result/ \ 12 | --result_file test_result.json \ 13 | --dic_dir dataSet/dic \ 14 | --overwrite_output_dir \ 15 | --do_train 16 | -------------------------------------------------------------------------------- /sample/random_out.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 8/5/19 4:12 PM 4 | # @Author : zchai 5 | 6 | # -*- coding: utf-8 -*- 7 | import sys 8 | 9 | import random 10 | 11 | from extract_features import get_features 12 | 13 | 14 | 15 | ''' 16 | Guess a label in random as a base line 17 | ''' 18 | def random_guess(): 19 | data_random = {} 20 | 21 | data_random['domain'] = random.sample(domain_value_list, 1)[0] 22 | data_random['intent'] = random.sample(intent_value_list, 1)[0] 23 | slot = {} 24 | slot[random.sample(slots_key_list, 1)[0]] = random.sample(slots_value_list, 1)[0] 25 | data_random['slots'] = slot 26 | 27 | return data_random 28 | 29 | 30 | if __name__ == '__main__': 31 | import json 32 | dev_dct = json.load(open(sys.argv[1]), encoding='utf8') 33 | 34 | domain_value_list, intent_value_list, slots_key_list, slots_value_list = get_features(sys.argv[1]) 35 | 36 | rguess_dct = [] 37 | for dev_data in dev_dct: 38 | text_dic = {"text": dev_data['text']} 39 | rguess_dct.append(dict(text_dic, **random_guess())) 40 | json.dump(rguess_dct, open(sys.argv[2], 'w', encoding='utf8'), ensure_ascii=False) -------------------------------------------------------------------------------- /sample/extract_features.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 8/5/19 4:13 PM 4 | # @Author : zchai 5 | import json 6 | 7 | 8 | def get_features(file_path): 9 | with open(file_path, 'r') as f: 10 | data_list = json.load(f) 11 | 12 | domain_value_list = [] 13 | intent_value_list = [] 14 | slots_key_list = [] 15 | slots_value_list = [] 16 | for data in data_list: 17 | domain_value = data['domain'] 18 | intent_value = data['intent'] 19 | if 'slots' not in data.keys(): 20 | print(data) 21 | continue 22 | else: 23 | slots = data['slots'] 24 | if type(slots) != dict: 25 | slots = {} 26 | print(slots) 27 | continue 28 | 29 | if domain_value not in domain_value_list: 30 | domain_value_list.append(domain_value) 31 | 32 | if intent_value not in intent_value_list: 33 | intent_value_list.append(intent_value) 34 | 35 | for key, value in slots.items(): 36 | if key not in slots_key_list: 37 | slots_key_list.append(key) 38 | if value not in slots_value_list: 39 | slots_value_list.append(value) 40 | 41 | return domain_value_list, intent_value_list, slots_key_list, slots_value_list -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | baidu_ernie/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | SMP2019 ECDT 中文人机对话技术测评 任务一 自然语言理解 2 | =================================== 3 | 4 | 任务说明: http://conference.cipsc.org.cn/smp2019/evaluation.html 5 | 6 | LeaderBoard: https://adamszq.github.io/smp2019ecdt_task1/ 7 | 8 | 9 | Process 10 | --------- 11 | 本次比赛的主要工作在于预训练模型的选择 + 分析数据并编写规则。 12 | 目前中文已有的开源预训练模型有3种,google的bert-chinese[1],百度的ernie[2]以及哈工大的中文bert[3]。关于三者的比较可以参见[3]。 13 | 14 | 主要基于百度开源的ernie进行微调,ernie和bert是一样的架构,只是训练方式不同,主要采用whole word masking模型,加入了百度贴吧等非正式场合的语料,对于日常对话这种非正式语料有更好的效果,不过需要采用更大的初始学习率,在槽位识别上会有一个较大的提升。模型架构不算复杂,感兴趣可以看看modeling.py里的BertForTaskNLU,loss为三者loss相加(实验发现三者之间存在关联,联合训练会使网络收敛的更好)。 15 | 16 | 该比赛17、18年都有举办,区别在于前两年只有意图和领域识别,没有槽位标注,但是数据量非常可观,而且是和本次比赛数据来自同一批数据当中的随机采样,大家可以把这些数据利用起来。 17 | 18 | 19 | Code Framework 20 | --------- 21 | * baidu_ernie/: 百度开源的ernie 22 | * dataSet/: 23 | * train.json: 官方给定的训练数据 24 | * dic/: 自己搜集制作的部分槽位字典(用于规则) 25 | * process.py: 当时做数据分析时随手敲的代码,想分析数据的可以看一下(有点乱) 26 | * 往年数据.zip: 17、18年SMP-ECDT官方数据 27 | * docs/: 28 | * rule.xlsx: 个人对数据做出的一些分析,用excel做成表格形式 29 | * sample/: 提交和评估参考代码 30 | * convert_tf_checkpoint_to_pytorch.py: 将tensorflow模型转成pytorch可读参数 31 | * modeling.py: 模型架构(重点关注BertForTaskNLU),这是本项目用到的模型 32 | * optimization.py: 优化器(周期学习率 + warmup_step) 33 | * rule.py: 针对此数据写的一些规则(乱是乱了点...) 34 | * run_classifier.py: 可以理解为main函数 35 | * run_classifier_dataset_utils.py: 数据处理部分 36 | * tokenization.py: 分词和建词表 37 | 38 | 39 | Enviroment 40 | --------- 41 | Python3.5 42 | 43 | pytorch1.0.0 44 | 45 | GPU(模型3G左右显存就够了) 46 | 47 | 48 | Usage 49 | --------- 50 | 先下载预训练好的ernie模型(这里直接用了 [4]已经转好的模型): https://pan.baidu.com/s/1I7kKVlZN6hl-sUbnvttJzA 提取码:iq74 51 | 52 | bash run.sh 0 (0是你想使用的GPU编号) 53 | 54 | 55 | References: 56 | --------- 57 | [1] https://github.com/google-research/bert 58 | 59 | [2] https://github.com/PaddlePaddle/LARK/tree/develop/ERNIE 60 | 61 | [3] https://github.com/ymcui/Chinese-BERT-wwm 62 | 63 | [4] https://github.com/ArthurRizar/tensorflow_ernie 64 | -------------------------------------------------------------------------------- /convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import re 23 | import argparse 24 | import tensorflow as tf 25 | import torch 26 | import numpy as np 27 | 28 | from modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert 29 | 30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 31 | # Initialise PyTorch model 32 | config = BertConfig.from_json_file(bert_config_file) 33 | print("Building PyTorch model from configuration: {}".format(str(config))) 34 | model = BertForPreTraining(config) 35 | 36 | # Load weights from tf checkpoint 37 | load_tf_weights_in_bert(model, tf_checkpoint_path) 38 | 39 | # Save pytorch-model 40 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 41 | torch.save(model.state_dict(), pytorch_dump_path) 42 | 43 | 44 | if __name__ == "__main__": 45 | 46 | # cmd: python3 convert_tf_checkpoint_to_pytorch.py --tf_checkpoint_path baidu_ernie --bert_config_file baidu_ernie/bert_config.json --pytorch_dump_path ernie_torch/bert_model.ckpt 47 | 48 | parser = argparse.ArgumentParser() 49 | ## Required parameters 50 | parser.add_argument("--tf_checkpoint_path", 51 | default = None, 52 | type = str, 53 | required = True, 54 | help = "Path the TensorFlow checkpoint path.") 55 | parser.add_argument("--bert_config_file", 56 | default = None, 57 | type = str, 58 | required = True, 59 | help = "The config json file corresponding to the pre-trained BERT model. \n" 60 | "This specifies the model architecture.") 61 | parser.add_argument("--pytorch_dump_path", 62 | default = None, 63 | type = str, 64 | required = True, 65 | help = "Path to the output PyTorch model.") 66 | args = parser.parse_args() 67 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 68 | args.bert_config_file, 69 | args.pytorch_dump_path) 70 | 71 | -------------------------------------------------------------------------------- /sample/sample.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "text": "请帮我打开uc", 4 | "domain": "app", 5 | "intent": "LAUNCH", 6 | "slots": { 7 | "name": "uc" 8 | } 9 | }, 10 | { 11 | "text": "链家地产打开QQ", 12 | "domain": "app", 13 | "intent": "LAUNCH", 14 | "slots": { 15 | "name": "qq" 16 | } 17 | }, 18 | { 19 | "text": "打开汽车之家", 20 | "domain": "app", 21 | "intent": "LAUNCH", 22 | "slots": { 23 | "name": "汽车之家" 24 | } 25 | }, 26 | { 27 | "text": "打开QQ通讯", 28 | "domain": "app", 29 | "intent": "LAUNCH", 30 | "slots": { 31 | "name": "qq" 32 | } 33 | }, 34 | { 35 | "text": "帮我打开人人", 36 | "domain": "app", 37 | "intent": "LAUNCH", 38 | "slots": { 39 | "name": "人人" 40 | } 41 | }, 42 | { 43 | "text": "开微信", 44 | "domain": "app", 45 | "intent": "LAUNCH", 46 | "slots": { 47 | "name": "微信" 48 | } 49 | }, 50 | { 51 | "text": "黎宇我要玩中国象棋", 52 | "domain": "app", 53 | "intent": "LAUNCH", 54 | "slots": { 55 | "name": "中国象棋" 56 | } 57 | }, 58 | { 59 | "text": "给我打开一下qq", 60 | "domain": "app", 61 | "intent": "LAUNCH", 62 | "slots": { 63 | "name": "qq" 64 | } 65 | }, 66 | { 67 | "text": "帮忙打开一下酷狗音乐播放音乐行不", 68 | "domain": "app", 69 | "intent": "LAUNCH", 70 | "slots": { 71 | "name": "酷狗音乐" 72 | } 73 | }, 74 | { 75 | "text": "百度浏览器打开", 76 | "domain": "app", 77 | "intent": "LAUNCH", 78 | "slots": { 79 | "name": "百度浏览器" 80 | } 81 | }, 82 | { 83 | "text": "搜索手机办公软件", 84 | "domain": "app", 85 | "intent": "QUERY", 86 | "slots": { 87 | "name": "办公软件" 88 | } 89 | }, 90 | { 91 | "text": "凯立德", 92 | "domain": "app", 93 | "intent": "LAUNCH", 94 | "slots": { 95 | "name": "凯立德" 96 | } 97 | }, 98 | { 99 | "text": "打开相机这", 100 | "domain": "app", 101 | "intent": "LAUNCH", 102 | "slots": { 103 | "name": "相机" 104 | } 105 | }, 106 | { 107 | "text": "打开qq同步助手", 108 | "domain": "app", 109 | "intent": "LAUNCH", 110 | "slots": { 111 | "name": "qq同步助手" 112 | } 113 | }, 114 | { 115 | "text": "打开淘宝购物", 116 | "domain": "app", 117 | "intent": "LAUNCH", 118 | "slots": { 119 | "name": "淘宝购物" 120 | } 121 | }, 122 | { 123 | "text": "帮我找到微信", 124 | "domain": "app", 125 | "intent": "QUERY", 126 | "slots": { 127 | "name": "微信" 128 | } 129 | }, 130 | { 131 | "text": "打开uc二哦", 132 | "domain": "app", 133 | "intent": "LAUNCH", 134 | "slots": { 135 | "name": "uc" 136 | } 137 | }, 138 | { 139 | "text": "开启qq", 140 | "domain": "app", 141 | "intent": "LAUNCH", 142 | "slots": { 143 | "name": "qq" 144 | } 145 | }, 146 | { 147 | "text": "查询许昌到中山的汽车。", 148 | "domain": "bus", 149 | "intent": "QUERY", 150 | "slots": { 151 | "Dest": "中山", 152 | "Src": "许昌" 153 | } 154 | }, 155 | { 156 | "text": "无锡到阜阳怎么坐汽车?", 157 | "domain": "bus", 158 | "intent": "QUERY", 159 | "slots": { 160 | "Dest": "阜阳", 161 | "Src": "无锡" 162 | } 163 | }, 164 | { 165 | "text": "去深圳怎么坐车?", 166 | "domain": "map", 167 | "intent": "ROUTE", 168 | "slots": { 169 | "endLoc_city": "深圳" 170 | } 171 | }, 172 | { 173 | "text": "南京到,黄山的汽车。", 174 | "domain": "bus", 175 | "intent": "QUERY", 176 | "slots": { 177 | "Dest": "黄山", 178 | "Src": "南京" 179 | } 180 | }, 181 | { 182 | "text": "从无锡市到西安市的汽车。", 183 | "domain": "bus", 184 | "intent": "QUERY", 185 | "slots": { 186 | "Dest": "西安市", 187 | "Src": "无锡市" 188 | } 189 | }, 190 | { 191 | "text": "在合肥怎么坐去南京的汽车", 192 | "domain": "bus", 193 | "intent": "QUERY", 194 | "slots": { 195 | "Dest": "南京", 196 | "Src": "合肥" 197 | } 198 | } 199 | ] -------------------------------------------------------------------------------- /sample/evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 9/5/19 10:54 AM 4 | # @Author : zchai 5 | # -*- coding: utf-8 -*- 6 | import json 7 | import codecs 8 | import sys 9 | 10 | ''' 11 | Calculate the sentence accuracy 12 | Json file format: { 13 | "text": "", 14 | "domain": "", 15 | "intent": "", 16 | "slots": { 17 | "name": "" 18 | } 19 | } 20 | ''' 21 | def sentence_acc(truth_dict_list, pred_dict_list): 22 | assert len(truth_dict_list) == len(pred_dict_list) 23 | 24 | acc_num = 0 25 | total_num = len(truth_dict_list) 26 | for truth_dic, pred_dic in zip(truth_dict_list, pred_dict_list): 27 | 28 | # Determine if the domain and intent are correct 29 | if truth_dic['domain'] != pred_dic['domain'] \ 30 | or truth_dic['intent'] != pred_dic['intent'] \ 31 | or len(truth_dic['slots']) != len(pred_dic['slots']): 32 | continue 33 | else: 34 | # Determine if the slots_key and slots_value are correct 35 | flag = True 36 | for key, value in truth_dic['slots'].items(): 37 | if key not in pred_dic['slots']: 38 | flag = False 39 | break # if there is a key not in predict, flag set as false 40 | elif pred_dic['slots'][key] != truth_dic['slots'][key]: 41 | flag = False # if one not match, flag set as false 42 | break 43 | 44 | if flag: 45 | acc_num += 1 46 | 47 | return float(acc_num) / float(total_num) 48 | 49 | def domain_acc(truth_dict_list, pred_dict_list): 50 | assert len(truth_dict_list) == len(pred_dict_list) 51 | acc_num = 0 52 | total_num = len(truth_dict_list) 53 | for truth_dic, pred_dic in zip(truth_dict_list, pred_dict_list): 54 | if truth_dic['domain'] == pred_dic['domain']: 55 | acc_num += 1 56 | 57 | return float(acc_num) / float(total_num) 58 | 59 | 60 | def intent_acc(truth_dict_list, pred_dict_list): 61 | assert len(truth_dict_list) == len(pred_dict_list) 62 | acc_num = 0 63 | total_num = len(truth_dict_list) 64 | for truth_dic, pred_dic in zip(truth_dict_list, pred_dict_list): 65 | if truth_dic['intent'] == pred_dic['intent'] and truth_dic['domain'] == pred_dic['domain']: 66 | acc_num += 1 67 | 68 | return float(acc_num) / float(total_num) 69 | 70 | def slots_acc(truth_dict_list, pred_dict_list): 71 | assert len(truth_dict_list) == len(pred_dict_list) 72 | acc_num = 0 73 | total_num = 0 74 | for truth_dic, pred_dic in zip(truth_dict_list, pred_dict_list): 75 | total_num += len(truth_dic['slots']) 76 | for key, value in truth_dic['slots'].items(): 77 | if key not in pred_dic['slots']: 78 | continue 79 | elif pred_dic['slots'][key] == truth_dic['slots'][key]: 80 | acc_num+=1 81 | 82 | return float(acc_num) / float(total_num) 83 | 84 | def slots_f(truth_dict_list, pred_dict_list): 85 | assert len(truth_dict_list) == len(pred_dict_list) 86 | correct, p_denominator, r_denominator = 0, 0, 0 87 | for truth_dic, pred_dic in zip(truth_dict_list, pred_dict_list): 88 | r_denominator += len(truth_dic['slots']) 89 | p_denominator += len(pred_dic['slots']) 90 | for key, value in truth_dic['slots'].items(): 91 | if key not in pred_dic['slots']: 92 | continue 93 | elif pred_dic['slots'][key] == truth_dic['slots'][key] and \ 94 | truth_dic['domain'] == pred_dic['domain'] and \ 95 | truth_dic['intent'] == pred_dic['intent']: 96 | correct += 1 97 | precision = float(correct) / p_denominator 98 | recall = float(correct) / r_denominator 99 | f1 = 2 * precision * recall / (precision + recall) * 1.0 100 | 101 | return f1 102 | 103 | if __name__ == '__main__': 104 | if len(sys.argv) < 3: 105 | print('Too few args for this script') 106 | exit(1) 107 | 108 | with codecs.open(sys.argv[1], 'r', encoding='utf-8') as f: 109 | fp_truth = json.loads(f.read()) 110 | 111 | with codecs.open(sys.argv[2], 'r', encoding='utf-8') as f_pred: 112 | fp_pred = json.loads(f_pred.read()) 113 | 114 | domain_accuracy = domain_acc(fp_truth, fp_pred) 115 | intent_accuracy = intent_acc(fp_truth, fp_pred) 116 | slots_f = slots_f(fp_truth, fp_pred) 117 | 118 | sentence_accuracy = sentence_acc(fp_truth, fp_pred) 119 | 120 | print('Domain sentence accuracy : %f' % domain_accuracy) 121 | print('Intent sentence accuracy : %f' % intent_accuracy) 122 | print('Slots f score : %f' % slots_f) 123 | print('Avg sentence accuracy : %f' % sentence_accuracy) 124 | 125 | -------------------------------------------------------------------------------- /dataSet/process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from collections import OrderedDict 4 | 5 | data_json = json.load(open('train.json', encoding='utf8'), object_pairs_hook=OrderedDict) 6 | data = {} 7 | for line in data_json: 8 | if line['domain'] not in data: 9 | data[line['domain']] = {} 10 | if line['intent'] not in data[line['domain']]: 11 | if(line['intent'] != line['intent']): 12 | line['intent'] = 'NaN' 13 | data[line['domain']][line['intent']] = [] 14 | data[line['domain']][line['intent']].append([line['text'], line['slots']]) 15 | 16 | data_pro = [] 17 | for k, v in data.items(): 18 | for kk, vv in v.items(): 19 | for vvv in vv: 20 | d = OrderedDict() 21 | d['text'] = vvv[0] 22 | d['domain'] = k 23 | d['intent'] = kk 24 | d['slots'] = vvv[1] 25 | for i in range(215//len(vv)): 26 | data_pro.append(d) 27 | 28 | os.makedirs('data/domain') 29 | for domain in data: 30 | os.makedirs('data/domain/' + domain) 31 | for intent in data[domain]: 32 | dic = [] 33 | for ele in data[domain][intent]: 34 | d = collections.OrderedDict() 35 | d["text"] = ele[0] 36 | d["domain"] = domain 37 | d["intent"] = intent 38 | d["slots"] = ele[1] 39 | dic.append(d) 40 | json.dump(dic, open('data/'+domain+'/'+intent+'.json', 'w'), ensure_ascii = False, indent = 2) 41 | 42 | os.makedirs('data/intent') 43 | for intent in data: 44 | os.makedirs('data/intent/' + intent) 45 | for domain in data[intent]: 46 | dic = [] 47 | for ele in data[intent][domain]: 48 | d = collections.OrderedDict() 49 | d["text"] = ele[0] 50 | d["domain"] = domain 51 | d["intent"] = intent 52 | d["slots"] = ele[1] 53 | dic.append(d) 54 | json.dump(dic, open('data/intent/'+intent+'/'+domain+'.json', 'w'), ensure_ascii = False, indent = 2) 55 | 56 | slots = {} 57 | for data in data_json: 58 | for slot, val in data['slots'].items(): 59 | if slot not in slots: 60 | slots[slot] = set() 61 | slots[slot].add(val) 62 | 63 | 64 | os.makedirs('data/slots') 65 | for slot in slots: 66 | with open("data/slots/" + slot + ".txt", 'w') as f: 67 | for s in slots[slot]: 68 | f.write(s + '\n') 69 | 70 | # *************************************************************************************** # 71 | train_json = json.load(open('train_local.json'), encoding='utf8') 72 | text = [] 73 | for line in train_json: 74 | text.append(line['text']) 75 | 76 | data = [] 77 | for line in data_json: 78 | if line['text'] in text: 79 | continue 80 | data.append(line) 81 | 82 | json.dump(data, open('test_local.json', 'w'), ensure_ascii = False, indent = 2) 83 | 84 | 85 | # *************************************************************************************** # 86 | import json 87 | import numpy as np 88 | import pandas as pd 89 | from sklearn.model_selection import StratifiedKFold 90 | 91 | data_json = json.load(open('train.json'), encoding='utf8') 92 | text = np.array([data['text'] for data in data_json]) 93 | domain = np.array([data['domain'] for data in data_json]) 94 | intent = np.array([data['intent'] for data in data_json]) 95 | label = [d+i for d, i in zip(domain, intent)] 96 | kfold = list(StratifiedKFold(n_splits = 2, random_state = 2019, shuffle = True).split(text, label)) 97 | train_index, val_index = kfold[0] 98 | text = text[train_index] 99 | 100 | train_json, test_json = [], [] 101 | for data in data_json: 102 | if data['text'] in text: 103 | train_json.append(data) 104 | else: 105 | test_json.append(data) 106 | 107 | json.dump(train_json, open('train_eval.json', 'w'), ensure_ascii = False, indent = 2) 108 | json.dump(test_json, open('test_eval.json', 'w'), ensure_ascii = False, indent = 2) 109 | 110 | # *************************************************************************************** # 111 | import json 112 | from collections import OrderedDict 113 | data_json1 = json.load(open('result1/test_result.json', encoding='utf8'), object_pairs_hook=OrderedDict) 114 | data_json2 = json.load(open('result2/test_result.json', encoding='utf8'), object_pairs_hook=OrderedDict) 115 | data_json3 = json.load(open('result3/test_result.json', encoding='utf8'), object_pairs_hook=OrderedDict) 116 | data_json = [] 117 | for data1, data2, data3 in zip(data_json1, data_json2, data_json3): 118 | d = OrderedDict() 119 | d['text'] = data1['text'] 120 | d['domain'] = data1['domain'] 121 | d['intent'] = data2['intent'] 122 | d['slots'] = data3['slots'] 123 | data_json.append(d) 124 | 125 | json.dump(data_json, open('result/test_result.json', 'w'), ensure_ascii = False, indent = 2) 126 | 127 | # *************************************************************************************** # 128 | import json 129 | from collections import OrderedDict 130 | 131 | data_json = json.load(open('train_eval.json', encoding='utf8'), object_pairs_hook=OrderedDict) 132 | dic = {} 133 | for data in data_json: 134 | if data['domain'] not in dic: 135 | dic[data['domain']] = [] 136 | dic[data['domain']].append(data) 137 | 138 | length = max([len(v) for k,v in dic.items()]) 139 | new_data = [] 140 | for k,v in dic.items(): 141 | times = length // len(v) 142 | for i in range(times): 143 | new_data.extend(v) 144 | 145 | json.dump(new_data, open('result/train1.json', 'w'), ensure_ascii = False, indent = 2) 146 | 147 | # *************************************************************************************** # 148 | dic = {} 149 | for data in data_json: 150 | if data['domain'] not in dic: 151 | dic[data['domain']] = [set(), set()] 152 | dic[data['domain']][0].add(data['intent']) 153 | for slot in data['slots']: 154 | dic[data['domain']][1].add(slot) 155 | 156 | for k, v in dic.items(): 157 | print('\''+k+'\':' ,v, ',') -------------------------------------------------------------------------------- /dataSet/dic/city.txt: -------------------------------------------------------------------------------- 1 | 北京市 2 | 东莞市 3 | 广州市 4 | 中山市 5 | 深圳市 6 | 惠州市 7 | 江门市 8 | 珠海市 9 | 汕头市 10 | 佛山市 11 | 湛江市 12 | 河源市 13 | 肇庆市 14 | 潮州市 15 | 清远市 16 | 韶关市 17 | 揭阳市 18 | 阳江市 19 | 云浮市 20 | 茂名市 21 | 梅州市 22 | 汕尾市 23 | 济南市 24 | 青岛市 25 | 临沂市 26 | 济宁市 27 | 菏泽市 28 | 烟台市 29 | 泰安市 30 | 淄博市 31 | 潍坊市 32 | 日照市 33 | 威海市 34 | 滨州市 35 | 东营市 36 | 聊城市 37 | 德州市 38 | 莱芜市 39 | 枣庄市 40 | 苏州市 41 | 徐州市 42 | 盐城市 43 | 无锡市 44 | 南京市 45 | 南通市 46 | 连云港市 47 | 常州市 48 | 扬州市 49 | 镇江市 50 | 淮安市 51 | 泰州市 52 | 宿迁市 53 | 郑州市 54 | 南阳市 55 | 新乡市 56 | 安阳市 57 | 洛阳市 58 | 信阳市 59 | 平顶山市 60 | 周口市 61 | 商丘市 62 | 开封市 63 | 焦作市 64 | 驻马店市 65 | 濮阳市 66 | 三门峡市 67 | 漯河市 68 | 许昌市 69 | 鹤壁市 70 | 济源市 71 | 上海市 72 | 石家庄市 73 | 唐山市 74 | 保定市 75 | 邯郸市 76 | 邢台市 77 | 河北区 78 | 沧州市 79 | 秦皇岛市 80 | 张家口市 81 | 衡水市 82 | 廊坊市 83 | 承德市 84 | 温州市 85 | 宁波市 86 | 杭州市 87 | 台州市 88 | 嘉兴市 89 | 金华市 90 | 湖州市 91 | 绍兴市 92 | 舟山市 93 | 丽水市 94 | 衢州市 95 | 西安市 96 | 咸阳市 97 | 宝鸡市 98 | 汉中市 99 | 渭南市 100 | 安康市 101 | 榆林市 102 | 商洛市 103 | 延安市 104 | 铜川市 105 | 长沙市 106 | 邵阳市 107 | 常德市 108 | 衡阳市 109 | 株洲市 110 | 湘潭市 111 | 永州市 112 | 岳阳市 113 | 怀化市 114 | 郴州市 115 | 娄底市 116 | 益阳市 117 | 张家界市 118 | 湘西州 119 | 重庆市 120 | 漳州市 121 | 泉州市 122 | 厦门市 123 | 福州市 124 | 莆田市 125 | 宁德市 126 | 三明市 127 | 南平市 128 | 龙岩市 129 | 天津市 130 | 昆明市 131 | 红河州 132 | 大理州 133 | 文山州 134 | 德宏州 135 | 曲靖市 136 | 昭通市 137 | 楚雄州 138 | 保山市 139 | 玉溪市 140 | 丽江地区 141 | 临沧地区 142 | 思茅地区 143 | 西双版纳州 144 | 怒江州 145 | 迪庆州 146 | 成都市 147 | 绵阳市 148 | 广元市 149 | 达州市 150 | 南充市 151 | 德阳市 152 | 广安市 153 | 阿坝州 154 | 巴中市 155 | 遂宁市 156 | 内江市 157 | 凉山州 158 | 攀枝花市 159 | 乐山市 160 | 自贡市 161 | 泸州市 162 | 雅安市 163 | 宜宾市 164 | 资阳市 165 | 眉山市 166 | 甘孜州 167 | 贵港市 168 | 玉林市 169 | 北海市 170 | 南宁市 171 | 柳州市 172 | 桂林市 173 | 梧州市 174 | 钦州市 175 | 来宾市 176 | 河池市 177 | 百色市 178 | 贺州市 179 | 崇左市 180 | 防城港市 181 | 芜湖市 182 | 合肥市 183 | 六安市 184 | 宿州市 185 | 阜阳市 186 | 安庆市 187 | 马鞍山市 188 | 蚌埠市 189 | 淮北市 190 | 淮南市 191 | 宣城市 192 | 黄山市 193 | 铜陵市 194 | 亳州市 195 | 池州市 196 | 巢湖市 197 | 滁州市 198 | 三亚市 199 | 海口市 200 | 琼海市 201 | 文昌市 202 | 东方市 203 | 昌江县 204 | 陵水县 205 | 乐东县 206 | 五指山市 207 | 保亭县 208 | 澄迈县 209 | 万宁市 210 | 儋州市 211 | 临高县 212 | 白沙县 213 | 定安县 214 | 琼中县 215 | 屯昌县 216 | 南昌市 217 | 赣州市 218 | 上饶市 219 | 吉安市 220 | 九江市 221 | 新余市 222 | 抚州市 223 | 宜春市 224 | 景德镇市 225 | 萍乡市 226 | 鹰潭市 227 | 武汉市 228 | 宜昌市 229 | 襄樊市 230 | 荆州市 231 | 恩施州 232 | 孝感市 233 | 黄冈市 234 | 十堰市 235 | 咸宁市 236 | 黄石市 237 | 仙桃市 238 | 随州市 239 | 天门市 240 | 荆门市 241 | 潜江市 242 | 鄂州市 243 | 太原市 244 | 大同市 245 | 运城市 246 | 长治市 247 | 晋城市 248 | 忻州市 249 | 临汾市 250 | 吕梁市 251 | 晋中市 252 | 阳泉市 253 | 朔州市 254 | 大连市 255 | 沈阳市 256 | 丹东市 257 | 辽阳市 258 | 葫芦岛市 259 | 锦州市 260 | 朝阳市 261 | 营口市 262 | 鞍山市 263 | 抚顺市 264 | 阜新市 265 | 本溪市 266 | 盘锦市 267 | 铁岭市 268 | 台北市 269 | 高雄市 270 | 台中市 271 | 新竹市 272 | 基隆市 273 | 台南市 274 | 嘉义市 275 | 齐齐哈尔市 276 | 哈尔滨市 277 | 大庆市 278 | 佳木斯市 279 | 双鸭山市 280 | 牡丹江市 281 | 鸡西市 282 | 黑河市 283 | 绥化市 284 | 鹤岗市 285 | 伊春市 286 | 大兴安岭地区 287 | 七台河市 288 | 贵阳市 289 | 黔东南州 290 | 黔南州 291 | 遵义市 292 | 黔西南州 293 | 毕节地区 294 | 铜仁地区 295 | 安顺市 296 | 六盘水市 297 | 兰州市 298 | 天水市 299 | 庆阳市 300 | 武威市 301 | 酒泉市 302 | 张掖市 303 | 陇南地区 304 | 白银市 305 | 定西地区 306 | 平凉市 307 | 嘉峪关市 308 | 临夏回族自治州 309 | 金昌市 310 | 甘南州 311 | 西宁市 312 | 海西州 313 | 海东地区 314 | 海北州 315 | 果洛州 316 | 玉树州 317 | 黄南藏族自治州 318 | 吉林市 319 | 长春市 320 | 白山市 321 | 白城市 322 | 延边州 323 | 松原市 324 | 辽源市 325 | 通化市 326 | 四平市 327 | 银川市 328 | 吴忠市 329 | 中卫市 330 | 石嘴山市 331 | 固原市 332 | 拉萨市 333 | 乌鲁木齐市 334 | 伊犁州 335 | 昌吉州 336 | 石河子市 337 | 阿拉尔市 338 | 博尔塔拉州 339 | 五家渠市 340 | 克孜勒苏州 341 | 图木舒克市 342 | 香港 343 | 北京 344 | 东莞 345 | 广州 346 | 中山 347 | 深圳 348 | 惠州 349 | 江门 350 | 珠海 351 | 汕头 352 | 佛山 353 | 湛江 354 | 河源 355 | 肇庆 356 | 潮州 357 | 清远 358 | 韶关 359 | 揭阳 360 | 阳江 361 | 云浮 362 | 茂名 363 | 梅州 364 | 汕尾 365 | 济南 366 | 青岛 367 | 临沂 368 | 济宁 369 | 菏泽 370 | 烟台 371 | 泰安 372 | 淄博 373 | 潍坊 374 | 日照 375 | 威海 376 | 滨州 377 | 东营 378 | 聊城 379 | 德州 380 | 莱芜 381 | 枣庄 382 | 苏州 383 | 徐州 384 | 盐城 385 | 无锡 386 | 南京 387 | 南通 388 | 连云港 389 | 常州 390 | 扬州 391 | 镇江 392 | 淮安 393 | 泰州 394 | 宿迁 395 | 郑州 396 | 南阳 397 | 新乡 398 | 安阳 399 | 洛阳 400 | 信阳 401 | 平顶山 402 | 周口 403 | 商丘 404 | 开封 405 | 焦作 406 | 驻马店 407 | 濮阳 408 | 三门峡 409 | 漯河 410 | 许昌 411 | 鹤壁 412 | 济源 413 | 上海 414 | 石家庄 415 | 唐山 416 | 保定 417 | 邯郸 418 | 邢台 419 | 沧州 420 | 秦皇岛 421 | 张家口 422 | 衡水 423 | 廊坊 424 | 承德 425 | 温州 426 | 宁波 427 | 杭州 428 | 台州 429 | 嘉兴 430 | 金华 431 | 湖州 432 | 绍兴 433 | 舟山 434 | 丽水 435 | 衢州 436 | 西安 437 | 咸阳 438 | 宝鸡 439 | 汉中 440 | 渭南 441 | 安康 442 | 榆林 443 | 商洛 444 | 延安 445 | 铜川 446 | 长沙 447 | 邵阳 448 | 常德 449 | 衡阳 450 | 株洲 451 | 湘潭 452 | 永州 453 | 岳阳 454 | 怀化 455 | 郴州 456 | 娄底 457 | 益阳 458 | 张家界 459 | 重庆 460 | 漳州 461 | 泉州 462 | 厦门 463 | 福州 464 | 莆田 465 | 宁德 466 | 三明 467 | 南平 468 | 龙岩 469 | 天津 470 | 昆明 471 | 曲靖 472 | 昭通 473 | 保山 474 | 玉溪 475 | 成都 476 | 绵阳 477 | 广元 478 | 达州 479 | 南充 480 | 德阳 481 | 广安 482 | 巴中 483 | 遂宁 484 | 内江 485 | 攀枝花 486 | 乐山 487 | 自贡 488 | 泸州 489 | 雅安 490 | 宜宾 491 | 资阳 492 | 眉山 493 | 贵港 494 | 玉林 495 | 北海 496 | 南宁 497 | 柳州 498 | 桂林 499 | 梧州 500 | 钦州 501 | 来宾 502 | 河池 503 | 百色 504 | 贺州 505 | 崇左 506 | 防城港 507 | 芜湖 508 | 合肥 509 | 六安 510 | 宿州 511 | 阜阳 512 | 安庆 513 | 马鞍山 514 | 蚌埠 515 | 淮北 516 | 淮南 517 | 宣城 518 | 黄山 519 | 铜陵 520 | 亳州 521 | 池州 522 | 巢湖 523 | 滁州 524 | 三亚 525 | 海口 526 | 琼海 527 | 文昌 528 | 东方 529 | 五指山 530 | 万宁 531 | 儋州 532 | 南昌 533 | 赣州 534 | 上饶 535 | 吉安 536 | 九江 537 | 新余 538 | 抚州 539 | 宜春 540 | 景德镇 541 | 萍乡 542 | 鹰潭 543 | 武汉 544 | 宜昌 545 | 襄樊 546 | 荆州 547 | 孝感 548 | 黄冈 549 | 十堰 550 | 咸宁 551 | 黄石 552 | 仙桃 553 | 随州 554 | 天门 555 | 荆门 556 | 潜江 557 | 鄂州 558 | 太原 559 | 大同 560 | 运城 561 | 长治 562 | 晋城 563 | 忻州 564 | 临汾 565 | 吕梁 566 | 晋中 567 | 阳泉 568 | 朔州 569 | 大连 570 | 沈阳 571 | 丹东 572 | 辽阳 573 | 葫芦岛 574 | 锦州 575 | 朝阳 576 | 营口 577 | 鞍山 578 | 抚顺 579 | 阜新 580 | 本溪 581 | 盘锦 582 | 铁岭 583 | 台北 584 | 高雄 585 | 台中 586 | 新竹 587 | 基隆 588 | 台南 589 | 嘉义 590 | 齐齐哈尔 591 | 哈尔滨 592 | 大庆 593 | 佳木斯 594 | 双鸭山 595 | 牡丹江 596 | 鸡西 597 | 黑河 598 | 绥化 599 | 鹤岗 600 | 伊春 601 | 七台河 602 | 贵阳 603 | 遵义 604 | 安顺 605 | 六盘水 606 | 兰州 607 | 天水 608 | 庆阳 609 | 武威 610 | 酒泉 611 | 张掖 612 | 白银 613 | 平凉 614 | 嘉峪关 615 | 金昌 616 | 西宁 617 | 吉林 618 | 长春 619 | 白山 620 | 白城 621 | 松原 622 | 辽源 623 | 通化 624 | 四平 625 | 银川 626 | 吴忠 627 | 中卫 628 | 石嘴山 629 | 固原 630 | 拉萨 631 | 乌鲁木齐 632 | 石河子 633 | 阿拉尔 634 | 五家渠 635 | 图木舒克 636 | -------------------------------------------------------------------------------- /domain_rule.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | 5 | dic_path = 'dataSet/op_dic/dic' 6 | op_path = 'dataSet/op_dic' 7 | 8 | 9 | def app(text, domain, all_app, app_op): 10 | 11 | # print(all_app) 12 | for app in all_app: 13 | new_text = ['',''] 14 | if text == app: 15 | return ['app',"LAUNCH"] 16 | # result[i]['intent'] = "LAUNCH" 17 | if domain != 'app': 18 | for op in app_op: 19 | op_search = re.search(op.strip(),text) 20 | if op_search: 21 | new_text = [text[:op_search.span()[0]],text[op_search.span()[1]:]] 22 | break 23 | search_words = re.sub("[\s+\.\!\/_,$%^*()+\"\']+|[+——!,。?、~@#¥%……&*()]+", "",app.strip()) 24 | if op_search and (re.search(search_words,new_text[0]) or re.search(search_words,new_text[-1])): 25 | return ['app'] 26 | # result[i]['domain'] = 'app' 27 | return [] 28 | 29 | 30 | def music(text, domain, music_op): 31 | 32 | if domain != 'music': 33 | for op in music_op: 34 | op_search = re.search(op.strip(),text) 35 | if op_search: 36 | text = text[:op_search.span()[0]] + text[op_search.span()[1]:] 37 | return ['music'] 38 | return [] 39 | 40 | 41 | def poetry(text, domain, all_poetry, poetry_op): 42 | 43 | for poetry in all_poetry: 44 | new_text = ['',''] 45 | if text == poetry: 46 | return ['poetry',"QUERY"] 47 | # result[i]['intent'] = "LAUNCH" 48 | if domain != 'poetry': 49 | for op in poetry_op: 50 | op_search = re.search(op.strip(),text) 51 | if op_search: 52 | new_text = [text[:op_search.span()[0]], text[op_search.span()[1]:]] 53 | break 54 | search_words = re.sub("[\s+\.\!\/_,$%^*()+\"\']+|[+——!,。?、~@#¥%……&*()]+", "",poetry.strip()) 55 | if op_search and (re.search(search_words,new_text[0]) or re.search(search_words,new_text[-1])): 56 | return ['poetry'] 57 | return [] 58 | 59 | 60 | def bus(text, domain, all_city): 61 | for city in all_city: 62 | new_text = ['',''] 63 | if domain != 'bus': 64 | for tool in ['车汽','车']: 65 | text = text[::-1] 66 | tool_search = re.search(tool,text) 67 | if tool_search: 68 | text = text[:tool_search.span()[0]] + text[tool_search.span()[1]:] 69 | text = text[::-1] 70 | break 71 | for op in ['到','去','回']: 72 | op_search = re.search(op, text) 73 | if op_search: 74 | new_text = [text[:op_search.span()[0]], text[op_search.span()[1]:]] 75 | break 76 | if tool_search and op_search and (re.search(city.strip(), new_text[0]) or re.search(city.strip(), new_text[1])): 77 | return ['bus'] 78 | return [] 79 | 80 | 81 | def video(text, domain, videos): 82 | subtexts = [] 83 | for k in range(len(text)): 84 | for i in range(k,len(text)): 85 | subtexts.append(text[k:i+1]) 86 | 87 | if videos.get(text,None): 88 | return ['video', text] 89 | for subtext in subtexts: 90 | if domain != 'video': 91 | search_result = videos.get(subtext,None) 92 | if search_result: 93 | th = float(len(subtext))/float(len(text)) 94 | if th >= 0.6: 95 | return 'video' 96 | return [] 97 | 98 | 99 | 100 | def tvchannel(text, domain, all_tvchannel): 101 | for tvchannel in all_tvchannel: 102 | final_result = ['tvchannel',None,None] 103 | if text == tvchannel: 104 | return ['tvchannel', text] 105 | if domain != 'tvchannel': 106 | search_words = re.sub("[\s+\.\!\/_,$%^*()+\"\']+|[+——!,。?、~@#¥%……&*()]+", "", tvchannel.strip()) 107 | search_result = re.search(search_words,text) 108 | if search_result: 109 | resolution_search = re.search('高清',text) 110 | if resolution_search: 111 | text = text[:resolution_search.span()[0]]+text[resolution_search.span()[1]:] 112 | final_result = ['tvchannel','高清',search_words] 113 | th = float(search_result.span()[1] - search_result.span()[0])/float(len(text)) 114 | if th >= 0.7: 115 | return final_result 116 | return [] 117 | 118 | 119 | def website(text,domain,websites,website_op): 120 | subtexts = [] 121 | for k in range(len(text)): 122 | for i in range(k,len(text)): 123 | subtexts.append(text[k:i+1]) 124 | 125 | if websites.get(text,None): 126 | return ['website', text] 127 | 128 | for subtext in subtexts: 129 | new_text = ['',''] 130 | if domain != 'website': 131 | for op in website_op: 132 | op_search = subtext.find(op) 133 | if op_search != -1: 134 | s = subtext.find(op) 135 | new_text = [subtext[:s],subtext[s+len(op)-1:]] 136 | break 137 | if op_search and (websites.get(new_text[0]) or websites.get(new_text[-1])): 138 | return ['website',subtext] 139 | # result[i]['domain'] = 'app' 140 | return [] 141 | 142 | 143 | 144 | ################################################################################################################## 145 | 146 | def domain_rule(result): 147 | 148 | app_count = 0 149 | for i,pred in enumerate(result): 150 | flag = False 151 | text = pred['text'] 152 | domain = pred['domain'] 153 | 154 | # app up 155 | with open(os.path.join(dic_path,'app.txt'),encoding='UTF-8') as f_app: 156 | all_app = f_app.read().strip().split('\n') 157 | 158 | with open(os.path.join(op_path,'app_op.txt'),encoding='UTF-8') as f_app_op: 159 | app_op = f_app_op.read().strip().split('\n') 160 | 161 | app_result = app(text, domain, all_app, app_op) 162 | if len(app_result) == 2: 163 | result[i]['domain'] = app_result[0] 164 | result[i]['intent'] = app_result[1] 165 | continue 166 | elif len(app_result) == 1: 167 | result[i]['domain'] = app_result[0] 168 | continue 169 | else: 170 | pass 171 | 172 | 173 | 174 | # music down 175 | # with open(os.path.join(op_path,'music_op.txt'),encoding='UTF-8') as f_music_op: 176 | # music_op = f_music_op.read().strip().split('\n') 177 | # 178 | # music_result = music(text,domain,music_op) 179 | # if len(music_result) == 1: 180 | # result[i]['domain'] = music_result[0] 181 | # continue 182 | # else: 183 | # pass 184 | 185 | 186 | # poetry no change 187 | # with open(os.path.join(dic_path,'poetry.txt'),encoding='UTF-8') as f_poetry: 188 | # all_poetry = f_poetry.read().strip().split('\n') 189 | # 190 | # with open(os.path.join(op_path,'poetry_op.txt'),encoding='UTF-8') as f_poetry_op: 191 | # poetry_op = f_poetry_op.read().strip().split('\n') 192 | # 193 | # poetry_result = poetry(text, domain,all_poetry, poetry_op) 194 | # if len(poetry_result) == 2: 195 | # result[i]['domain'] = poetry_result[0] 196 | # result[i]['intent'] = poetry_result[1] 197 | # continue 198 | # elif len(poetry_result) == 1: 199 | # result[i]['domain'] = poetry_result[0] 200 | # continue 201 | # else: 202 | # pass 203 | 204 | 205 | # bus no change 206 | with open(os.path.join(dic_path,'city.txt'),encoding='UTF-8') as f_city: 207 | all_city = f_city.read().strip().split('\n') 208 | 209 | city_result = bus(text,domain,all_city) 210 | if len(city_result) == 1: 211 | result[i]['domain'] = city_result[0] 212 | continue 213 | else: 214 | pass 215 | 216 | 217 | # video up 218 | with open(os.path.join(dic_path,'film.txt'),encoding='UTF-8') as f_video: 219 | all_video = f_video.read().strip().split('\n') 220 | videos = {} 221 | for v in all_video: 222 | videos[re.sub("[\s+\.\!\/_,$%^*()+\"\']+|[+——!,。?、~@#¥%……&*()]+", "",v.strip())] = 1 223 | 224 | with open(os.path.join(dic_path,'film_category.txt'),encoding='UTF-8') as f_video: 225 | video_category = f_video.read().strip().split('\n') 226 | 227 | video_result = video(text,domain,videos) 228 | if len(video_result) == 2: 229 | result[i]['domain'] = video_result[0] 230 | if video_result[1] in video_category: 231 | result[i]['slots']['category'] = video_result[1] 232 | else: 233 | result[i]['slots']['name'] = video_result[1] 234 | continue 235 | if len(video_result) == 1: 236 | result[i]['domain'] = video_result[0] 237 | continue 238 | else: 239 | pass 240 | 241 | 242 | # tvchannel up 243 | with open(os.path.join(dic_path,'tvchannel.txt'),encoding='UTF-8') as f_tvchannel: 244 | all_tvchannel = f_tvchannel.read().strip().split('\n') 245 | 246 | tvchannel_result = tvchannel(text,domain,all_tvchannel) 247 | if len(tvchannel_result) == 3: 248 | result[i]['domain'] = tvchannel_result[0] 249 | result[i]['slots']['resolution'] = tvchannel_result[1] 250 | result[i]['slots']['name'] = tvchannel_result[2] 251 | continue 252 | if len(tvchannel_result) == 2: 253 | result[i]['domain'] = tvchannel_result[0] 254 | result[i]['slots']['name'] = tvchannel_result[1] 255 | continue 256 | else: 257 | pass 258 | 259 | 260 | # website down 261 | with open(os.path.join(dic_path,'website.txt'),encoding='UTF-8') as f_website: 262 | all_website = f_website.read().strip().split('\n') 263 | websites = {} 264 | for w in websites: 265 | websites[re.sub("[\s+\.\!\/_,$%^*()+\"\']+|[+——!,。?、~@#¥%……&*()]+", "",w.strip())] = 1 266 | 267 | with open(os.path.join(op_path,'website_op.txt'),encoding='UTF-8') as f_website_op: 268 | website_op = f_website_op.read().strip().split('\n') 269 | 270 | website_result = website(text, domain, websites, website_op) 271 | if len(website_result) == 2: 272 | result[i]['domain'] = website_result[0] 273 | result[i]['slots']['name'] = website_result[1] 274 | continue 275 | # elif len(website_result) == 1: 276 | # result[i]['domain'] = website_result[0] 277 | # continue 278 | else: 279 | pass 280 | 281 | 282 | 283 | 284 | 285 | return result 286 | 287 | 288 | # if __name__ == '__main__': 289 | # result = json.load(open('result/test_result2.json', encoding = 'utf8'), object_pairs_hook = OrderedDict) 290 | # domain_rule(result) 291 | -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | import abc 24 | import sys 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | if sys.version_info >= (3, 4): 30 | ABC = abc.ABC 31 | else: 32 | ABC = abc.ABCMeta('ABC', (), {}) 33 | 34 | 35 | class _LRSchedule(ABC): 36 | """ Parent of all LRSchedules here. """ 37 | warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense 38 | def __init__(self, warmup=0.002, t_total=-1, **kw): 39 | """ 40 | :param warmup: what fraction of t_total steps will be used for linear warmup 41 | :param t_total: how many training steps (updates) are planned 42 | :param kw: 43 | """ 44 | super(_LRSchedule, self).__init__(**kw) 45 | if t_total < 0: 46 | logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) 47 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 48 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 49 | warmup = max(warmup, 0.) 50 | self.warmup, self.t_total = float(warmup), float(t_total) 51 | self.warned_for_t_total_at_progress = -1 52 | 53 | def get_lr(self, step, nowarn=False): 54 | """ 55 | :param step: which of t_total steps we're on 56 | :param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps 57 | :return: learning rate multiplier for current update 58 | """ 59 | if self.t_total < 0: 60 | return 1. 61 | progress = float(step) / self.t_total 62 | ret = self.get_lr_(progress) 63 | # warning for exceeding t_total (only active with warmup_linear 64 | if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress: 65 | logger.warning( 66 | "Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly." 67 | .format(ret, self.__class__.__name__)) 68 | self.warned_for_t_total_at_progress = progress 69 | # end warning 70 | return ret 71 | 72 | @abc.abstractmethod 73 | def get_lr_(self, progress): 74 | """ 75 | :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress 76 | :return: learning rate multiplier for current update 77 | """ 78 | return 1. 79 | 80 | 81 | class ConstantLR(_LRSchedule): 82 | def get_lr_(self, progress): 83 | return 1. 84 | 85 | 86 | class WarmupCosineSchedule(_LRSchedule): 87 | """ 88 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 89 | Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve. 90 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 91 | """ 92 | warn_t_total = True 93 | def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw): 94 | """ 95 | :param warmup: see LRSchedule 96 | :param t_total: see LRSchedule 97 | :param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1. 98 | :param kw: 99 | """ 100 | super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw) 101 | self.cycles = cycles 102 | 103 | def get_lr_(self, progress): 104 | if progress < self.warmup: 105 | return progress / self.warmup 106 | else: 107 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 108 | return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress)) 109 | 110 | 111 | class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule): 112 | """ 113 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 114 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying 115 | learning rate (with hard restarts). 116 | """ 117 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): 118 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) 119 | assert(cycles >= 1.) 120 | 121 | def get_lr_(self, progress): 122 | if progress < self.warmup: 123 | return progress / self.warmup 124 | else: 125 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 126 | ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1))) 127 | return ret 128 | 129 | 130 | class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule): 131 | """ 132 | All training progress is divided in `cycles` (default=1.) parts of equal length. 133 | Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1., 134 | followed by a learning rate decreasing from 1. to 0. following a cosine curve. 135 | """ 136 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): 137 | assert(warmup * cycles < 1.) 138 | warmup = warmup * cycles if warmup >= 0 else warmup 139 | super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) 140 | 141 | def get_lr_(self, progress): 142 | progress = progress * self.cycles % 1. 143 | if progress < self.warmup: 144 | return progress / self.warmup 145 | else: 146 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 147 | ret = 0.5 * (1. + math.cos(math.pi * progress)) 148 | return ret 149 | 150 | 151 | class WarmupConstantSchedule(_LRSchedule): 152 | """ 153 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 154 | Keeps learning rate equal to 1. after warmup. 155 | """ 156 | def get_lr_(self, progress): 157 | if progress < self.warmup: 158 | return progress / self.warmup 159 | return 1. 160 | 161 | 162 | class WarmupLinearSchedule(_LRSchedule): 163 | """ 164 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 165 | Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps. 166 | """ 167 | warn_t_total = True 168 | def get_lr_(self, progress): 169 | if progress < self.warmup: 170 | return progress / self.warmup 171 | return max((progress - 1.) / (self.warmup - 1.), 0.) 172 | 173 | 174 | SCHEDULES = { 175 | None: ConstantLR, 176 | "none": ConstantLR, 177 | "warmup_cosine": WarmupCosineSchedule, 178 | "warmup_constant": WarmupConstantSchedule, 179 | "warmup_linear": WarmupLinearSchedule 180 | } 181 | 182 | 183 | class BertAdam(Optimizer): 184 | """Implements BERT version of Adam algorithm with weight decay fix. 185 | Params: 186 | lr: learning rate 187 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 188 | t_total: total number of training steps for the learning 189 | rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1 190 | schedule: schedule to use for the warmup (see above). 191 | Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below). 192 | If `None` or `'none'`, learning rate is always kept constant. 193 | Default : `'warmup_linear'` 194 | b1: Adams b1. Default: 0.9 195 | b2: Adams b2. Default: 0.999 196 | e: Adams epsilon. Default: 1e-6 197 | weight_decay: Weight decay. Default: 0.01 198 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 199 | """ 200 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 201 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): 202 | if lr is not required and lr < 0.0: 203 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 204 | if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: 205 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 206 | if not 0.0 <= b1 < 1.0: 207 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 208 | if not 0.0 <= b2 < 1.0: 209 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 210 | if not e >= 0.0: 211 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 212 | # initialize schedule object 213 | if not isinstance(schedule, _LRSchedule): 214 | schedule_type = SCHEDULES[schedule] 215 | schedule = schedule_type(warmup=warmup, t_total=t_total) 216 | else: 217 | if warmup != -1 or t_total != -1: 218 | logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. " 219 | "Please specify custom warmup and t_total in _LRSchedule object.") 220 | defaults = dict(lr=lr, schedule=schedule, 221 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 222 | max_grad_norm=max_grad_norm) 223 | super(BertAdam, self).__init__(params, defaults) 224 | 225 | def get_lr(self): 226 | lr = [] 227 | for group in self.param_groups: 228 | for p in group['params']: 229 | state = self.state[p] 230 | if len(state) == 0: 231 | return [0] 232 | lr_scheduled = group['lr'] 233 | lr_scheduled *= group['schedule'].get_lr(state['step']) 234 | lr.append(lr_scheduled) 235 | return lr 236 | 237 | def step(self, closure=None): 238 | """Performs a single optimization step. 239 | 240 | Arguments: 241 | closure (callable, optional): A closure that reevaluates the model 242 | and returns the loss. 243 | """ 244 | loss = None 245 | if closure is not None: 246 | loss = closure() 247 | 248 | for group in self.param_groups: 249 | for p in group['params']: 250 | if p.grad is None: 251 | continue 252 | grad = p.grad.data 253 | if grad.is_sparse: 254 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 255 | 256 | state = self.state[p] 257 | 258 | # State initialization 259 | if len(state) == 0: 260 | state['step'] = 0 261 | # Exponential moving average of gradient values 262 | state['next_m'] = torch.zeros_like(p.data) 263 | # Exponential moving average of squared gradient values 264 | state['next_v'] = torch.zeros_like(p.data) 265 | 266 | next_m, next_v = state['next_m'], state['next_v'] 267 | beta1, beta2 = group['b1'], group['b2'] 268 | 269 | # Add grad clipping 270 | if group['max_grad_norm'] > 0: 271 | clip_grad_norm_(p, group['max_grad_norm']) 272 | 273 | # Decay the first and second moment running average coefficient 274 | # In-place operations to update the averages at the same time 275 | next_m.mul_(beta1).add_(1 - beta1, grad) 276 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 277 | update = next_m / (next_v.sqrt() + group['e']) 278 | 279 | # Just adding the square of the weights to the loss function is *not* 280 | # the correct way of using L2 regularization/weight decay with Adam, 281 | # since that will interact with the m and v parameters in strange ways. 282 | # 283 | # Instead we want to decay the weights in a manner that doesn't interact 284 | # with the m/v parameters. This is equivalent to adding the square 285 | # of the weights to the loss with plain (non-momentum) SGD. 286 | if group['weight_decay'] > 0.0: 287 | update += group['weight_decay'] * p.data 288 | 289 | lr_scheduled = group['lr'] 290 | lr_scheduled *= group['schedule'].get_lr(state['step']) 291 | 292 | update_with_lr = lr_scheduled * update 293 | p.data.add_(-update_with_lr) 294 | 295 | state['step'] += 1 296 | 297 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 298 | # No bias correction 299 | # bias_correction1 = 1 - beta1 ** state['step'] 300 | # bias_correction2 = 1 - beta2 ** state['step'] 301 | 302 | return loss 303 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def load_vocab(vocab_file): 29 | """Loads a vocabulary file into a dictionary.""" 30 | vocab = collections.OrderedDict() 31 | index = 0 32 | with open(vocab_file, "r", encoding="utf-8") as reader: 33 | while True: 34 | token = reader.readline() 35 | if not token: 36 | break 37 | token = token.strip() 38 | vocab[token] = index 39 | index += 1 40 | return vocab 41 | 42 | 43 | def whitespace_tokenize(text): 44 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 45 | text = text.strip() 46 | if not text: 47 | return [] 48 | tokens = text.split() 49 | return tokens 50 | 51 | 52 | class BertTokenizer(object): 53 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 54 | 55 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, 56 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 57 | """Constructs a BertTokenizer. 58 | 59 | Args: 60 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 61 | do_lower_case: Whether to lower case the input 62 | Only has an effect when do_wordpiece_only=False 63 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 64 | max_len: An artificial maximum length to truncate tokenized sequences to; 65 | Effective maximum length is always the minimum of this 66 | value (if specified) and the underlying BERT model's 67 | sequence length. 68 | never_split: List of tokens which will never be split during tokenization. 69 | Only has an effect when do_wordpiece_only=False 70 | """ 71 | if not os.path.isfile(vocab_file): 72 | raise ValueError( 73 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 74 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 75 | self.vocab = load_vocab(vocab_file) 76 | self.ids_to_tokens = collections.OrderedDict( 77 | [(ids, tok) for tok, ids in self.vocab.items()]) 78 | self.do_basic_tokenize = do_basic_tokenize 79 | if do_basic_tokenize: 80 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 81 | never_split=never_split) 82 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 83 | self.max_len = max_len if max_len is not None else int(1e12) 84 | 85 | def tokenize(self, text): 86 | split_tokens = [] 87 | if self.do_basic_tokenize: 88 | for token in self.basic_tokenizer.tokenize(text): 89 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 90 | split_tokens.append(sub_token) 91 | else: 92 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 93 | return split_tokens 94 | 95 | def convert_tokens_to_ids(self, tokens): 96 | """Converts a sequence of tokens into ids using the vocab.""" 97 | ids = [] 98 | for token in tokens: 99 | ids.append(self.vocab[token]) 100 | if len(ids) > self.max_len: 101 | logger.warning( 102 | "Token indices sequence length is longer than the specified maximum " 103 | " sequence length for this BERT model ({} > {}). Running this" 104 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) 105 | ) 106 | return ids 107 | 108 | def convert_ids_to_tokens(self, ids): 109 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 110 | tokens = [] 111 | for i in ids: 112 | tokens.append(self.ids_to_tokens[i]) 113 | return tokens 114 | 115 | def save_vocabulary(self, vocab_path): 116 | """Save the tokenizer vocabulary to a directory or file.""" 117 | index = 0 118 | with open(vocab_path, "w", encoding="utf-8") as writer: 119 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 120 | if index != token_index: 121 | logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." 122 | " Please check that the vocabulary is not corrupted!".format(vocab_path)) 123 | index = token_index 124 | writer.write(token + u'\n') 125 | index += 1 126 | return vocab_path 127 | 128 | @classmethod 129 | def from_pretrained(cls, vocab_path, *inputs, **kwargs): 130 | """ 131 | Instantiate a PreTrainedBertModel from a pre-trained model file. 132 | Download and cache the pre-trained model file if needed. 133 | """ 134 | # redirect to the cache, if necessary 135 | logger.info("loading vocabulary file {}".format(vocab_path)) 136 | 137 | # Instantiate tokenizer. 138 | tokenizer = cls(vocab_path, *inputs, **kwargs) 139 | return tokenizer 140 | 141 | 142 | class BasicTokenizer(object): 143 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 144 | 145 | def __init__(self, 146 | do_lower_case=True, 147 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 148 | """Constructs a BasicTokenizer. 149 | 150 | Args: 151 | do_lower_case: Whether to lower case the input. 152 | """ 153 | self.do_lower_case = do_lower_case 154 | self.never_split = never_split 155 | 156 | def tokenize(self, text): 157 | """Tokenizes a piece of text.""" 158 | text = self._clean_text(text) 159 | # This was added on November 1st, 2018 for the multilingual and Chinese 160 | # models. This is also applied to the English models now, but it doesn't 161 | # matter since the English models were not trained on any Chinese data 162 | # and generally don't have any Chinese data in them (there are Chinese 163 | # characters in the vocabulary because Wikipedia does have some Chinese 164 | # words in the English Wikipedia.). 165 | text = self._tokenize_chinese_chars(text) 166 | orig_tokens = whitespace_tokenize(text) 167 | split_tokens = [] 168 | for token in orig_tokens: 169 | if self.do_lower_case and token not in self.never_split: 170 | token = token.lower() 171 | token = self._run_strip_accents(token) 172 | split_tokens.extend(self._run_split_on_punc(token)) 173 | 174 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 175 | return output_tokens 176 | 177 | def _run_strip_accents(self, text): 178 | """Strips accents from a piece of text.""" 179 | text = unicodedata.normalize("NFD", text) 180 | output = [] 181 | for char in text: 182 | cat = unicodedata.category(char) 183 | if cat == "Mn": 184 | continue 185 | output.append(char) 186 | return "".join(output) 187 | 188 | def _run_split_on_punc(self, text): 189 | """Splits punctuation on a piece of text.""" 190 | if text in self.never_split: 191 | return [text] 192 | chars = list(text) 193 | i = 0 194 | start_new_word = True 195 | output = [] 196 | while i < len(chars): 197 | char = chars[i] 198 | if _is_punctuation(char): 199 | output.append([char]) 200 | start_new_word = True 201 | else: 202 | if start_new_word: 203 | output.append([]) 204 | start_new_word = False 205 | output[-1].append(char) 206 | i += 1 207 | 208 | return ["".join(x) for x in output] 209 | 210 | def _tokenize_chinese_chars(self, text): 211 | """Adds whitespace around any CJK character.""" 212 | output = [] 213 | for char in text: 214 | cp = ord(char) 215 | if self._is_chinese_char(cp): 216 | output.append(" ") 217 | output.append(char) 218 | output.append(" ") 219 | else: 220 | output.append(char) 221 | return "".join(output) 222 | 223 | def _is_chinese_char(self, cp): 224 | """Checks whether CP is the codepoint of a CJK character.""" 225 | # This defines a "chinese character" as anything in the CJK Unicode block: 226 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 227 | # 228 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 229 | # despite its name. The modern Korean Hangul alphabet is a different block, 230 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 231 | # space-separated words, so they are not treated specially and handled 232 | # like the all of the other languages. 233 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 234 | (cp >= 0x3400 and cp <= 0x4DBF) or # 235 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 236 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 237 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 238 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 239 | (cp >= 0xF900 and cp <= 0xFAFF) or # 240 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 241 | return True 242 | 243 | return False 244 | 245 | def _clean_text(self, text): 246 | """Performs invalid character removal and whitespace cleanup on text.""" 247 | output = [] 248 | for char in text: 249 | cp = ord(char) 250 | if cp == 0 or cp == 0xfffd or _is_control(char): 251 | continue 252 | if _is_whitespace(char): 253 | output.append(" ") 254 | else: 255 | output.append(char) 256 | return "".join(output) 257 | 258 | 259 | class WordpieceTokenizer(object): 260 | """Runs WordPiece tokenization.""" 261 | 262 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 263 | self.vocab = vocab 264 | self.unk_token = unk_token 265 | self.max_input_chars_per_word = max_input_chars_per_word 266 | 267 | def tokenize(self, text): 268 | """Tokenizes a piece of text into its word pieces. 269 | 270 | This uses a greedy longest-match-first algorithm to perform tokenization 271 | using the given vocabulary. 272 | 273 | For example: 274 | input = "unaffable" 275 | output = ["un", "##aff", "##able"] 276 | 277 | Args: 278 | text: A single token or whitespace separated tokens. This should have 279 | already been passed through `BasicTokenizer`. 280 | 281 | Returns: 282 | A list of wordpiece tokens. 283 | """ 284 | 285 | output_tokens = [] 286 | for token in whitespace_tokenize(text): 287 | chars = list(token) 288 | if len(chars) > self.max_input_chars_per_word: 289 | output_tokens.append(self.unk_token) 290 | continue 291 | 292 | is_bad = False 293 | start = 0 294 | sub_tokens = [] 295 | while start < len(chars): 296 | end = len(chars) 297 | cur_substr = None 298 | while start < end: 299 | substr = "".join(chars[start:end]) 300 | if start > 0: 301 | substr = "##" + substr 302 | if substr in self.vocab: 303 | cur_substr = substr 304 | break 305 | end -= 1 306 | if cur_substr is None: 307 | is_bad = True 308 | break 309 | sub_tokens.append(cur_substr) 310 | start = end 311 | 312 | if is_bad: 313 | output_tokens.append(self.unk_token) 314 | else: 315 | output_tokens.extend(sub_tokens) 316 | return output_tokens 317 | 318 | 319 | def _is_whitespace(char): 320 | """Checks whether `chars` is a whitespace character.""" 321 | # \t, \n, and \r are technically contorl characters but we treat them 322 | # as whitespace since they are generally considered as such. 323 | if char == " " or char == "\t" or char == "\n" or char == "\r": 324 | return True 325 | cat = unicodedata.category(char) 326 | if cat == "Zs": 327 | return True 328 | return False 329 | 330 | 331 | def _is_control(char): 332 | """Checks whether `chars` is a control character.""" 333 | # These are technically control characters but we count them as whitespace 334 | # characters. 335 | if char == "\t" or char == "\n" or char == "\r": 336 | return False 337 | cat = unicodedata.category(char) 338 | if cat.startswith("C"): 339 | return True 340 | return False 341 | 342 | 343 | def _is_punctuation(char): 344 | """Checks whether `chars` is a punctuation character.""" 345 | cp = ord(char) 346 | # We treat all non-letter/number ASCII as punctuation. 347 | # Characters such as "^", "$", and "`" are not in the Unicode 348 | # Punctuation class but we treat them as punctuation anyways, for 349 | # consistency. 350 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 351 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 352 | return True 353 | cat = unicodedata.category(char) 354 | if cat.startswith("P"): 355 | return True 356 | return False 357 | -------------------------------------------------------------------------------- /run_classifier_dataset_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT classification fine-tuning: utilities to work with GLUE tasks """ 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import csv 21 | import json 22 | import collections 23 | import logging 24 | import os 25 | import sys 26 | import numpy as np 27 | 28 | from rule import process 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | #run 'sample.json:0x9ba8a2' 'config.json:0x190174' 'vocab.txt:0xfec2ea' 'optimization.py:0xddfbb4' 'tokenization.py:0xd6a857' 'modeling.py:0x513a98' 'pytorch_model.bin:0x796f6e' 'pytorch_model2.bin:0xb7667b' 'pytorch_model3.bin:0x99124c' 'run_classifier_dataset_utils.py:0x09918f' 'run_classifier.py:0x45023b' 'rule.py:0x4e1774' 'city.txt:0x50dc61' 'railway_station.txt:0x2acf89' 'dishName.txt:0xa346b6' 'province.txt:0x2cbaf9' 'python3 run_classifier.py --test_data sample.json --result_file pred.json' --request-docker-image pytorch/pytorch:latest --request-memory 6g 33 | 34 | 35 | class InputExample(object): 36 | """A single training/test example for simple sequence classification.""" 37 | 38 | def __init__(self, guid, text_a, text_b=None, domain=None, intent=None, slots=None): 39 | """Constructs a InputExample. 40 | 41 | Args: 42 | guid: Unique id for the example. 43 | text_a: string. The untokenized text of the first sequence. For single 44 | sequence tasks, only this sequence must be specified. 45 | text_b: (Optional) string. The untokenized text of the second sequence. 46 | Only must be specified for sequence pair tasks. 47 | label: (Optional) string. The label of the example. This should be 48 | specified for train and dev examples, but not for test examples. 49 | """ 50 | self.guid = guid 51 | self.text_a = text_a 52 | self.text_b = text_b 53 | self.domain = domain 54 | self.intent = intent 55 | self.slots = slots 56 | 57 | 58 | class InputFeatures(object): 59 | """A single set of features of data.""" 60 | 61 | def __init__(self, 62 | input_ids, 63 | input_mask, 64 | segment_ids, 65 | domain_id, 66 | intent_id, 67 | slots_id, 68 | is_real_example=True): 69 | self.input_ids = input_ids 70 | self.input_mask = input_mask 71 | self.segment_ids = segment_ids 72 | self.domain_id = domain_id 73 | self.intent_id = intent_id 74 | self.slots_id = slots_id 75 | self.is_real_example = is_real_example 76 | 77 | 78 | class DataProcessor(object): 79 | """Base class for data converters for sequence classification data sets.""" 80 | 81 | def get_train_examples(self, data_dir): 82 | """Gets a collection of `InputExample`s for the train set.""" 83 | raise NotImplementedError() 84 | 85 | def get_dev_examples(self, data_dir): 86 | """Gets a collection of `InputExample`s for the dev set.""" 87 | raise NotImplementedError() 88 | 89 | def get_test_examples(self, data_dir): 90 | """Gets a collection of `InputExample`s for prediction.""" 91 | raise NotImplementedError() 92 | 93 | def get_labels(self): 94 | """Gets the list of labels for this data set.""" 95 | raise NotImplementedError() 96 | 97 | @classmethod 98 | def _read_tsv(cls, input_file, quotechar=None): 99 | """Reads a tab separated value file.""" 100 | with tf.gfile.Open(input_file, "r") as f: 101 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 102 | lines = [] 103 | for line in reader: 104 | lines.append(line) 105 | return lines 106 | 107 | @classmethod 108 | def _read_json(cls, input_file, quotechar=None): 109 | """Reads pandas csv file.""" 110 | data_json = json.load(open(input_file, encoding='utf8'), object_pairs_hook=collections.OrderedDict) 111 | 112 | return data_json 113 | 114 | 115 | class NLUProcessor(DataProcessor): 116 | """Processor for nlu data set.""" 117 | def __init__(self): 118 | pass 119 | 120 | def get_train_examples(self, train_data_path): 121 | """See base class.""" 122 | set_type = "train" 123 | data = self._read_json(train_data_path) 124 | examples = [] 125 | for (i, line) in enumerate(data): 126 | guid = "%s-%s" % (set_type, i) 127 | text_a, domain, intent, slots = line['text'], line['domain'], line['intent'], line['slots'] 128 | examples.append(InputExample(guid = guid, text_a = text_a, domain = domain, intent = intent, slots = slots)) 129 | 130 | return examples 131 | 132 | def get_dev_examples(self, eval_data_path): 133 | """See base class.""" 134 | pass 135 | 136 | def get_test_examples(self, test_data_path): 137 | """See base class.""" 138 | set_type = "test" 139 | data = self._read_json(test_data_path) 140 | examples = [] 141 | for (i, line) in enumerate(data): 142 | guid = "%s-%s" % (set_type, i) 143 | text_a = line['text'] 144 | examples.append(InputExample(guid = guid, text_a = text_a)) 145 | 146 | return examples 147 | 148 | def get_labels(self): 149 | """See base class.""" 150 | domain = ['app', 'bus', 'cinemas', 'contacts', 'cookbook', 'email', 'epg', 'flight', 'health', 'joke', 'lottery', 'map', 'match', 'message', 'music', 'news', 'novel', 'poetry', 'radio', 'riddle', 'stock', 'story', 'telephone', 'train', 'translation', 'tvchannel', 'video', 'weather', 'website'] 151 | intent = ['CLOSEPRICE_QUERY', 'CREATE', 'DATE_QUERY', 'DEFAULT', 'DIAL', 'DOWNLOAD', 'FORWARD', 'LAUNCH', 'LOOK_BACK', 'NUMBER_QUERY', 'OPEN', 'PLAY', 'POSITION', 'QUERY', 'REPLAY_ALL', 'REPLY', 'RISERATE_QUERY', 'ROUTE', 'SEARCH', 'SEND', 'SENDCONTACTS', 'TRANSLATION', 'VIEW'] 152 | slotsOri = ['Dest', 'Src', 'absIssue', 'area', 'artist', 'artistRole', 'author', 'awayName', 'category', 'code', 'content', 'datetime_date', 'datetime_time', 'decade', 'dishName', 'dynasty', 'endLoc_area', 'endLoc_city', 'endLoc_poi', 'endLoc_province', 'episode', 'film', 'headNum', 'homeName', 'ingredient', 'keyword', 'location_area', 'location_city', 'location_country', 'location_poi', 'location_province', 'media', 'name', 'payment', 'popularity', 'queryField', 'questionWord', 'receiver', 'relIssue', 'resolution', 'scoreDescr', 'season', 'song', 'startDate_date', 'startDate_time', 'startLoc_area', 'startLoc_city', 'startLoc_poi', 'startLoc_province', 'subfocus', 'tag', 'target', 'teleOperator', 'theatre', 'timeDescr', 'tvchannel', 'type', 'utensil', 'yesterday'] 153 | 154 | slots = [] 155 | slots.append("O") 156 | for slot in slotsOri: 157 | slots.append("B-"+slot) 158 | slots.append("I-"+slot) 159 | 160 | return {'domain':list(domain), 'intent':list(intent), 'slots':slots} 161 | 162 | 163 | def slots_convert(text, slots): 164 | """Convert slots to B-I-O form""" 165 | tokens = ["O"] * len(text) 166 | if slots: 167 | for slot, value in slots.items(): 168 | index = text.find(value) 169 | for i in range(len(value)): 170 | if i == 0: 171 | slot_token = "B-"+slot 172 | else: 173 | slot_token = "I-"+slot 174 | tokens[index+i] = slot_token 175 | 176 | return tokens 177 | 178 | 179 | def convert_examples_to_features(examples, domain_map, intent_map, slots_map, 180 | max_seq_length, tokenizer): 181 | """Converts a single `InputExample` into a single `InputFeatures`.""" 182 | features = [] 183 | for (ex_index, example) in enumerate(examples): 184 | ori_slots = slots_convert(example.text_a, example.slots) 185 | # tokens_a = tokenizer.tokenize(example.text_a) 186 | tokens_a = [] 187 | tokens_slots = [] 188 | for i, word in enumerate(example.text_a): 189 | token = tokenizer.tokenize(word) 190 | tokens_a.extend(token) 191 | if len(token) > 0: 192 | tokens_slots.append(ori_slots[i]) 193 | if not len(tokens_a) == len(tokens_slots): 194 | logger.info("********** Take Care! ***********") 195 | print(tokens_a) 196 | print(tokens_slots) 197 | assert len(tokens_a) == len(tokens_slots) 198 | 199 | # tokens_b = None 200 | # if example.text_b: 201 | # tokens_b = tokenizer.tokenize(example.text_b) 202 | 203 | # if tokens_b: 204 | # # Modifies `tokens_a` and `tokens_b` in place so that the total 205 | # # length is less than the specified length. 206 | # # Account for [CLS], [SEP], [SEP] with "- 3" 207 | # _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 208 | # else: 209 | # Account for [CLS] and [SEP] with "- 2" 210 | if len(tokens_a) > max_seq_length - 2: 211 | tokens_a = tokens_a[0:(max_seq_length - 2)] 212 | tokens_slots = tokens_slots[0:(max_seq_length - 2)] 213 | 214 | # The convention in BERT is: 215 | # (a) For sequence pairs: 216 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 217 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 218 | # (b) For single sequences: 219 | # tokens: [CLS] the dog is hairy . [SEP] 220 | # type_ids: 0 0 0 0 0 0 0 221 | # 222 | # Where "type_ids" are used to indicate whether this is the first 223 | # sequence or the second sequence. The embedding vectors for `type=0` and 224 | # `type=1` were learned during pre-training and are added to the wordpiece 225 | # embedding vector (and position vector). This is not *strictly* necessary 226 | # since the [SEP] token unambiguously separates the sequences, but it makes 227 | # it easier for the model to learn the concept of sequences. 228 | # 229 | # For classification tasks, the first vector (corresponding to [CLS]) is 230 | # used as the "sentence vector". Note that this only makes sense because 231 | # the entire model is fine-tuned. 232 | tokens = [] 233 | slots_id = [] 234 | segment_ids = [] 235 | tokens.append("[CLS]") 236 | slots_id.append(slots_map["O"]) 237 | segment_ids.append(0) 238 | for token, slots in zip(tokens_a, tokens_slots): 239 | tokens.append(token) 240 | slots_id.append(slots_map[slots]) 241 | segment_ids.append(0) 242 | tokens.append("[SEP]") 243 | slots_id.append(slots_map["O"]) 244 | segment_ids.append(0) 245 | 246 | # if tokens_b: 247 | # for token in tokens_b: 248 | # tokens.append(token) 249 | # segment_ids.append(1) 250 | # tokens.append("[SEP]") 251 | # segment_ids.append(1) 252 | 253 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 254 | 255 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 256 | # tokens are attended to. 257 | input_mask = [1] * len(input_ids) 258 | 259 | # Zero-pad up to the sequence length. 260 | while len(input_ids) < max_seq_length: 261 | input_ids.append(0) 262 | input_mask.append(0) 263 | segment_ids.append(0) 264 | slots_id.append(slots_map["O"]) 265 | 266 | assert len(input_ids) == max_seq_length 267 | assert len(input_mask) == max_seq_length 268 | assert len(segment_ids) == max_seq_length 269 | assert len(slots_id) == max_seq_length 270 | 271 | domain_id, intent_id = 0, 0 272 | if example.domain: 273 | domain_id = domain_map[example.domain] 274 | if example.intent: 275 | intent_id = intent_map[example.intent] 276 | 277 | # if ex_index < 1: 278 | # tf.logging.info("*** Example ***") 279 | # tf.logging.info("guid: %s" % (example.guid)) 280 | # tf.logging.info("tokens: %s" % " ".join( 281 | # [tokenization.printable_text(x) for x in tokens])) 282 | # tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 283 | # tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 284 | # tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 285 | # tf.logging.info("label: %s (id = %d)" % (example.label, label_id)) 286 | 287 | features.append(InputFeatures( 288 | input_ids=input_ids, 289 | input_mask=input_mask, 290 | segment_ids=segment_ids, 291 | domain_id=domain_id, 292 | intent_id=intent_id, 293 | slots_id=slots_id, 294 | is_real_example=True)) 295 | 296 | return features 297 | 298 | 299 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 300 | """Truncates a sequence pair in place to the maximum length.""" 301 | 302 | # This is a simple heuristic which will always truncate the longer sequence 303 | # one token at a time. This makes more sense than truncating an equal percent 304 | # of tokens from each, since if one sequence is very short then each token 305 | # that's truncated likely contains more information than a longer sequence. 306 | while True: 307 | total_length = len(tokens_a) + len(tokens_b) 308 | if total_length <= max_length: 309 | break 310 | if len(tokens_a) > len(tokens_b): 311 | tokens_a.pop() 312 | else: 313 | tokens_b.pop() 314 | 315 | 316 | def get_slots(slots_id, slots_map, text): 317 | slots = {} 318 | tokens_slots = [] 319 | for i in range(1, min(len(text)+1, len(slots_id))): 320 | tokens_slots.append(slots_map[slots_id[i]]) 321 | 322 | i = 0 323 | while i < len(text): 324 | if not tokens_slots[i] == "O" and tokens_slots[i][:2] == "B-": 325 | slot = tokens_slots[i][2:] 326 | value = [text[i]] 327 | i += 1 328 | while i < len(text) and not tokens_slots[i] == "O" and not tokens_slots[i][:2] == "B-": 329 | value.append(text[i]) 330 | i += 1 331 | slots[slot] = "".join(value) 332 | i -= 1 333 | i += 1 334 | 335 | return slots 336 | 337 | 338 | def write_result(output_predict_file, dic_dir, result, predict_examples, domain_map, intent_map, slots_map): 339 | """ Write result to json file""" 340 | result_json = [] 341 | 342 | domain_map = {v:k for k, v in domain_map.items()} 343 | intent_map = {v:k for k, v in intent_map.items()} 344 | slots_map = {v:k for k, v in slots_map.items()} 345 | 346 | for i, (pred, example) in enumerate(zip(result, predict_examples)): 347 | text = example.text_a 348 | 349 | domain_id = np.argmax(pred["domain"]) 350 | intent_id = np.argmax(pred["intent"]) 351 | slots_id = np.argmax(pred["slots"], axis = -1) 352 | slots = get_slots(slots_id, slots_map, text) 353 | 354 | d = collections.OrderedDict() 355 | d["text"] = text 356 | d["domain"] = domain_map[domain_id] 357 | d["intent"] = intent_map[intent_id] 358 | d["slots"] = slots 359 | result_json.append(d) 360 | 361 | # json.dump(result_json, open(output_predict_file, 'w', encoding = 'utf-8'), ensure_ascii = False, indent = 2) 362 | json.dump(process(result_json, dic_dir), open(output_predict_file, 'w', encoding = 'utf-8'), ensure_ascii = False, indent = 2) 363 | 364 | 365 | processors = { 366 | "nlu": NLUProcessor 367 | } 368 | -------------------------------------------------------------------------------- /rule.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | from collections import OrderedDict 5 | from domain_rule import domain_rule 6 | 7 | def cookbook(text, pred, dishName): 8 | slots = {} 9 | for k, v in pred['slots'].items(): 10 | if k == 'dishName' and v not in dishName: 11 | slots['ingredient'] = v 12 | elif k == 'ingredient' and v in dishName and k not in slots: 13 | slots['dishName'] = v 14 | else: 15 | slots[k] = v 16 | 17 | return slots 18 | 19 | 20 | def bus(text, pred): 21 | station = [] 22 | for k, v in pred['slots'].items(): 23 | station.append((v, re.search(v, text).span()[0])) 24 | slots = {} 25 | if len(station) == 1: 26 | slots = {'Dest': station[0][0]} 27 | else: 28 | station.sort(key = lambda x:x[1]) 29 | slots = {'Src':station[0][0], 'Dest':station[1][0]} 30 | 31 | return slots 32 | 33 | 34 | def train(text, pred, province, city, railway_station): 35 | location = set({'startLoc_area', 'endLoc_area', 'startLoc_city', 'endLoc_city', 'startLoc_province', 'endLoc_province', 'startLoc_poi', 'endLoc_poi'}) 36 | slots = {} 37 | for k, v in pred['slots'].items(): 38 | name = k 39 | if k in location: 40 | prefix = k.split('_')[0] 41 | if v in province: 42 | name = prefix + '_province' 43 | elif v in city: 44 | name = prefix + '_city' 45 | elif v in railway_station: 46 | name = prefix + '_area' 47 | else: 48 | name = prefix + '_poi' 49 | slots[name] = v 50 | 51 | return slots 52 | 53 | 54 | def map_pos(text, pred, province, city, railway_station): 55 | location = set({'location_province', 'location_city', 'location_poi', 'location_area'}) 56 | slots = {} 57 | for k, v in pred['slots'].items(): 58 | name = k 59 | if k in location: 60 | prefix = 'location' 61 | if v in province: 62 | name = prefix + '_province' 63 | elif v in city: 64 | name = prefix + '_city' 65 | elif v in railway_station: 66 | name = prefix + '_area' 67 | else: 68 | name = prefix + '_poi' 69 | slots[name] = v 70 | 71 | return slots 72 | 73 | 74 | def process(result, dic_dir): 75 | result = domain_rule(result,dic_dir) 76 | table = { 77 | 'music': [{'PLAY', 'SEARCH'}, {'category', 'song', 'artist'}] , 78 | 'match': [{'QUERY'}, {'datetime_date', 'type', 'homeName', 'name', 'category', 'awayName'}] , 79 | 'joke': [{'QUERY'}, set()] , 80 | 'weather': [{'QUERY'}, {'subfocus', 'datetime_date', 'location_city', 'questionWord'}] , 81 | 'novel': [{'QUERY'}, {'category', 'name', 'popularity', 'author'}] , 82 | 'flight': [{'QUERY'}, {'startLoc_area', 'startLoc_city', 'startDate_date', 'startDate_date', 'startDate_time', 'endLoc_area', 'endLoc_poi', 'startLoc_poi', 'endLoc_city'}] , 83 | 'health': [{'QUERY'}, {'keyword'}] , 84 | 'poetry': [{'QUERY', 'DEFAULT'}, {'keyword', 'queryField', 'dynasty', 'name', 'author'}] , 85 | 'video': [{'QUERY'}, {'season', 'scoreDescr', 'datetime_date', 'tag', 'popularity', 'name', 'artist', 'category', 'timeDescr', 'date', 'decade', 'resolution', 'artistRole', 'payment', 'area', 'episode'}] , 86 | 'epg': [{'QUERY', 'LOOK_BACK'}, {'datetime_date', 'tvchannel', 'name', 'category', 'code', 'datetime_time', 'area'}] , 87 | 'message': [{'SENDCONTACTS', 'VIEW', 'SEND'}, {'content', 'receiver', 'name', 'category', 'teleOperator', 'headNum'}] , 88 | 'contacts': [{'QUERY', 'CREATE'}, {'name', 'code'}] , 89 | 'stock': [{'CLOSEPRICE_QUERY', 'QUERY', 'RISERATE_QUERY'}, {'yesterday', 'name', 'code'}] , 90 | 'radio': [{'LAUNCH'}, {'category', 'location_province', 'name', 'code'}] , 91 | 'telephone': [{'QUERY', 'DIAL'}, {'teleOperator', 'category', 'name'}] , 92 | 'map': [{'POSITION', 'ROUTE'}, {'startLoc_area', 'startLoc_city', 'endLoc_province', 'location_city', 'type', 'endLoc_area', 'location_poi', 'endLoc_poi', 'startLoc_poi', 'location_area', 'location_province', 'endLoc_city'}] , 93 | 'cinemas': [{'QUERY', 'DATE_QUERY'}, {'datetime_date', 'location_city', 'film', 'name', 'theatre', 'timeDescr', 'category', 'datetime_time'}] , 94 | 'lottery': [{'QUERY', 'NUMBER_QUERY'}, {'category', 'relIssue', 'datetime_date', 'absIssue', 'name'}] , 95 | 'story': [{'QUERY'}, {'category'}] , 96 | 'translation': [{'TRANSLATION'}, {'target', 'content'}] , 97 | 'tvchannel': [{'PLAY'}, {'category', 'resolution', 'name', 'code'}] , 98 | 'cookbook': [{'QUERY'}, {'keyword', 'dishName', 'ingredient', 'utensil'}] , 99 | 'bus': [{'QUERY'}, {'Dest', 'Src'}] , 100 | 'email': [{'REPLY', 'LAUNCH', 'REPLAY_ALL', 'FORWARD', 'CREATE', 'SEND'}, {'content', 'name'}] , 101 | 'app': [{'QUERY', 'DOWNLOAD', 'LAUNCH'}, {'name'}] , 102 | 'train': [{'QUERY'}, {'startLoc_area', 'startLoc_city', 'startDate_date', 'endLoc_province', 'startDate_time', 'endLoc_area', 'category', 'startLoc_province', 'startLoc_poi', 'endLoc_city'}] , 103 | 'website': [{'OPEN'}, {'name'}] , 104 | 'news': [{'PLAY'}, {'datetime_date', 'location_city', 'location_country', 'keyword', 'category', 'datetime_time', 'media', 'location_province'}] , 105 | 'riddle': [{'QUERY'}, {'category'}] 106 | } 107 | 108 | 109 | # app = set([a.strip() for a in open(os.path.join(dic_dir, "app.txt"), 'r').readlines()]) 110 | # website = set([a.strip() for a in open(os.path.join(dic_dir, "website.txt"), 'r').readlines()]) 111 | dishName = set([a.strip() for a in open(os.path.join(dic_dir, "dishName.txt"), 'r', encoding='UTF-8').readlines()]) 112 | province = set([a.strip() for a in open(os.path.join(dic_dir, "province.txt"), 'r', encoding='UTF-8').readlines()]) 113 | city = set([a.strip() for a in open(os.path.join(dic_dir, "city.txt"), 'r', encoding='UTF-8').readlines()]) 114 | railway_station = set([a.strip() for a in open(os.path.join(dic_dir, "railway_station.txt"), 'r', encoding='UTF-8').readlines()]) 115 | 116 | 117 | for i, pred in enumerate(result): 118 | text = pred['text'] 119 | if pred['domain'] == 'app': 120 | name = pred['slots'].get('name', None) 121 | if re.search('下载', text): 122 | result[i]['intent'] = 'DOWNLOAD' 123 | elif re.search('搜索', text) or re.search('找到', text): 124 | result[i]['intent'] = 'QUERY' 125 | elif re.search('打开', text) or re.search('开启', text) or re.search('启动', text) or re.search('进入', text): 126 | result[i]['intent'] = 'LAUNCH' 127 | elif name and (re.match('.*搜.*' + name + '.*', text) or re.match('.*找.*' + name + '.*', text)): 128 | result[i]['intent'] = 'QUERY' 129 | else: 130 | result[i]['intent'] = 'LAUNCH' 131 | if name: 132 | pred['slots']['name'] = name.lower() 133 | 134 | elif pred['domain'] == 'health': 135 | result[i]['intent'] = 'QUERY' 136 | 137 | elif pred['domain'] == 'joke': 138 | result[i]['intent'] = 'QUERY' 139 | 140 | elif pred['domain'] == 'cookbook': 141 | result[i]['slots'] = cookbook(text, pred, dishName) 142 | 143 | elif pred['domain'] == 'email': 144 | if re.search('转发', text): 145 | result[i]['intent'] = 'FORWARD' 146 | elif re.search('全部', text): 147 | result[i]['intent'] = 'REPLAY_ALL' 148 | elif re.search('回复', text) or re.search('答复', text): 149 | result[i]['intent'] = 'REPLY' 150 | elif re.search('发邮件', text): 151 | result[i]['intent'] = 'SEND' 152 | cnt = re.search('说他', text) 153 | if not cnt: 154 | cnt = re.search('说她', text) 155 | if not cnt: 156 | cnt = re.search('叫他', text) 157 | if not cnt: 158 | cnt = re.search('叫她', text) 159 | if not cnt: 160 | cnt = re.search('说', text) 161 | if not cnt: 162 | cnt = re.search('叫', text) 163 | if cnt: 164 | cnt = cnt.span() 165 | if len(text) > cnt[0]: 166 | result[i]['slots']['content'] = text[cnt[1]:] 167 | elif re.search('打开', text) or re.search('开启', text) or re.search('查看', text) or re.search('给我看', text): 168 | result[i]['intent'] = 'LAUNCH' 169 | elif re.search('写', text): 170 | result[i]['intent'] = 'CREATE' 171 | 172 | elif pred['domain'] == 'novel': 173 | result[i]['intent'] = 'QUERY' 174 | 175 | elif pred['domain'] == 'poetry': 176 | # if not pred['slots'] == {}: 177 | # result[i]['intent'] = 'QUERY' 178 | # else: 179 | # result[i]['intent'] = 'DEFAULT' 180 | pass 181 | 182 | elif pred['domain'] == 'radio': 183 | result[i]['intent'] = 'LAUNCH' 184 | result[i]['slots'] = map_pos(text, pred, province, city, railway_station) 185 | 186 | elif pred['domain'] == 'riddle': 187 | result[i]['intent'] = 'QUERY' 188 | 189 | elif pred['domain'] == 'story': 190 | result[i]['intent'] = 'QUERY' 191 | 192 | elif pred['domain'] == 'website': 193 | result[i]['intent'] = 'OPEN' 194 | 195 | elif pred['domain'] == 'weather': 196 | result[i]['intent'] = 'QUERY' 197 | result[i]['slots'] = map_pos(text, pred, province, city, railway_station) 198 | time = set({'datetime_date', 'datetime_time', 'startDate_date', 'startDate_time'}) 199 | slots = {} 200 | for k, v in pred['slots'].items(): 201 | name = k 202 | if k in time: 203 | name = 'datetime_' + k.split('_')[-1] 204 | slots[name] = v 205 | result[i]['slots'] = slots 206 | #************************ bus, train, flight, map, news ************************# 207 | elif pred['domain'] == 'bus': 208 | if re.search('动车', text) or re.search('高铁', text): 209 | result[i]['domain'] = 'train' 210 | result[i]['intent'] = 'QUERY' 211 | result[i]['slots'] = train(text, pred, province, city, railway_station) 212 | time = set({'datetime_date', 'datetime_time', 'startDate_date', 'startDate_time'}) 213 | slots = {} 214 | for k, v in pred['slots'].items(): 215 | name = k 216 | if k in time: 217 | name = 'startDate_' + k.split('_')[-1] 218 | slots[name] = v 219 | result[i]['slots'] = slots 220 | elif re.search('飞机', text) or re.search('航班', text) or re.search('机票', text): 221 | result[i]['domain'] = 'flight' 222 | result[i]['intent'] = 'QUERY' 223 | result[i]['slots'] = train(text, pred, province, city, railway_station) 224 | else: 225 | result[i]['intent'] = 'QUERY' 226 | result[i]['slots'] = bus(text, pred) 227 | 228 | elif pred['domain'] == 'train': 229 | if re.search('汽车', text): 230 | result[i]['domain'] = 'bus' 231 | result[i]['intent'] = 'QUERY' 232 | result[i]['slots'] = bus(text, pred) 233 | elif re.search('飞机', text) or re.search('航班', text) or re.search('机票', text): 234 | result[i]['domain'] = 'flight' 235 | result[i]['intent'] = 'QUERY' 236 | result[i]['slots'] = train(text, pred, province, city, railway_station) 237 | else: 238 | result[i]['intent'] = 'QUERY' 239 | result[i]['slots'] = train(text, pred, province, city, railway_station) 240 | time = set({'datetime_date', 'datetime_time', 'startDate_date', 'startDate_time'}) 241 | slots = {} 242 | for k, v in pred['slots'].items(): 243 | name = k 244 | if k in time: 245 | name = 'startDate_' + k.split('_')[-1] 246 | slots[name] = v 247 | result[i]['slots'] = slots 248 | 249 | elif pred['domain'] == 'flight': 250 | if re.search('动车', text) or re.search('高铁', text): 251 | result[i]['domain'] = 'train' 252 | result[i]['intent'] = 'QUERY' 253 | result[i]['slots'] = train(text, pred, province, city, railway_station) 254 | time = set({'datetime_date', 'datetime_time', 'startDate_date', 'startDate_time'}) 255 | slots = {} 256 | for k, v in pred['slots'].items(): 257 | name = k 258 | if k in time: 259 | name = 'startDate_' + k.split('_')[-1] 260 | slots[name] = v 261 | result[i]['slots'] = slots 262 | elif re.search('汽车', text): 263 | result[i]['domain'] = 'bus' 264 | result[i]['intent'] = 'QUERY' 265 | result[i]['slots'] = bus(text, pred) 266 | else: 267 | result[i]['intent'] = 'QUERY' 268 | result[i]['slots'] = train(text, pred, province, city, railway_station) 269 | 270 | elif pred['domain'] == 'map': 271 | if pred['intent'] == 'POSITION': 272 | result[i]['slots'] = map_pos(text, pred, province, city, railway_station) 273 | elif pred['intent'] == 'ROUTE': 274 | result[i]['slots'] = train(text, pred, province, city, railway_station) 275 | 276 | elif pred['domain'] == 'news': 277 | result[i]['intent'] = 'PLAY' 278 | result[i]['slots'] = map_pos(text, pred, province, city, railway_station) 279 | time = set({'datetime_date', 'datetime_time', 'startDate_date', 'startDate_time'}) 280 | slots = {} 281 | for k, v in pred['slots'].items(): 282 | name = k 283 | if k in time: 284 | name = 'datetime_' + k.split('_')[-1] 285 | slots[name] = v 286 | result[i]['slots'] = slots 287 | 288 | #***********************************************************************# 289 | 290 | elif pred['domain'] == 'translation': 291 | result[i]['intent'] = 'TRANSLATION' 292 | 293 | elif pred['domain'] == 'cinemas': 294 | if re.search('什么时候', text) or re.search('何时', text): 295 | result[i]['intent'] = 'DATE_QUERY' 296 | else: 297 | result[i]['intent'] = 'QUERY' 298 | time = set({'datetime_date', 'datetime_time', 'startDate_date', 'startDate_time'}) 299 | slots = {} 300 | for k, v in pred['slots'].items(): 301 | name = k 302 | if k in time: 303 | name = 'datetime_' + k.split('_')[-1] 304 | slots[name] = v 305 | result[i]['slots'] = slots 306 | 307 | elif pred['domain'] == 'video': 308 | result[i]['intent'] = 'QUERY' 309 | 310 | elif pred['domain'] == 'contacts': 311 | if re.search('新建', text) or re.search('添加', text): 312 | result[i]['intent'] = 'CREATE' 313 | else: 314 | result[i]['intent'] = 'QUERY' 315 | 316 | elif pred['domain'] == 'telephone': 317 | telep = ['移动', '联通', '电信'] 318 | # if re.search('呼叫', text) or re.search('打', text) or re.search('拨', text): 319 | # result[i]['intent'] = 'DIAL' 320 | # else: 321 | # result[i]['intent'] = 'QUERY' 322 | if re.search('hello word', text.lower()): 323 | span = re.search('hello word', text.lower()).span() 324 | result[i]['slots']['name'] = text[span[0]:span[1]] 325 | for t in telep: 326 | if 'teleOperator' in result[i]['slots']: 327 | break 328 | if re.search(t, text): 329 | span = re.search(t, text).span() 330 | result[i]['slots']['teleOperator'] = text[span[0]:span[1]] 331 | 332 | 333 | elif pred['domain'] == 'message': 334 | # if re.search('查看', text): 335 | # result[i]['intent'] = 'VIEW' 336 | # elif re.search('电话', text) or re.search('号码', text): 337 | # result[i]['intent'] = 'SENDCONTACTS' 338 | # elif re.search('短信', text) or re.search('消息', text) or re.search('简讯', text) or re.search('信息', text) or re.search('短讯', text): 339 | # result[i]['intent'] = 'SEND' 340 | # cnt = re.search('说他', text) 341 | # if not cnt: 342 | # cnt = re.search('说她', text) 343 | # if not cnt: 344 | # cnt = re.search('叫他', text) 345 | # if not cnt: 346 | # cnt = re.search('叫她', text) 347 | # if not cnt: 348 | # cnt = re.search('说', text) 349 | # if not cnt: 350 | # cnt = re.search('叫', text) 351 | # if cnt and 'content' not in result[i]['slots']: 352 | # cnt = cnt.span() 353 | # if len(text) > cnt[0]: 354 | # result[i]['slots']['content'] = text[cnt[1]:] 355 | pass 356 | 357 | elif pred['domain'] == 'tvchannel': 358 | result[i]['intent'] = 'PLAY' 359 | 360 | elif pred['domain'] == 'epg': 361 | if re.search('回放', text) or re.search('回看', text): 362 | result[i]['intent'] = 'LOOK_BACK' 363 | else: 364 | result[i]['intent'] = 'QUERY' 365 | 366 | elif pred['domain'] == 'lottery': 367 | if pred['slots'] == {}: 368 | result[i]['intent'] = 'QUERY' 369 | else: 370 | result[i]['intent'] = 'NUMBER_QUERY' 371 | 372 | elif pred['domain'] == 'music': 373 | # if re.search('有什么', text) or re.search('搜索', text) or re.search('找一首', text) or re.search('查一下', text): 374 | # result[i]['intent'] = 'SEARCH' 375 | # else: 376 | # result[i]['intent'] = 'PLAY' 377 | pass 378 | 379 | elif pred['domain'] == 'stock': 380 | # yesterday = ['昨天', '昨日'] 381 | # if re.search('收盘价', text): 382 | # result[i]['intent'] = 'CLOSEPRICE_QUERY' 383 | # elif re.search('涨', text) and re.search('跌', text): 384 | # result[i]['intent'] = 'RISERATE_QUERY' 385 | # elif pred['intent'] not in table[pred['domain']][0]: 386 | # result[i]['intent'] = 'QUERY' 387 | 388 | # for t in yesterday: 389 | # if 'yesterday' in result[i]['slots']: 390 | # break 391 | # if re.search(t, text): 392 | # span = re.search(t, text).span() 393 | # result[i]['slots']['yesterday'] = text[span[0]:span[1]] 394 | pass 395 | 396 | elif pred['domain'] == 'match': 397 | result[i]['intent'] = 'QUERY' 398 | 399 | slots = OrderedDict() 400 | for k, v in pred['slots'].items(): 401 | if k in table[pred['domain']][1]: 402 | slots[k] = v 403 | 404 | result[i]['slots'] = slots 405 | 406 | return result 407 | 408 | if __name__ == '__main__': 409 | result = json.load(open('result/test_result.json', encoding = 'utf8'), object_pairs_hook = OrderedDict) 410 | result = process(result, 'dataSet/dic') 411 | json.dump(result, open('result/test_result.json', 'w'), ensure_ascii = False, indent = 2) 412 | -------------------------------------------------------------------------------- /dataSet/dic/railway_station.txt: -------------------------------------------------------------------------------- 1 | 阿城 2 | 阿尔山 3 | 阿金 4 | 阿克苏 5 | 阿克陶 6 | 阿拉山口 7 | 阿里河 8 | 阿龙山 9 | 阿木尔 10 | 阿图什 11 | 安达 12 | 安德 13 | 安广 14 | 安化 15 | 安家 16 | 安康 17 | 安口窑 18 | 安龙 19 | 安陆 20 | 安平 21 | 安庆 22 | 安庆西 23 | 安仁 24 | 鞍山 25 | 鞍山西 26 | 潜山 27 | 安顺 28 | 安亭北 29 | 安图 30 | 安溪 31 | 安阳 32 | 安阳东 33 | 昂昂溪 34 | 鳌江 35 | 敖力布告 36 | 巴楚 37 | 八达岭 38 | 巴东 39 | 八角台 40 | 巴林 41 | 八面城 42 | 八面通 43 | 巴山 44 | 八仙筒 45 | 鲅鱼圈 46 | 巴中 47 | 霸州 48 | 白城 49 | 白河 50 | 白河东 51 | 白河县 52 | 白涧 53 | 白奎堡 54 | 白狼 55 | 白泉 56 | 百色 57 | 白沙 58 | 白山市 59 | 白石山 60 | 白水江 61 | 白音胡硕 62 | 白银市 63 | 白音他拉 64 | 白银西 65 | 宝坻 66 | 保定 67 | 保定东 68 | 宝华山 69 | 宝鸡 70 | 宝鸡南 71 | 保康 72 | 宝林 73 | 宝龙山 74 | 宝清 75 | 宝泉岭 76 | 包头 77 | 包头东 78 | 北安 79 | 北碚 80 | 北戴河 81 | 北海 82 | 北滘 83 | 北京 84 | 北京北 85 | 北京东 86 | 北京南 87 | 北京西 88 | 北流 89 | 北马圈子 90 | 北票南 91 | 北沈阳北 92 | 北台 93 | 北屯市 94 | 背荫河 95 | 栟茶 96 | 本溪 97 | 本溪湖 98 | 蚌埠 99 | 蚌埠南 100 | 笔架山 101 | 碧江 102 | 滨海 103 | 滨海北 104 | 滨江 105 | 博鳌 106 | 博克图 107 | 博乐 108 | 勃利 109 | 博山 110 | 泊头 111 | 博兴 112 | 亳州 113 | 布海 114 | 布列开 115 | 蔡家沟 116 | 蔡家坡 117 | 苍南 118 | 苍石 119 | 沧州 120 | 沧州西 121 | 草海 122 | 草河口 123 | 草市 124 | 曹县 125 | 册亨 126 | 岑溪 127 | 茶陵 128 | 茶陵南 129 | 柴岗 130 | 柴河 131 | 长春 132 | 长春南 133 | 长春西 134 | 常德 135 | 长甸 136 | 长葛 137 | 昌乐 138 | 昌黎 139 | 长岭子 140 | 常平 141 | 昌平 142 | 昌平北 143 | 长沙 144 | 长沙南 145 | 长山屯 146 | 长寿 147 | 长寿北 148 | 长汀 149 | 长汀镇 150 | 昌图 151 | 昌图西 152 | 长兴 153 | 长兴南 154 | 长阳 155 | 长垣 156 | 长征 157 | 长治 158 | 长治北 159 | 常州 160 | 常州北 161 | 巢湖 162 | 潮汕 163 | 潮阳 164 | 朝阳镇 165 | 潮州 166 | 晨明 167 | 辰清 168 | 辰溪 169 | 陈相屯 170 | 郴州 171 | 郴州西 172 | 承德 173 | 承德东 174 | 成都 175 | 成都东 176 | 成都南 177 | 成高子 178 | 城固 179 | 成吉思汗 180 | 城阳 181 | 赤壁 182 | 赤壁北 183 | 赤峰 184 | 赤峰西 185 | 池州 186 | 重庆 187 | 重庆北 188 | 重庆南 189 | 崇左 190 | 楚山 191 | 滁州 192 | 滁州北 193 | 春湾 194 | 春阳 195 | 慈利 196 | 磁山 197 | 磁县 198 | 磁窑 199 | 嵯岗 200 | 大安 201 | 大安北 202 | 大巴 203 | 大堡 204 | 大埔 205 | 大成 206 | 大关 207 | 大官屯 208 | 大红旗 209 | 大虎山 210 | 达家沟 211 | 大理 212 | 大荔 213 | 大连 214 | 大连北 215 | 大林 216 | 大陆号 217 | 大平房 218 | 大庆 219 | 大青沟 220 | 大庆西 221 | 大石桥 222 | 大石头 223 | 大石寨 224 | 大同 225 | 大屯 226 | 大武口 227 | 大兴 228 | 大兴沟 229 | 大雁 230 | 大杨树 231 | 大冶北 232 | 大营 233 | 大英东 234 | 大营镇 235 | 大营子 236 | 达州 237 | 大竹园 238 | 带岭 239 | 代县 240 | 岱岳 241 | 丹东 242 | 丹凤 243 | 丹徒 244 | 丹霞山 245 | 丹阳 246 | 丹阳北 247 | 砀山 248 | 当阳 249 | 到保 250 | 刀尔登 251 | 道清 252 | 道州 253 | 德安 254 | 德伯斯 255 | 德昌 256 | 得耳布尔 257 | 德惠 258 | 德惠西 259 | 德令哈 260 | 德清 261 | 德清西 262 | 德阳 263 | 德州 264 | 德州东 265 | 登沙河 266 | 灯塔 267 | 邓州 268 | 滴道 269 | 低窝铺 270 | 低庄 271 | 甸心 272 | 定边 273 | 定南 274 | 定陶 275 | 定西 276 | 定襄 277 | 定远 278 | 定州 279 | 定州东 280 | 东安东 281 | 东边井 282 | 东戴河 283 | 东二道河 284 | 东方 285 | 东方红 286 | 东丰 287 | 东莞 288 | 东莞东 289 | 东光 290 | 东海 291 | 东海县 292 | 东津 293 | 东京城 294 | 东来 295 | 东明村 296 | 东明县 297 | 东升 298 | 东胜 299 | 东胜西 300 | 东台 301 | 东通化 302 | 东乡 303 | 东辛庄 304 | 东营 305 | 东淤地 306 | 东镇 307 | 东至 308 | 东庄 309 | 豆罗 310 | 杜家 311 | 都江堰 312 | 独山 313 | 都匀 314 | 对青山 315 | 兑镇 316 | 敦化 317 | 敦煌 318 | 峨边 319 | 额济纳 320 | 峨眉 321 | 鄂州 322 | 鄂州东 323 | 恩施 324 | 二道湾 325 | 二连 326 | 二龙 327 | 二龙山屯 328 | 二密河 329 | 发耳 330 | 范家屯 331 | 繁峙 332 | 防城港北 333 | 肥东 334 | 费县 335 | 汾阳 336 | 分宜 337 | 丰城 338 | 丰城南 339 | 丰都 340 | 奉化 341 | 凤凰城 342 | 丰乐镇 343 | 风陵渡 344 | 丰顺 345 | 冯屯 346 | 凤县 347 | 凤阳 348 | 丰镇 349 | 凤州 350 | 佛山 351 | 福安 352 | 富川 353 | 福鼎 354 | 富海 355 | 福海 356 | 富锦 357 | 富拉尔基 358 | 福利屯 359 | 涪陵 360 | 涪陵北 361 | 阜南 362 | 抚宁 363 | 阜宁 364 | 福清 365 | 福泉 366 | 抚顺 367 | 抚顺北 368 | 扶绥 369 | 富县 370 | 富县东 371 | 阜新 372 | 阜阳 373 | 富裕 374 | 扶余 375 | 扶余北 376 | 富源 377 | 抚远 378 | 抚州 379 | 福州 380 | 抚州东 381 | 福州南 382 | 嘎什甸子 383 | 盖州 384 | 盖州西 385 | 干沟 386 | 甘谷 387 | 甘河 388 | 甘洛 389 | 甘旗卡 390 | 甘泉 391 | 甘泉北 392 | 赶水 393 | 赣州 394 | 高安 395 | 高碑店 396 | 高碑店东 397 | 藁城 398 | 高村 399 | 皋兰 400 | 高密 401 | 高平 402 | 高桥镇 403 | 高山子 404 | 高台 405 | 高滩 406 | 高兴 407 | 高邑 408 | 高邑西 409 | 高州 410 | 葛店南 411 | 格尔木 412 | 葛根庙 413 | 革镇堡 414 | 根河 415 | 公庙子 416 | 工农湖 417 | 弓棚子 418 | 共青城 419 | 巩义 420 | 巩义南 421 | 公营子 422 | 公主岭 423 | 公主岭南 424 | 沟帮子 425 | 固安 426 | 谷城 427 | 古城镇 428 | 古东 429 | 孤家子 430 | 古交 431 | 古浪 432 | 古莲 433 | 固始 434 | 古田 435 | 固原 436 | 菇园 437 | 古镇 438 | 固镇 439 | 瓜州 440 | 关林 441 | 灌水 442 | 官厅 443 | 官厅西 444 | 冠豸山 445 | 官字井 446 | 广安 447 | 广安南 448 | 广德 449 | 广汉 450 | 光明城 451 | 广宁寺 452 | 广水 453 | 广通北 454 | 广元 455 | 光泽 456 | 广州 457 | 广州北 458 | 广州东 459 | 广州南 460 | 贵定 461 | 贵定南 462 | 贵港 463 | 桂林 464 | 桂林北 465 | 归流河 466 | 桂平 467 | 贵溪 468 | 贵阳 469 | 郭家店 470 | 郭磊庄 471 | 果松 472 | 涡阳 473 | 哈尔滨 474 | 哈尔滨东 475 | 哈尔滨西 476 | 哈拉海 477 | 哈拉苏 478 | 蛤蟆塘 479 | 哈密 480 | 海安县 481 | 海北 482 | 海城 483 | 海城西 484 | 海口 485 | 海口东 486 | 海拉尔 487 | 海林 488 | 海龙 489 | 海伦 490 | 海宁 491 | 海宁西 492 | 海石湾 493 | 海坨子 494 | 海湾 495 | 海阳 496 | 韩城 497 | 汉川 498 | 寒葱沟 499 | 邯郸 500 | 邯郸东 501 | 汉沽 502 | 涵江 503 | 汉口 504 | 寒岭 505 | 汉寿 506 | 汉阴 507 | 汉源 508 | 汉中 509 | 杭州 510 | 杭州东 511 | 浩良河 512 | 鹤北 513 | 鹤壁 514 | 鹤壁东 515 | 河边 516 | 合川 517 | 河唇 518 | 合肥 519 | 合肥北城 520 | 合肥南 521 | 鹤岗 522 | 河津 523 | 和静 524 | 河口北 525 | 河口南 526 | 河口前 527 | 鹤立 528 | 和龙 529 | 和平 530 | 合浦 531 | 鹤庆 532 | 和什托洛盖 533 | 贺胜桥东 534 | 和田 535 | 合阳 536 | 河源 537 | 菏泽 538 | 贺州 539 | 黑河 540 | 黑井 541 | 黑水 542 | 黑台 543 | 横道河子 544 | 横峰 545 | 横沟桥东 546 | 衡南 547 | 衡山 548 | 衡山西 549 | 衡水 550 | 衡阳 551 | 衡阳东 552 | 红安 553 | 红安西 554 | 红光镇 555 | 红果 556 | 洪河 557 | 红花沟 558 | 宏庆 559 | 红山 560 | 红寺堡 561 | 洪洞 562 | 洪洞西 563 | 红星 564 | 红兴隆 565 | 红彦 566 | 侯马 567 | 侯马西 568 | 鲘门 569 | 呼和浩特 570 | 呼和浩特东 571 | 湖口 572 | 呼兰 573 | 虎林 574 | 葫芦岛 575 | 葫芦岛北 576 | 虎门 577 | 虎什哈 578 | 虎石台 579 | 呼源 580 | 呼中 581 | 湖州 582 | 华城 583 | 化德 584 | 花湖 585 | 华家 586 | 桦林 587 | 桦南 588 | 花桥 589 | 华容 590 | 华容东 591 | 华容南 592 | 华山 593 | 华山北 594 | 花山南 595 | 华蓥 596 | 花园 597 | 化州 598 | 淮安 599 | 淮北 600 | 淮滨 601 | 怀化 602 | 淮南 603 | 淮南东 604 | 怀仁 605 | 怀仁东 606 | 怀柔 607 | 怀柔北 608 | 换新天 609 | 黄柏 610 | 潢川 611 | 黄村 612 | 黄冈 613 | 黄冈东 614 | 黄冈西 615 | 皇姑屯 616 | 黄花筒 617 | 黄口 618 | 黄陵 619 | 黄梅 620 | 黄泥河 621 | 黄山 622 | 黄石 623 | 黄石北 624 | 黄石东 625 | 黄松甸 626 | 湟源 627 | 黄州 628 | 惠东 629 | 惠农 630 | 惠山 631 | 会同 632 | 徽县 633 | 惠州 634 | 惠州南 635 | 惠州西 636 | 浑河 637 | 霍尔果斯 638 | 获嘉 639 | 火连寨 640 | 霍林郭勒 641 | 霍邱 642 | 霍州 643 | 霍州东 644 | 吉安 645 | 集安 646 | 鸡东 647 | 鸡冠山 648 | 纪家沟 649 | 吉林 650 | 济南 651 | 济南东 652 | 济南西 653 | 济宁 654 | 集宁南 655 | 稷山 656 | 吉首 657 | 吉舒 658 | 吉文 659 | 鸡西 660 | 绩溪县 661 | 蓟县 662 | 济源 663 | 嘉峰 664 | 加格达奇 665 | 佳木斯 666 | 嘉善 667 | 嘉善南 668 | 嘉祥 669 | 夹心子 670 | 嘉兴 671 | 嘉兴南 672 | 嘉峪关 673 | 建昌 674 | 建湖 675 | 建宁县北 676 | 建瓯 677 | 建三江 678 | 建设 679 | 建始 680 | 建水 681 | 建阳 682 | 简阳 683 | 江东 684 | 江都 685 | 江华 686 | 姜家 687 | 江津 688 | 将乐 689 | 江门 690 | 江宁 691 | 江桥 692 | 江山 693 | 江所田 694 | 姜堰 695 | 江永 696 | 江油 697 | 江源 698 | 交城 699 | 蛟河 700 | 角美 701 | 胶州 702 | 胶州北 703 | 焦作 704 | 焦作东 705 | 界首市 706 | 介休 707 | 介休东 708 | 揭阳 709 | 金宝屯 710 | 金昌 711 | 晋城 712 | 晋城北 713 | 金城江 714 | 锦河 715 | 金华 716 | 晋江 717 | 金坑 718 | 金山北 719 | 金山屯 720 | 进贤 721 | 进贤南 722 | 缙云 723 | 金寨 724 | 金杖子 725 | 晋中 726 | 晋州 727 | 金州 728 | 锦州 729 | 锦州南 730 | 靖边 731 | 景德镇 732 | 井店 733 | 井冈山 734 | 静海 735 | 精河 736 | 精河南 737 | 荆门 738 | 井南 739 | 经棚 740 | 京山 741 | 景泰 742 | 镜铁山 743 | 井陉 744 | 靖远 745 | 荆州 746 | 靖州 747 | 九江 748 | 酒泉 749 | 九三 750 | 九台 751 | 九台南 752 | 巨宝 753 | 莒南 754 | 句容西 755 | 莒县 756 | 巨野 757 | 鄄城 758 | 峻德 759 | 军粮城北 760 | 喀什 761 | 开安 762 | 开封 763 | 开江 764 | 凯里 765 | 开鲁 766 | 开通 767 | 开原 768 | 开原西 769 | 康城 770 | 康金井 771 | 康庄 772 | 克东 773 | 克拉玛依 774 | 岢岚 775 | 克山 776 | 克一 777 | 库车 778 | 库都尔 779 | 库尔勒 780 | 库伦 781 | 宽甸 782 | 奎山 783 | 葵潭 784 | 奎屯 785 | 昆明 786 | 昆山 787 | 昆山南 788 | 昆阳 789 | 拉古 790 | 拉哈 791 | 拉林 792 | 拉萨 793 | 来宾 794 | 来宾北 795 | 濑湍 796 | 莱芜东 797 | 莱西 798 | 涞源 799 | 来舟 800 | 蓝村 801 | 兰岗 802 | 兰考 803 | 兰棱 804 | 兰溪 805 | 兰州 806 | 兰州西 807 | 廊坊 808 | 廊坊北 809 | 朗乡 810 | 老边 811 | 老府 812 | 老莱 813 | 老羊壕 814 | 老营 815 | 乐昌 816 | 乐都 817 | 乐平市 818 | 乐山北 819 | 耒阳 820 | 耒阳西 821 | 雷州 822 | 冷水 823 | 黎城 824 | 利川 825 | 离堆公园 826 | 李家 827 | 丽江 828 | 醴陵 829 | 醴陵东 830 | 里木店 831 | 李石寨 832 | 梨树镇 833 | 溧水 834 | 丽水柳园 835 | 黎塘 836 | 黎塘西 837 | 澧县 838 | 溧阳 839 | 廉江 840 | 连江 841 | 莲江口 842 | 连山关 843 | 涟源 844 | 连云港 845 | 连云港东 846 | 两家 847 | 亮甲店 848 | 梁平 849 | 梁山 850 | 聊城 851 | 辽阳 852 | 辽源 853 | 辽中 854 | 林东 855 | 临汾 856 | 临汾西 857 | 临海 858 | 林海 859 | 临河 860 | 临江 861 | 林口 862 | 临澧 863 | 临清 864 | 林西 865 | 临湘 866 | 临沂 867 | 临沂北 868 | 临颍 869 | 林源 870 | 临泽 871 | 临淄 872 | 灵宝 873 | 灵宝西 874 | 灵璧 875 | 凌海 876 | 零陵 877 | 灵丘 878 | 灵石 879 | 灵石东 880 | 陵水 881 | 灵武 882 | 凌源 883 | 凌源东 884 | 柳河 885 | 刘家店 886 | 刘家河 887 | 柳林南 888 | 六盘水 889 | 柳树屯 890 | 柳园南 891 | 六枝 892 | 柳州 893 | 隆昌 894 | 龙川 895 | 隆化 896 | 龙华 897 | 龙嘉 898 | 龙江 899 | 龙井 900 | 龙里 901 | 龙南 902 | 龙山镇 903 | 龙市 904 | 陇西 905 | 陇县 906 | 龙岩 907 | 龙游 908 | 龙镇 909 | 龙爪沟 910 | 娄底 911 | 六安 912 | 芦潮港 913 | 潞城 914 | 陆川 915 | 鹿道 916 | 鲁番 917 | 陆丰 918 | 禄丰南 919 | 六合镇 920 | 庐江 921 | 路口铺 922 | 陆良 923 | 卢龙 924 | 庐山 925 | 鲁山 926 | 露水河 927 | 芦台 928 | 芦溪 929 | 鹿寨 930 | 滦河 931 | 滦平 932 | 滦县 933 | 略阳 934 | 轮台 935 | 罗城 936 | 漯河 937 | 漯河西 938 | 罗江 939 | 罗平 940 | 罗山 941 | 洛阳 942 | 洛阳东 943 | 洛阳龙门 944 | 罗源 945 | 绿化 946 | 吕梁 947 | 旅顺 948 | 马鞍山 949 | 麻城 950 | 麻城北 951 | 马莲河 952 | 马林 953 | 玛纳斯 954 | 玛纳斯湖 955 | 马桥河 956 | 马三家 957 | 麻山 958 | 麻尾 959 | 麻阳 960 | 满归 961 | 满洲里 962 | 毛坝 963 | 毛坝关 964 | 茅草坪 965 | 帽儿山 966 | 茂林 967 | 茂名 968 | 茂名东 969 | 梅河口 970 | 美兰 971 | 眉山 972 | 美溪 973 | 梅州 974 | 猛洞河 975 | 孟家岗 976 | 蒙自 977 | 蒙自北 978 | 汨罗 979 | 汨罗东 980 | 米沙子 981 | 密山 982 | 米易 983 | 密云北 984 | 米脂 985 | 渑池 986 | 渑池南 987 | 免渡河 988 | 冕宁 989 | 勉县 990 | 绵阳 991 | 庙岭 992 | 庙山 993 | 闽清 994 | 民权 995 | 明安 996 | 明城 997 | 明港 998 | 明港东 999 | 明光 1000 | 明水河 1001 | 明珠 1002 | 磨刀石 1003 | 莫尔道嘎 1004 | 漠河 1005 | 墨玉 1006 | 牡丹江 1007 | 木里图 1008 | 穆棱 1009 | 那曲 1010 | 乃林 1011 | 奈曼 1012 | 南岔 1013 | 南昌 1014 | 南昌西 1015 | 南城 1016 | 南充 1017 | 南仇 1018 | 南丹 1019 | 南芬 1020 | 南丰 1021 | 南宫东 1022 | 南关岭 1023 | 南湖东 1024 | 南京 1025 | 南靖 1026 | 南京南 1027 | 南口 1028 | 南口前 1029 | 南朗 1030 | 南木 1031 | 南宁 1032 | 南平 1033 | 南平南 1034 | 南桥 1035 | 南台 1036 | 南通 1037 | 南头 1038 | 南翔北 1039 | 南雄 1040 | 南阳 1041 | 南峪 1042 | 南杂木 1043 | 南召 1044 | 讷河 1045 | 内江 1046 | 内乡 1047 | 嫩江 1048 | 能家 1049 | 泥河子 1050 | 尼勒克 1051 | 尼木 1052 | 碾子山 1053 | 娘子关 1054 | 宁安 1055 | 宁波 1056 | 宁德 1057 | 宁国 1058 | 宁海 1059 | 宁家 1060 | 宁陵县 1061 | 宁明 1062 | 宁武 1063 | 宁乡 1064 | 牛家 1065 | 牛心台 1066 | 农安 1067 | 盘关 1068 | 潘家店 1069 | 盘锦 1070 | 盘锦北 1071 | 磐石 1072 | 攀枝花 1073 | 泡子 1074 | 裴德 1075 | 蓬安 1076 | 彭山 1077 | 彭水 1078 | 彭泽 1079 | 皮口 1080 | 皮山 1081 | 郫县 1082 | 郫县西 1083 | 邳州 1084 | 偏岭 1085 | 瓢儿屯 1086 | 平安 1087 | 平安驿 1088 | 平安镇 1089 | 屏边 1090 | 平顶山 1091 | 平顶山西 1092 | 平房 1093 | 平岗 1094 | 平关 1095 | 平果 1096 | 平凉 1097 | 平南南 1098 | 平泉 1099 | 平山 1100 | 平社 1101 | 坪石 1102 | 平台 1103 | 平田 1104 | 平旺 1105 | 凭祥 1106 | 萍乡 1107 | 萍乡北 1108 | 平型关 1109 | 平洋 1110 | 平遥 1111 | 平遥古城 1112 | 平邑 1113 | 平原 1114 | 平庄 1115 | 平庄南 1116 | 普安 1117 | 蒲城 1118 | 蒲城东 1119 | 普兰店 1120 | 普宁 1121 | 莆田 1122 | 普湾 1123 | 普雄 1124 | 蕲春 1125 | 祁东 1126 | 祁家堡 1127 | 綦江 1128 | 七里河 1129 | 祁门 1130 | 齐齐哈尔 1131 | 岐山 1132 | 戚墅堰 1133 | 七台河 1134 | 旗下营 1135 | 祁县 1136 | 祁县东 1137 | 祁阳 1138 | 乾安 1139 | 迁安 1140 | 前锋 1141 | 千河 1142 | 潜江 1143 | 黔江 1144 | 前进镇 1145 | 前磨头 1146 | 前山 1147 | 前卫 1148 | 千阳 1149 | 桥头 1150 | 秦都 1151 | 秦皇岛 1152 | 秦家 1153 | 秦家庄 1154 | 秦岭 1155 | 沁县 1156 | 沁阳 1157 | 钦州 1158 | 钦州东 1159 | 庆安 1160 | 青城山 1161 | 青岛 1162 | 青岛北 1163 | 庆丰 1164 | 清河 1165 | 清河城 1166 | 清河门 1167 | 清华园 1168 | 清涧县 1169 | 青龙山 1170 | 青山 1171 | 庆盛 1172 | 清水 1173 | 青田 1174 | 青铜峡 1175 | 青县 1176 | 清徐 1177 | 清原 1178 | 清远 1179 | 青州市 1180 | 琼海 1181 | 曲阜 1182 | 曲阜东 1183 | 曲靖 1184 | 渠旧 1185 | 渠黎 1186 | 曲水县 1187 | 渠县 1188 | 衢州 1189 | 全椒 1190 | 泉阳 1191 | 泉州 1192 | 泉州东 1193 | 全州南 1194 | 确山 1195 | 饶平 1196 | 饶阳 1197 | 绕阳河 1198 | 热水 1199 | 仁布 1200 | 任丘 1201 | 日喀则 1202 | 日照 1203 | 融安 1204 | 荣昌 1205 | 容桂 1206 | 融水 1207 | 容县 1208 | 如东 1209 | 如皋 1210 | 乳山 1211 | 汝阳 1212 | 汝州 1213 | 瑞安 1214 | 瑞昌 1215 | 瑞金 1216 | 萨拉齐 1217 | 赛汗塔拉 1218 | 三河县 1219 | 三汇镇 1220 | 三家店 1221 | 三家寨 1222 | 三间房 1223 | 三江口 1224 | 三江县 1225 | 三井子 1226 | 三门峡 1227 | 三门峡南 1228 | 三门峡西 1229 | 三门县 1230 | 三明 1231 | 三明北 1232 | 三十家 1233 | 三十里堡 1234 | 三水 1235 | 三堂集 1236 | 三亚 1237 | 三义井 1238 | 三原 1239 | 三源浦 1240 | 桑根达来 1241 | 莎车 1242 | 沙城 1243 | 沙海 1244 | 沙河 1245 | 沙河口 1246 | 沙河市 1247 | 沙后所 1248 | 沙岭子 1249 | 沙湾县 1250 | 山城镇 1251 | 山丹 1252 | 山海关 1253 | 山河屯 1254 | 山坡东 1255 | 鄯善 1256 | 鄯善北 1257 | 山市 1258 | 汕头 1259 | 汕尾 1260 | 山阴 1261 | 上板城 1262 | 上板城南 1263 | 商城 1264 | 商都 1265 | 上海 1266 | 上海虹桥 1267 | 上海南 1268 | 上海西 1269 | 上杭 1270 | 尚家 1271 | 商洛 1272 | 商南 1273 | 商丘 1274 | 商丘南 1275 | 上饶 1276 | 上虞 1277 | 上虞北 1278 | 上园 1279 | 尚志 1280 | 邵东 1281 | 韶关 1282 | 韶关东 1283 | 韶山 1284 | 邵武 1285 | 绍兴 1286 | 绍兴北 1287 | 邵阳 1288 | 舍力虎 1289 | 歙县 1290 | 涉县 1291 | 神池 1292 | 绅坊 1293 | 沈家 1294 | 深井子 1295 | 神木 1296 | 沈丘 1297 | 神树 1298 | 神头 1299 | 沈阳 1300 | 沈阳东 1301 | 深圳 1302 | 深圳东 1303 | 深圳坪山 1304 | 深圳西 1305 | 神州 1306 | 深州十渡 1307 | 施秉 1308 | 世博园 1309 | 石城 1310 | 石河子 1311 | 石家庄 1312 | 石家庄北 1313 | 石景山南 1314 | 石林 1315 | 石磷 1316 | 石岭 1317 | 石门县北 1318 | 石桥子 1319 | 石泉县 1320 | 石人 1321 | 石人城 1322 | 石山 1323 | 石头 1324 | 石岘 1325 | 始兴 1326 | 十堰 1327 | 石柱县 1328 | 师宗 1329 | 石嘴山 1330 | 首山 1331 | 寿阳 1332 | 舒城 1333 | 舒兰 1334 | 疏勒 1335 | 疏勒河 1336 | 沭阳 1337 | 双城堡 1338 | 双城北 1339 | 双丰 1340 | 双河镇 1341 | 双辽 1342 | 双牌 1343 | 双鸭山 1344 | 水洞 1345 | 水富 1346 | 水家湖 1347 | 水泉 1348 | 顺昌 1349 | 顺德 1350 | 顺德学院 1351 | 顺义 1352 | 朔州 1353 | 四道湾 1354 | 四方台 1355 | 四合永 1356 | 泗洪 1357 | 四平 1358 | 四平东 1359 | 泗水 1360 | 泗县 1361 | 泗阳 1362 | 松河 1363 | 松江 1364 | 松江河 1365 | 松江南 1366 | 松江镇 1367 | 松树 1368 | 松树镇 1369 | 松桃 1370 | 松原 1371 | 松原北 1372 | 松滋 1373 | 苏家屯 1374 | 肃宁 1375 | 宿松 1376 | 宿州 1377 | 苏州 1378 | 苏州北 1379 | 宿州东 1380 | 苏州新区 1381 | 苏州园区 1382 | 绥德 1383 | 绥芬河 1384 | 绥化 1385 | 绥棱 1386 | 遂宁 1387 | 遂平 1388 | 遂溪 1389 | 绥阳 1390 | 绥中 1391 | 绥中北 1392 | 随州 1393 | 孙家 1394 | 孙吴 1395 | 孙镇 1396 | 索伦 1397 | 塔尔气 1398 | 塔哈 1399 | 塔河 1400 | 台安 1401 | 泰安 1402 | 太谷 1403 | 太谷西 1404 | 泰和 1405 | 太湖 1406 | 泰康 1407 | 泰来 1408 | 太姥山 1409 | 泰宁 1410 | 太平川 1411 | 太平镇 1412 | 台前 1413 | 泰山 1414 | 太阳山 1415 | 太阳升 1416 | 太原 1417 | 太原北 1418 | 太原东 1419 | 太原南 1420 | 台州 1421 | 泰州 1422 | 郯城 1423 | 谭家井 1424 | 汤池 1425 | 塘沽 1426 | 唐河 1427 | 唐家湾 1428 | 唐山 1429 | 唐山北 1430 | 汤山城 1431 | 汤旺河 1432 | 汤逊湖 1433 | 汤阴 1434 | 汤原 1435 | 桃村 1436 | 陶家屯 1437 | 陶赖昭 1438 | 洮南 1439 | 桃山 1440 | 藤县 1441 | 滕州 1442 | 滕州东 1443 | 田东 1444 | 天岗 1445 | 天津 1446 | 天津北 1447 | 天津南 1448 | 天津西 1449 | 田林 1450 | 天门 1451 | 天门南 1452 | 天桥岭 1453 | 田师府 1454 | 天水 1455 | 田阳 1456 | 天义 1457 | 天镇 1458 | 天祝 1459 | 天柱山 1460 | 铁厂 1461 | 铁力 1462 | 铁岭 1463 | 铁岭西 1464 | 亭亮 1465 | 桐柏 1466 | 通北 1467 | 桐城 1468 | 通道 1469 | 通沟 1470 | 潼关 1471 | 通海 1472 | 通化 1473 | 通化县 1474 | 通辽 1475 | 铜陵 1476 | 潼南 1477 | 铜仁 1478 | 通途 1479 | 桐乡 1480 | 同心 1481 | 通远堡 1482 | 通州西 1483 | 桐梓 1484 | 桐子林 1485 | 土地堂东 1486 | 土贵乌拉 1487 | 吐哈 1488 | 图里河 1489 | 吐列毛杜 1490 | 吐鲁番 1491 | 吐鲁番北 1492 | 图们 1493 | 土牧尔台 1494 | 图强 1495 | 土溪 1496 | 团结 1497 | 驼腰岭 1498 | 瓦房店 1499 | 瓦房店西 1500 | 瓦屋山 1501 | 歪头山 1502 | 万发屯 1503 | 湾沟 1504 | 万乐 1505 | 万年 1506 | 万宁 1507 | 万源 1508 | 万州 1509 | 旺苍 1510 | 望都 1511 | 王府 1512 | 王岗 1513 | 汪清 1514 | 王瞳 1515 | 王兆屯 1516 | 卫东 1517 | 潍坊 1518 | 威海 1519 | 苇河 1520 | 卫辉 1521 | 渭津 1522 | 威箐 1523 | 渭南 1524 | 渭南北 1525 | 渭南南 1526 | 渭南镇 1527 | 威舍 1528 | 卫星 1529 | 魏杖子 1530 | 韦庄 1531 | 苇子沟 1532 | 文安 1533 | 文昌 1534 | 温春 1535 | 文登 1536 | 文地 1537 | 温岭 1538 | 文水 1539 | 闻喜 1540 | 闻喜西 1541 | 温州 1542 | 温州南 1543 | 倭肯 1544 | 卧里屯 1545 | 沃皮 1546 | 武安 1547 | 吴堡 1548 | 五叉沟 1549 | 五常 1550 | 武昌 1551 | 五大连池 1552 | 武当山 1553 | 五道沟 1554 | 乌尔旗汗 1555 | 武功 1556 | 乌海 1557 | 乌海西 1558 | 武汉 1559 | 芜湖 1560 | 五家 1561 | 吴家屯 1562 | 五棵树 1563 | 乌拉山 1564 | 乌拉特前旗 1565 | 乌兰哈达 1566 | 乌兰浩特 1567 | 五莲 1568 | 武隆 1569 | 五龙背 1570 | 乌龙泉南 1571 | 乌鲁木齐南 1572 | 乌奴耳 1573 | 五女山 1574 | 吴桥 1575 | 武清 1576 | 武山 1577 | 五台山 1578 | 武威 1579 | 武威南 1580 | 五五 1581 | 乌西 1582 | 无锡 1583 | 无锡东 1584 | 无锡新区 1585 | 武乡 1586 | 武穴 1587 | 武义 1588 | 乌伊岭 1589 | 武夷山 1590 | 武进 1591 | 五营 1592 | 五原 1593 | 五寨 1594 | 梧州 1595 | 梧州南 1596 | 西安 1597 | 西安北 1598 | 西安南 1599 | 西昌 1600 | 西昌南 1601 | 喜德 1602 | 西斗铺 1603 | 息烽 1604 | 西丰 1605 | 西岗子 1606 | 西林 1607 | 锡林浩特 1608 | 西柳 1609 | 西麻山 1610 | 西宁西 1611 | 西平 1612 | 犀浦 1613 | 犀浦东 1614 | 浠水 1615 | 西峡 1616 | 西乡 1617 | 西小召 1618 | 西哲里木 1619 | 汐子 1620 | 下板城 1621 | 下城子 1622 | 夏官营 1623 | 下花园 1624 | 峡江 1625 | 下马塘 1626 | 厦门 1627 | 厦门北 1628 | 厦门高崎 1629 | 霞浦 1630 | 下社 1631 | 夏石 1632 | 下台子 1633 | 夏邑县 1634 | 仙林 1635 | 咸宁 1636 | 咸宁北 1637 | 咸宁东 1638 | 咸宁南 1639 | 仙人桥 1640 | 仙桃西 1641 | 咸阳 1642 | 项城 1643 | 香坊 1644 | 襄汾 1645 | 襄汾西 1646 | 襄河 1647 | 香兰 1648 | 湘潭 1649 | 向塘 1650 | 湘乡 1651 | 向阳 1652 | 襄阳 1653 | 襄阳东 1654 | 襄垣 1655 | 祥云 1656 | 孝感 1657 | 孝感北 1658 | 小河沿 1659 | 小河镇 1660 | 小榄 1661 | 小岭 1662 | 孝南 1663 | 小市 1664 | 小寺沟 1665 | 孝西 1666 | 小扬气 1667 | 小雨谷 1668 | 谢家镇 1669 | 协荣 1670 | 新安县 1671 | 新城子 1672 | 新绰源 1673 | 信丰 1674 | 新干 1675 | 新和 1676 | 新化 1677 | 新华 1678 | 新华屯 1679 | 新晃 1680 | 新会 1681 | 辛集 1682 | 新绛 1683 | 新乐 1684 | 新立屯 1685 | 新立镇 1686 | 新林 1687 | 新民 1688 | 新青 1689 | 新邱 1690 | 新松浦 1691 | 新窝铺 1692 | 新县 1693 | 新乡 1694 | 新乡东 1695 | 新兴县 1696 | 信阳 1697 | 信阳东 1698 | 新阳镇 1699 | 信宜 1700 | 新沂 1701 | 新友谊 1702 | 新余 1703 | 新余北 1704 | 新杖子 1705 | 新肇 1706 | 忻州 1707 | 兴安北 1708 | 兴城 1709 | 兴国 1710 | 兴和西 1711 | 兴凯 1712 | 兴隆店 1713 | 兴隆县 1714 | 兴隆镇 1715 | 兴宁 1716 | 兴平 1717 | 杏树 1718 | 杏树屯 1719 | 邢台 1720 | 邢台东 1721 | 兴业 1722 | 兴义 1723 | 熊岳城 1724 | 秀山 1725 | 修武 1726 | 许昌 1727 | 许昌东 1728 | 徐家 1729 | 许家屯 1730 | 溆浦 1731 | 徐水 1732 | 徐闻 1733 | 徐州 1734 | 徐州东 1735 | 宣城 1736 | 轩岗 1737 | 宣汉 1738 | 宣化 1739 | 宣威 1740 | 旬阳 1741 | 旬阳北 1742 | 亚布力 1743 | 亚布力南 1744 | 牙克石 1745 | 亚龙湾 1746 | 牙屯堡 1747 | 鸭园 1748 | 延安 1749 | 盐城 1750 | 盐池 1751 | 砚川 1752 | 雁荡山 1753 | 燕岗 1754 | 岩会 1755 | 延吉 1756 | 燕郊 1757 | 盐津 1758 | 阎良 1759 | 炎陵 1760 | 焉耆 1761 | 延庆 1762 | 燕山 1763 | 偃师 1764 | 烟台 1765 | 烟筒山 1766 | 烟筒屯 1767 | 兖州 1768 | 燕子砭 1769 | 羊草 1770 | 阳岔 1771 | 羊场 1772 | 阳城 1773 | 阳澄湖 1774 | 阳春 1775 | 杨村 1776 | 杨岗 1777 | 阳高 1778 | 阳谷 1779 | 洋河 1780 | 杨陵 1781 | 杨陵南 1782 | 杨柳青 1783 | 阳明堡 1784 | 阳平关 1785 | 阳曲 1786 | 阳泉 1787 | 阳泉北 1788 | 阳泉曲 1789 | 杨树岭 1790 | 阳新 1791 | 阳邑 1792 | 杨杖子 1793 | 扬州 1794 | 姚家 1795 | 姚千户屯 1796 | 叶柏寿 1797 | 叶城 1798 | 叶集 1799 | 野三坡 1800 | 依安 1801 | 宜宾 1802 | 宜昌 1803 | 宜昌东 1804 | 宜城 1805 | 宜春 1806 | 宜春西 1807 | 伊尔施 1808 | 一间堡 1809 | 伊拉哈 1810 | 彝良 1811 | 宜良北 1812 | 伊林 1813 | 义马 1814 | 一面坡 1815 | 一面山伊春 1816 | 沂南 1817 | 伊宁 1818 | 伊宁东 1819 | 沂水 1820 | 伊图里河 1821 | 义乌 1822 | 义县 1823 | 宜兴 1824 | 弋阳 1825 | 益阳 1826 | 宜州 1827 | 银川 1828 | 银浪 1829 | 迎宾路 1830 | 应城 1831 | 营城子 1832 | 迎春 1833 | 英德 1834 | 英德西 1835 | 英吉沙 1836 | 营口 1837 | 营口东 1838 | 营盘湾 1839 | 营山 1840 | 鹰手营子 1841 | 鹰潭 1842 | 鹰潭北 1843 | 应县 1844 | 永安 1845 | 永安乡 1846 | 永川 1847 | 永登 1848 | 永定 1849 | 永福南 1850 | 永济 1851 | 永济北 1852 | 永嘉 1853 | 永康泽普 1854 | 永康 1855 | 永郎 1856 | 永乐店 1857 | 永泰 1858 | 永修 1859 | 永州 1860 | 友好 1861 | 尤溪 1862 | 攸县 1863 | 攸县南 1864 | 酉阳 1865 | 禹城 1866 | 虞城县 1867 | 榆次 1868 | 于都 1869 | 雨格 1870 | 余杭 1871 | 余江 1872 | 余粮堡 1873 | 榆林 1874 | 玉林 1875 | 玉门 1876 | 玉屏 1877 | 玉泉 1878 | 玉山 1879 | 玉山南 1880 | 榆社 1881 | 榆树 1882 | 榆树台 1883 | 榆树屯 1884 | 玉田县 1885 | 玉溪 1886 | 余姚 1887 | 余姚北 1888 | 元宝山 1889 | 元谋 1890 | 原平 1891 | 元氏 1892 | 源潭 1893 | 岳池 1894 | 月亮田 1895 | 乐清 1896 | 月山 1897 | 越西 1898 | 岳阳 1899 | 岳阳东 1900 | 运城 1901 | 郓城 1902 | 运城北 1903 | 云梦 1904 | 云霄 1905 | 咋子 1906 | 枣林 1907 | 枣强 1908 | 枣阳 1909 | 枣庄 1910 | 枣庄西 1911 | 扎赉诺尔西 1912 | 扎兰屯 1913 | 扎鲁特 1914 | 柞水 1915 | 湛江 1916 | 湛江西 1917 | 章党 1918 | 章古台 1919 | 张家界 1920 | 张家口 1921 | 张家口南 1922 | 张兰 1923 | 樟木头 1924 | 漳平 1925 | 漳浦 1926 | 张桥 1927 | 章丘 1928 | 樟树 1929 | 樟树东 1930 | 张维屯 1931 | 彰武 1932 | 张掖 1933 | 漳州 1934 | 漳州东 1935 | 诏安 1936 | 赵城 1937 | 肇东 1938 | 赵光 1939 | 肇庆 1940 | 昭通 1941 | 朝阳 1942 | 朝阳川 1943 | 朝阳地 1944 | 哲里木 1945 | 镇安 1946 | 镇城底 1947 | 镇江 1948 | 镇江南 1949 | 镇赉 1950 | 镇平 1951 | 镇西 1952 | 镇远 1953 | 正定机场 1954 | 正镶白旗 1955 | 郑州 1956 | 郑州东 1957 | 治安 1958 | 枝城 1959 | 纸坊东 1960 | 枝江北 1961 | 织金 1962 | 钟家村 1963 | 中宁 1964 | 中宁东 1965 | 中山 1966 | 中山北 1967 | 中卫 1968 | 钟祥 1969 | 周家 1970 | 周家屯 1971 | 周口 1972 | 周水子 1973 | 诸城 1974 | 珠海 1975 | 珠海北 1976 | 诸暨 1977 | 朱家沟 1978 | 驻马店 1979 | 驻马店西 1980 | 朱日和 1981 | 朱杨溪 1982 | 竹园坝 1983 | 株洲 1984 | 株洲西 1985 | 庄桥 1986 | 涿州 1987 | 涿州东 1988 | 卓资东 1989 | 卓资山 1990 | 淄博 1991 | 子长 1992 | 自贡 1993 | 资溪 1994 | 紫阳 1995 | 资阳 1996 | 资中 1997 | 子洲 1998 | 棕溪 1999 | 邹城 2000 | 遵义 2001 | 左岭 2002 | 灌阳县 2003 | 儋州 2004 | 东至县 2005 | 王浩屯镇 2006 | 薛城区 2007 | 平阳 2008 | 鼓楼 2009 | 奉贤 2010 | -------------------------------------------------------------------------------- /run_classifier.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """BERT finetuning runner.""" 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import argparse 21 | import logging 22 | import os 23 | import sys 24 | import random 25 | 26 | import numpy as np 27 | 28 | import torch 29 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 30 | TensorDataset) 31 | from torch.utils.data.distributed import DistributedSampler 32 | from torch.nn import CrossEntropyLoss 33 | 34 | from modeling import BertForTaskNLU 35 | from tokenization import BertTokenizer 36 | from optimization import BertAdam, WarmupLinearSchedule 37 | 38 | from run_classifier_dataset_utils import processors, convert_examples_to_features, write_result 39 | 40 | if sys.version_info[0] == 2: 41 | import cPickle as pickle 42 | else: 43 | import pickle 44 | 45 | WEIGHTS_NAME = 'pytorch_model.bin' 46 | 47 | logger = logging.getLogger(__name__) 48 | 49 | 50 | def main(): 51 | parser = argparse.ArgumentParser() 52 | 53 | ## Required parameters 54 | parser.add_argument("--train_data", 55 | default=None, 56 | type=str, 57 | help="Train data path.") 58 | parser.add_argument("--test_data", 59 | default=None, 60 | type=str, 61 | help="Test data path.") 62 | parser.add_argument("--eval_data", 63 | default=None, 64 | type=str, 65 | help="Eval data path.") 66 | parser.add_argument("--bert_model", 67 | default=None, 68 | type=str, 69 | required=True, 70 | help="PreTrained model path.") 71 | parser.add_argument("--config", 72 | default=None, 73 | type=str, 74 | required=True, 75 | help="Model config path.") 76 | parser.add_argument("--vocab", 77 | default=None, 78 | type=str, 79 | required=True, 80 | help="Vocabulary path.") 81 | parser.add_argument("--task_name", 82 | default=None, 83 | type=str, 84 | required=True, 85 | help="The name of the task to train.") 86 | parser.add_argument("--output_dir", 87 | default=None, 88 | type=str, 89 | required=True, 90 | help="The output directory where the model predictions and checkpoints will be written.") 91 | parser.add_argument("--result_file", 92 | default=None, 93 | type=str, 94 | help="The output directory where the model predictions and checkpoints will be written.") 95 | parser.add_argument("--dic_dir", 96 | default=None, 97 | type=str, 98 | required=True, 99 | help="The dic directory which used by rule.") 100 | 101 | ## Other parameters 102 | parser.add_argument("--max_seq_length", 103 | default=128, 104 | type=int, 105 | help="The maximum total input sequence length after WordPiece tokenization. \n" 106 | "Sequences longer than this will be truncated, and sequences shorter \n" 107 | "than this will be padded.") 108 | parser.add_argument("--do_train", 109 | action='store_true', 110 | help="Whether to run training.") 111 | parser.add_argument("--do_predict", 112 | action='store_true', 113 | help="Whether to run eval on the dev set.") 114 | parser.add_argument("--do_lower_case", 115 | action='store_true', 116 | help="Set this flag if you are using an uncased model.") 117 | parser.add_argument("--train_batch_size", 118 | default=32, 119 | type=int, 120 | help="Total batch size for training.") 121 | parser.add_argument("--pred_batch_size", 122 | default=32, 123 | type=int, 124 | help="Total batch size for eval.") 125 | parser.add_argument("--learning_rate", 126 | default=5e-5, 127 | type=float, 128 | help="The initial learning rate for Adam.") 129 | parser.add_argument("--num_train_epochs", 130 | default=3.0, 131 | type=float, 132 | help="Total number of training epochs to perform.") 133 | parser.add_argument("--warmup_proportion", 134 | default=0.1, 135 | type=float, 136 | help="Proportion of training to perform linear learning rate warmup for. " 137 | "E.g., 0.1 = 10%% of training.") 138 | parser.add_argument("--no_cuda", 139 | action='store_true', 140 | help="Whether not to use CUDA when available") 141 | parser.add_argument('--overwrite_output_dir', 142 | action='store_true', 143 | help="Overwrite the content of the output directory") 144 | parser.add_argument("--local_rank", 145 | type=int, 146 | default=-1, 147 | help="local_rank for distributed training on gpus") 148 | parser.add_argument('--seed', 149 | type=int, 150 | default=2019, 151 | help="random seed for initialization") 152 | parser.add_argument('--gradient_accumulation_steps', 153 | type=int, 154 | default=1, 155 | help="Number of updates steps to accumulate before performing a backward/update pass.") 156 | parser.add_argument('--fp16', 157 | action='store_true', 158 | help="Whether to use 16-bit float precision instead of 32-bit") 159 | parser.add_argument('--loss_scale', 160 | type=float, default=0, 161 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 162 | "0 (default value): dynamic loss scaling.\n" 163 | "Positive power of 2: static loss scaling value.\n") 164 | args = parser.parse_args() 165 | 166 | if args.local_rank == -1 or args.no_cuda: 167 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 168 | n_gpu = torch.cuda.device_count() 169 | else: 170 | torch.cuda.set_device(args.local_rank) 171 | device = torch.device("cuda", args.local_rank) 172 | n_gpu = 1 173 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 174 | torch.distributed.init_process_group(backend='nccl') 175 | args.device = device 176 | 177 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 178 | datefmt = '%m/%d/%Y %H:%M:%S', 179 | level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 180 | 181 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( 182 | device, n_gpu, bool(args.local_rank != -1), args.fp16)) 183 | 184 | if args.gradient_accumulation_steps < 1: 185 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 186 | args.gradient_accumulation_steps)) 187 | 188 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 189 | 190 | random.seed(args.seed) 191 | np.random.seed(args.seed) 192 | torch.manual_seed(args.seed) 193 | if n_gpu > 0: 194 | torch.cuda.manual_seed_all(args.seed) 195 | 196 | if not args.do_train and not args.do_predict: 197 | raise ValueError("At least one of `do_train` or `do_predict` must be True.") 198 | 199 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: 200 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 201 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: 202 | os.makedirs(args.output_dir) 203 | 204 | task_name = args.task_name.lower() 205 | 206 | if task_name not in processors: 207 | raise ValueError("Task not found: %s" % (task_name)) 208 | 209 | processor = processors[task_name]() 210 | 211 | label_list = processor.get_labels() 212 | domain_map = {} 213 | for (i, label) in enumerate(label_list['domain']): 214 | domain_map[label] = i 215 | 216 | intent_map = {} 217 | for (i, label) in enumerate(label_list['intent']): 218 | intent_map[label] = i 219 | 220 | slots_map = {} 221 | for (i, label) in enumerate(label_list['slots']): 222 | slots_map[label] = i 223 | 224 | logger.info("***** label list *****") 225 | for key, value in label_list.items(): 226 | logger.info("%s(%d): %s" %(key, len(value), ", ".join(value))) 227 | 228 | if args.local_rank not in [-1, 0]: 229 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 230 | tokenizer = BertTokenizer.from_pretrained(args.vocab, do_lower_case=args.do_lower_case) 231 | model = BertForTaskNLU.from_pretrained(args.bert_model, args.config, label_list=label_list, max_seq_len=args.max_seq_length) 232 | if args.local_rank == 0: 233 | torch.distributed.barrier() 234 | 235 | if args.fp16: 236 | model.half() 237 | model.to(device) 238 | if args.local_rank != -1: 239 | model = torch.nn.parallel.DistributedDataParallel(model, 240 | device_ids=[args.local_rank], 241 | output_device=args.local_rank, 242 | find_unused_parameters=True) 243 | elif n_gpu > 1: 244 | model = torch.nn.DataParallel(model) 245 | 246 | global_step = 0 247 | nb_tr_steps = 0 248 | tr_loss = 0 249 | 250 | if args.do_train: 251 | # Prepare data loader 252 | train_examples = processor.get_train_examples(args.train_data) 253 | random.seed(args.seed) 254 | random.shuffle(train_examples) 255 | train_features = convert_examples_to_features( 256 | train_examples, domain_map, intent_map, slots_map, args.max_seq_length, tokenizer) 257 | 258 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 259 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 260 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 261 | all_domain_ids = torch.tensor([f.domain_id for f in train_features], dtype=torch.long) 262 | all_intent_ids = torch.tensor([f.intent_id for f in train_features], dtype=torch.long) 263 | all_slots_ids = torch.tensor([f.slots_id for f in train_features], dtype=torch.long) 264 | 265 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_domain_ids, all_intent_ids, all_slots_ids) 266 | if args.local_rank == -1: 267 | train_sampler = RandomSampler(train_data) 268 | else: 269 | train_sampler = DistributedSampler(train_data) 270 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) 271 | 272 | num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 273 | 274 | # Prepare optimizer 275 | 276 | param_optimizer = list(model.named_parameters()) 277 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 278 | optimizer_grouped_parameters = [ 279 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 280 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 281 | ] 282 | if args.fp16: 283 | try: 284 | from apex.optimizers import FP16_Optimizer 285 | from apex.optimizers import FusedAdam 286 | except ImportError: 287 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 288 | 289 | optimizer = FusedAdam(optimizer_grouped_parameters, 290 | lr=args.learning_rate, 291 | bias_correction=False, 292 | max_grad_norm=1.0) 293 | if args.loss_scale == 0: 294 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) 295 | else: 296 | optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) 297 | warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion, 298 | t_total=num_train_optimization_steps) 299 | 300 | else: 301 | optimizer = BertAdam(optimizer_grouped_parameters, 302 | lr=args.learning_rate, 303 | warmup=args.warmup_proportion, 304 | t_total=num_train_optimization_steps) 305 | 306 | logger.info("***** Running training *****") 307 | logger.info("Num examples = %d", len(train_examples)) 308 | logger.info("Batch size = %d", args.train_batch_size) 309 | logger.info("Num steps = %d", num_train_optimization_steps) 310 | 311 | model.train() 312 | for _ in range(int(args.num_train_epochs)): 313 | tr_loss = 0 314 | nb_tr_examples, nb_tr_steps = 0, 0 315 | for step, batch in enumerate(train_dataloader): 316 | batch = tuple(t.to(device) for t in batch) 317 | input_ids, input_mask, segment_ids, domain_id, intent_id, slots_id = batch 318 | 319 | # define a new function to compute loss values for both output_modes 320 | domain_logits, intent_logits, slots_logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask) 321 | 322 | loss_fct = CrossEntropyLoss() 323 | loss = loss_fct(domain_logits, domain_id) + loss_fct(intent_logits, intent_id) 324 | for i in range(len(slots_id)): 325 | loss += loss_fct(slots_logits[i], slots_id[i]) 326 | 327 | if n_gpu > 1: 328 | loss = loss.mean() # mean() to average on multi-gpu. 329 | if args.gradient_accumulation_steps > 1: 330 | loss = loss / args.gradient_accumulation_steps 331 | 332 | if args.fp16: 333 | optimizer.backward(loss) 334 | else: 335 | loss.backward() 336 | 337 | tr_loss += loss.item() 338 | nb_tr_examples += input_ids.size(0) 339 | nb_tr_steps += 1 340 | if (step + 1) % args.gradient_accumulation_steps == 0: 341 | if args.fp16: 342 | # modify learning rate with special warm up BERT uses 343 | # if args.fp16 is False, BertAdam is used that handles this automatically 344 | lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion) 345 | for param_group in optimizer.param_groups: 346 | param_group['lr'] = lr_this_step 347 | optimizer.step() 348 | optimizer.zero_grad() 349 | global_step += 1 350 | if args.local_rank in [-1, 0] and nb_tr_steps % 20 == 0: 351 | # logger.info("lr = {}, global_step = {}".format(optimizer.get_lr()[0], global_step)) 352 | logger.info("loss = {:.6f}, global_step = {}".format(tr_loss/global_step, global_step)) 353 | 354 | ### Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() 355 | ### Example: 356 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 357 | # Save a trained model, configuration and tokenizer 358 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 359 | 360 | # If we save using the predefined names, we can load using `from_pretrained` 361 | output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) 362 | torch.save(model_to_save.state_dict(), output_model_file) 363 | else: 364 | model = BertForTaskNLU.from_pretrained(args.bert_model, args.config, label_list=label_list, max_seq_len=args.max_seq_length) 365 | 366 | model.to(device) 367 | 368 | ### prediction 369 | if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 370 | pred_examples = processor.get_test_examples(args.test_data) 371 | pred_features = convert_examples_to_features( 372 | pred_examples, domain_map, intent_map, slots_map, args.max_seq_length, tokenizer) 373 | 374 | logger.info("***** Running prediction *****") 375 | logger.info("Num examples = %d", len(pred_examples)) 376 | logger.info("Batch size = %d", args.pred_batch_size) 377 | all_input_ids = torch.tensor([f.input_ids for f in pred_features], dtype=torch.long) 378 | all_input_mask = torch.tensor([f.input_mask for f in pred_features], dtype=torch.long) 379 | all_segment_ids = torch.tensor([f.segment_ids for f in pred_features], dtype=torch.long) 380 | 381 | pred_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids) 382 | # Run prediction for full data 383 | if args.local_rank == -1: 384 | pred_sampler = SequentialSampler(pred_data) 385 | else: 386 | pred_sampler = DistributedSampler(pred_data) # Note that this sampler samples randomly 387 | pred_dataloader = DataLoader(pred_data, sampler=pred_sampler, batch_size=args.pred_batch_size) 388 | 389 | model.eval() 390 | preds = [] 391 | 392 | for input_ids, input_mask, segment_ids in pred_dataloader: 393 | input_ids = input_ids.to(device) 394 | input_mask = input_mask.to(device) 395 | segment_ids = segment_ids.to(device) 396 | 397 | with torch.no_grad(): 398 | domain_logits, intent_logits, slots_logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask) 399 | domain = domain_logits.detach().cpu().numpy() 400 | intent = intent_logits.detach().cpu().numpy() 401 | slots = slots_logits.detach().cpu().numpy() 402 | for i in range(domain.shape[0]): 403 | preds.append({"domain":domain[i], "intent":intent[i], "slots":slots[i]}) 404 | 405 | output_predict_file = os.path.join(args.output_dir, args.result_file) 406 | write_result(output_predict_file, args.dic_dir, preds, pred_examples, domain_map, intent_map, slots_map) 407 | 408 | if __name__ == "__main__": 409 | main() 410 | -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model.""" 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import copy 21 | import json 22 | import logging 23 | import math 24 | import os 25 | import sys 26 | 27 | import torch 28 | from torch import nn 29 | from torch.nn import CrossEntropyLoss 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | def prune_linear_layer(layer, index, dim=0): 34 | """ Prune a linear layer (a model parameters) to keep only entries in index. 35 | Return the pruned layer as a new layer with requires_grad=True. 36 | Used to remove heads. 37 | """ 38 | index = index.to(layer.weight.device) 39 | W = layer.weight.index_select(dim, index).clone().detach() 40 | if layer.bias is not None: 41 | if dim == 1: 42 | b = layer.bias.clone().detach() 43 | else: 44 | b = layer.bias[index].clone().detach() 45 | new_size = list(layer.weight.size()) 46 | new_size[dim] = len(index) 47 | new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device) 48 | new_layer.weight.requires_grad = False 49 | new_layer.weight.copy_(W.contiguous()) 50 | new_layer.weight.requires_grad = True 51 | if layer.bias is not None: 52 | new_layer.bias.requires_grad = False 53 | new_layer.bias.copy_(b.contiguous()) 54 | new_layer.bias.requires_grad = True 55 | return new_layer 56 | 57 | 58 | def gelu(x): 59 | """Implementation of the gelu activation function. 60 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 61 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 62 | Also see https://arxiv.org/abs/1606.08415 63 | """ 64 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 65 | 66 | 67 | def swish(x): 68 | return x * torch.sigmoid(x) 69 | 70 | 71 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 72 | 73 | 74 | class BertConfig(object): 75 | """Configuration class to store the configuration of a `BertModel`. 76 | """ 77 | def __init__(self, 78 | vocab_size_or_config_json_file, 79 | hidden_size=768, 80 | num_hidden_layers=12, 81 | num_attention_heads=12, 82 | intermediate_size=3072, 83 | hidden_act="gelu", 84 | hidden_dropout_prob=0.1, 85 | attention_probs_dropout_prob=0.1, 86 | max_position_embeddings=512, 87 | type_vocab_size=2, 88 | initializer_range=0.02, 89 | layer_norm_eps=1e-12): 90 | """Constructs BertConfig. 91 | 92 | Args: 93 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. 94 | hidden_size: Size of the encoder layers and the pooler layer. 95 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 96 | num_attention_heads: Number of attention heads for each attention layer in 97 | the Transformer encoder. 98 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 99 | layer in the Transformer encoder. 100 | hidden_act: The non-linear activation function (function or string) in the 101 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 102 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 103 | layers in the embeddings, encoder, and pooler. 104 | attention_probs_dropout_prob: The dropout ratio for the attention 105 | probabilities. 106 | max_position_embeddings: The maximum sequence length that this model might 107 | ever be used with. Typically set this to something large just in case 108 | (e.g., 512 or 1024 or 2048). 109 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 110 | `BertModel`. 111 | initializer_range: The sttdev of the truncated_normal_initializer for 112 | initializing all weight matrices. 113 | layer_norm_eps: The epsilon used by LayerNorm. 114 | """ 115 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 116 | and isinstance(vocab_size_or_config_json_file, unicode)): 117 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 118 | json_config = json.loads(reader.read()) 119 | for key, value in json_config.items(): 120 | self.__dict__[key] = value 121 | elif isinstance(vocab_size_or_config_json_file, int): 122 | self.vocab_size = vocab_size_or_config_json_file 123 | self.hidden_size = hidden_size 124 | self.num_hidden_layers = num_hidden_layers 125 | self.num_attention_heads = num_attention_heads 126 | self.hidden_act = hidden_act 127 | self.intermediate_size = intermediate_size 128 | self.hidden_dropout_prob = hidden_dropout_prob 129 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 130 | self.max_position_embeddings = max_position_embeddings 131 | self.type_vocab_size = type_vocab_size 132 | self.initializer_range = initializer_range 133 | self.layer_norm_eps = layer_norm_eps 134 | else: 135 | raise ValueError("First argument must be either a vocabulary size (int)" 136 | "or the path to a pretrained model config file (str)") 137 | 138 | @classmethod 139 | def from_dict(cls, json_object): 140 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 141 | config = BertConfig(vocab_size_or_config_json_file=-1) 142 | for key, value in json_object.items(): 143 | config.__dict__[key] = value 144 | return config 145 | 146 | @classmethod 147 | def from_json_file(cls, json_file): 148 | """Constructs a `BertConfig` from a json file of parameters.""" 149 | with open(json_file, "r", encoding='utf-8') as reader: 150 | text = reader.read() 151 | return cls.from_dict(json.loads(text)) 152 | 153 | def __repr__(self): 154 | return str(self.to_json_string()) 155 | 156 | def to_dict(self): 157 | """Serializes this instance to a Python dictionary.""" 158 | output = copy.deepcopy(self.__dict__) 159 | return output 160 | 161 | def to_json_string(self): 162 | """Serializes this instance to a JSON string.""" 163 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 164 | 165 | def to_json_file(self, json_file_path): 166 | """ Save this instance to a json file.""" 167 | with open(json_file_path, "w", encoding='utf-8') as writer: 168 | writer.write(self.to_json_string()) 169 | 170 | try: 171 | from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm 172 | except ImportError: 173 | logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .") 174 | class BertLayerNorm(nn.Module): 175 | def __init__(self, hidden_size, eps=1e-12): 176 | """Construct a layernorm module in the TF style (epsilon inside the square root). 177 | """ 178 | super(BertLayerNorm, self).__init__() 179 | self.weight = nn.Parameter(torch.ones(hidden_size)) 180 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 181 | self.variance_epsilon = eps 182 | 183 | def forward(self, x): 184 | u = x.mean(-1, keepdim=True) 185 | s = (x - u).pow(2).mean(-1, keepdim=True) 186 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 187 | return self.weight * x + self.bias 188 | 189 | class BertEmbeddings(nn.Module): 190 | """Construct the embeddings from word, position and token_type embeddings. 191 | """ 192 | def __init__(self, config): 193 | super(BertEmbeddings, self).__init__() 194 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) 195 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 196 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 197 | 198 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 199 | # any TensorFlow checkpoint file 200 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 201 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 202 | 203 | def forward(self, input_ids, token_type_ids=None): 204 | seq_length = input_ids.size(1) 205 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 206 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 207 | if token_type_ids is None: 208 | token_type_ids = torch.zeros_like(input_ids) 209 | 210 | words_embeddings = self.word_embeddings(input_ids) 211 | position_embeddings = self.position_embeddings(position_ids) 212 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 213 | 214 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 215 | embeddings = self.LayerNorm(embeddings) 216 | embeddings = self.dropout(embeddings) 217 | return embeddings 218 | 219 | 220 | class BertSelfAttention(nn.Module): 221 | def __init__(self, config, output_attentions=False, keep_multihead_output=False): 222 | super(BertSelfAttention, self).__init__() 223 | if config.hidden_size % config.num_attention_heads != 0: 224 | raise ValueError( 225 | "The hidden size (%d) is not a multiple of the number of attention " 226 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 227 | self.output_attentions = output_attentions 228 | self.keep_multihead_output = keep_multihead_output 229 | self.multihead_output = None 230 | 231 | self.num_attention_heads = config.num_attention_heads 232 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 233 | self.all_head_size = self.num_attention_heads * self.attention_head_size 234 | 235 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 236 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 237 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 238 | 239 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 240 | 241 | def transpose_for_scores(self, x): 242 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 243 | x = x.view(*new_x_shape) 244 | return x.permute(0, 2, 1, 3) 245 | 246 | def forward(self, hidden_states, attention_mask, head_mask=None): 247 | mixed_query_layer = self.query(hidden_states) 248 | mixed_key_layer = self.key(hidden_states) 249 | mixed_value_layer = self.value(hidden_states) 250 | 251 | query_layer = self.transpose_for_scores(mixed_query_layer) 252 | key_layer = self.transpose_for_scores(mixed_key_layer) 253 | value_layer = self.transpose_for_scores(mixed_value_layer) 254 | 255 | # Take the dot product between "query" and "key" to get the raw attention scores. 256 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 257 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 258 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 259 | attention_scores = attention_scores + attention_mask 260 | 261 | # Normalize the attention scores to probabilities. 262 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 263 | 264 | # This is actually dropping out entire tokens to attend to, which might 265 | # seem a bit unusual, but is taken from the original Transformer paper. 266 | attention_probs = self.dropout(attention_probs) 267 | 268 | # Mask heads if we want to 269 | if head_mask is not None: 270 | attention_probs = attention_probs * head_mask 271 | 272 | context_layer = torch.matmul(attention_probs, value_layer) 273 | if self.keep_multihead_output: 274 | self.multihead_output = context_layer 275 | self.multihead_output.retain_grad() 276 | 277 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 278 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 279 | context_layer = context_layer.view(*new_context_layer_shape) 280 | if self.output_attentions: 281 | return attention_probs, context_layer 282 | return context_layer 283 | 284 | 285 | class BertSelfOutput(nn.Module): 286 | def __init__(self, config): 287 | super(BertSelfOutput, self).__init__() 288 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 289 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 290 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 291 | 292 | def forward(self, hidden_states, input_tensor): 293 | hidden_states = self.dense(hidden_states) 294 | hidden_states = self.dropout(hidden_states) 295 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 296 | return hidden_states 297 | 298 | 299 | class BertAttention(nn.Module): 300 | def __init__(self, config, output_attentions=False, keep_multihead_output=False): 301 | super(BertAttention, self).__init__() 302 | self.output_attentions = output_attentions 303 | self.self = BertSelfAttention(config, output_attentions=output_attentions, 304 | keep_multihead_output=keep_multihead_output) 305 | self.output = BertSelfOutput(config) 306 | 307 | def prune_heads(self, heads): 308 | if len(heads) == 0: 309 | return 310 | mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) 311 | for head in heads: 312 | mask[head] = 0 313 | mask = mask.view(-1).contiguous().eq(1) 314 | index = torch.arange(len(mask))[mask].long() 315 | # Prune linear layers 316 | self.self.query = prune_linear_layer(self.self.query, index) 317 | self.self.key = prune_linear_layer(self.self.key, index) 318 | self.self.value = prune_linear_layer(self.self.value, index) 319 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) 320 | # Update hyper params 321 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads) 322 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads 323 | 324 | def forward(self, input_tensor, attention_mask, head_mask=None): 325 | self_output = self.self(input_tensor, attention_mask, head_mask) 326 | if self.output_attentions: 327 | attentions, self_output = self_output 328 | attention_output = self.output(self_output, input_tensor) 329 | if self.output_attentions: 330 | return attentions, attention_output 331 | return attention_output 332 | 333 | 334 | class BertIntermediate(nn.Module): 335 | def __init__(self, config): 336 | super(BertIntermediate, self).__init__() 337 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 338 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): 339 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 340 | else: 341 | self.intermediate_act_fn = config.hidden_act 342 | 343 | def forward(self, hidden_states): 344 | hidden_states = self.dense(hidden_states) 345 | hidden_states = self.intermediate_act_fn(hidden_states) 346 | return hidden_states 347 | 348 | 349 | class BertOutput(nn.Module): 350 | def __init__(self, config): 351 | super(BertOutput, self).__init__() 352 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 353 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 354 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 355 | 356 | def forward(self, hidden_states, input_tensor): 357 | hidden_states = self.dense(hidden_states) 358 | hidden_states = self.dropout(hidden_states) 359 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 360 | return hidden_states 361 | 362 | 363 | class BertLayer(nn.Module): 364 | def __init__(self, config, output_attentions=False, keep_multihead_output=False): 365 | super(BertLayer, self).__init__() 366 | self.output_attentions = output_attentions 367 | self.attention = BertAttention(config, output_attentions=output_attentions, 368 | keep_multihead_output=keep_multihead_output) 369 | self.intermediate = BertIntermediate(config) 370 | self.output = BertOutput(config) 371 | 372 | def forward(self, hidden_states, attention_mask, head_mask=None): 373 | attention_output = self.attention(hidden_states, attention_mask, head_mask) 374 | if self.output_attentions: 375 | attentions, attention_output = attention_output 376 | intermediate_output = self.intermediate(attention_output) 377 | layer_output = self.output(intermediate_output, attention_output) 378 | if self.output_attentions: 379 | return attentions, layer_output 380 | return layer_output 381 | 382 | 383 | class BertEncoder(nn.Module): 384 | def __init__(self, config, output_attentions=False, keep_multihead_output=False): 385 | super(BertEncoder, self).__init__() 386 | self.output_attentions = output_attentions 387 | layer = BertLayer(config, output_attentions=output_attentions, 388 | keep_multihead_output=keep_multihead_output) 389 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 390 | 391 | def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, head_mask=None): 392 | all_encoder_layers = [] 393 | all_attentions = [] 394 | for i, layer_module in enumerate(self.layer): 395 | hidden_states = layer_module(hidden_states, attention_mask, head_mask[i]) 396 | if self.output_attentions: 397 | attentions, hidden_states = hidden_states 398 | all_attentions.append(attentions) 399 | if output_all_encoded_layers: 400 | all_encoder_layers.append(hidden_states) 401 | if not output_all_encoded_layers: 402 | all_encoder_layers.append(hidden_states) 403 | if self.output_attentions: 404 | return all_attentions, all_encoder_layers 405 | return all_encoder_layers 406 | 407 | 408 | class BertPooler(nn.Module): 409 | def __init__(self, config): 410 | super(BertPooler, self).__init__() 411 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 412 | self.activation = nn.Tanh() 413 | 414 | def forward(self, hidden_states): 415 | # We "pool" the model by simply taking the hidden state corresponding 416 | # to the first token. 417 | first_token_tensor = hidden_states[:, 0] 418 | pooled_output = self.dense(first_token_tensor) 419 | pooled_output = self.activation(pooled_output) 420 | return pooled_output 421 | 422 | 423 | class BertPredictionHeadTransform(nn.Module): 424 | def __init__(self, config): 425 | super(BertPredictionHeadTransform, self).__init__() 426 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 427 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): 428 | self.transform_act_fn = ACT2FN[config.hidden_act] 429 | else: 430 | self.transform_act_fn = config.hidden_act 431 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 432 | 433 | def forward(self, hidden_states): 434 | hidden_states = self.dense(hidden_states) 435 | hidden_states = self.transform_act_fn(hidden_states) 436 | hidden_states = self.LayerNorm(hidden_states) 437 | return hidden_states 438 | 439 | 440 | class BertLMPredictionHead(nn.Module): 441 | def __init__(self, config, bert_model_embedding_weights): 442 | super(BertLMPredictionHead, self).__init__() 443 | self.transform = BertPredictionHeadTransform(config) 444 | 445 | # The output weights are the same as the input embeddings, but there is 446 | # an output-only bias for each token. 447 | self.decoder = nn.Linear(bert_model_embedding_weights.size(1), 448 | bert_model_embedding_weights.size(0), 449 | bias=False) 450 | self.decoder.weight = bert_model_embedding_weights 451 | self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) 452 | 453 | def forward(self, hidden_states): 454 | hidden_states = self.transform(hidden_states) 455 | hidden_states = self.decoder(hidden_states) + self.bias 456 | return hidden_states 457 | 458 | 459 | class BertOnlyMLMHead(nn.Module): 460 | def __init__(self, config, bert_model_embedding_weights): 461 | super(BertOnlyMLMHead, self).__init__() 462 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) 463 | 464 | def forward(self, sequence_output): 465 | prediction_scores = self.predictions(sequence_output) 466 | return prediction_scores 467 | 468 | 469 | class BertOnlyNSPHead(nn.Module): 470 | def __init__(self, config): 471 | super(BertOnlyNSPHead, self).__init__() 472 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 473 | 474 | def forward(self, pooled_output): 475 | seq_relationship_score = self.seq_relationship(pooled_output) 476 | return seq_relationship_score 477 | 478 | 479 | class BertPreTrainingHeads(nn.Module): 480 | def __init__(self, config, bert_model_embedding_weights): 481 | super(BertPreTrainingHeads, self).__init__() 482 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) 483 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 484 | 485 | def forward(self, sequence_output, pooled_output): 486 | prediction_scores = self.predictions(sequence_output) 487 | seq_relationship_score = self.seq_relationship(pooled_output) 488 | return prediction_scores, seq_relationship_score 489 | 490 | 491 | class BertPreTrainedModel(nn.Module): 492 | """ An abstract class to handle weights initialization and 493 | a simple interface for dowloading and loading pretrained models. 494 | """ 495 | def __init__(self, config, *inputs, **kwargs): 496 | super(BertPreTrainedModel, self).__init__() 497 | if not isinstance(config, BertConfig): 498 | raise ValueError( 499 | "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " 500 | "To create a model from a Google pretrained model use " 501 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 502 | self.__class__.__name__, self.__class__.__name__ 503 | )) 504 | self.config = config 505 | 506 | def init_bert_weights(self, module): 507 | """ Initialize the weights. 508 | """ 509 | if isinstance(module, (nn.Linear, nn.Embedding)): 510 | # Slightly different from the TF version which uses truncated_normal for initialization 511 | # cf https://github.com/pytorch/pytorch/pull/5617 512 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 513 | elif isinstance(module, BertLayerNorm): 514 | module.bias.data.zero_() 515 | module.weight.data.fill_(1.0) 516 | if isinstance(module, nn.Linear) and module.bias is not None: 517 | module.bias.data.zero_() 518 | 519 | @classmethod 520 | def from_pretrained(cls, pretrained_model_path, config_path, *inputs, **kwargs): 521 | """ 522 | Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. 523 | Download and cache the pre-trained model file if needed. 524 | 525 | Params: 526 | pretrained_model_path: a path or url to a pretrained model 527 | config_path: a configuration file for the model 528 | state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models 529 | *inputs, **kwargs: additional input for the specific Bert class 530 | (ex: num_labels for BertForSequenceClassification) 531 | """ 532 | state_dict = kwargs.get('state_dict', None) 533 | kwargs.pop('state_dict', None) 534 | 535 | logger.info("loading weights file {}".format(pretrained_model_path)) 536 | logger.info("loading configuration file {}".format(config_path)) 537 | # Load config 538 | config = BertConfig.from_json_file(config_path) 539 | # logger.info("Model config {}".format(config)) 540 | # Instantiate model. 541 | model = cls(config, *inputs, **kwargs) 542 | if state_dict is None: 543 | state_dict = torch.load(pretrained_model_path, map_location='cpu') 544 | 545 | # Load from a PyTorch state_dict 546 | old_keys = [] 547 | new_keys = [] 548 | for key in state_dict.keys(): 549 | new_key = None 550 | if 'gamma' in key: 551 | new_key = key.replace('gamma', 'weight') 552 | if 'beta' in key: 553 | new_key = key.replace('beta', 'bias') 554 | if new_key: 555 | old_keys.append(key) 556 | new_keys.append(new_key) 557 | for old_key, new_key in zip(old_keys, new_keys): 558 | state_dict[new_key] = state_dict.pop(old_key) 559 | 560 | missing_keys = [] 561 | unexpected_keys = [] 562 | error_msgs = [] 563 | # copy state_dict so _load_from_state_dict can modify it 564 | metadata = getattr(state_dict, '_metadata', None) 565 | state_dict = state_dict.copy() 566 | if metadata is not None: 567 | state_dict._metadata = metadata 568 | 569 | def load(module, prefix=''): 570 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 571 | module._load_from_state_dict( 572 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 573 | for name, child in module._modules.items(): 574 | if child is not None: 575 | load(child, prefix + name + '.') 576 | start_prefix = '' 577 | if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()): 578 | start_prefix = 'bert.' 579 | load(model, prefix=start_prefix) 580 | if len(missing_keys) > 0: 581 | logger.info("Weights of {} not initialized from pretrained model: {}".format( 582 | model.__class__.__name__, missing_keys)) 583 | if len(unexpected_keys) > 0: 584 | logger.info("Weights from pretrained model not used in {}: {}".format( 585 | model.__class__.__name__, unexpected_keys)) 586 | if len(error_msgs) > 0: 587 | raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 588 | model.__class__.__name__, "\n\t".join(error_msgs))) 589 | return model 590 | 591 | 592 | class BertModel(BertPreTrainedModel): 593 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 594 | 595 | Params: 596 | `config`: a BertConfig class instance with the configuration to build a new model 597 | `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False 598 | `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. 599 | This can be used to compute head importance metrics. Default: False 600 | 601 | Inputs: 602 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 603 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 604 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 605 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 606 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 607 | a `sentence B` token (see BERT paper for more details). 608 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 609 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 610 | input sequence length in the current batch. It's the mask that we typically use for attention when 611 | a batch has varying length sentences. 612 | `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. 613 | `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. 614 | It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. 615 | 616 | 617 | Outputs: Tuple of (encoded_layers, pooled_output) 618 | `encoded_layers`: controled by `output_all_encoded_layers` argument: 619 | - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end 620 | of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each 621 | encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], 622 | - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding 623 | to the last attention block of shape [batch_size, sequence_length, hidden_size], 624 | `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a 625 | classifier pretrained on top of the hidden state associated to the first character of the 626 | input (`CLS`) to train on the Next-Sentence task (see BERT's paper). 627 | 628 | Example usage: 629 | ```python 630 | # Already been converted into WordPiece token ids 631 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 632 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 633 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 634 | 635 | config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 636 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 637 | 638 | model = modeling.BertModel(config=config) 639 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 640 | ``` 641 | """ 642 | def __init__(self, config, output_attentions=False, keep_multihead_output=False): 643 | super(BertModel, self).__init__(config) 644 | self.output_attentions = output_attentions 645 | self.embeddings = BertEmbeddings(config) 646 | self.encoder = BertEncoder(config, output_attentions=output_attentions, 647 | keep_multihead_output=keep_multihead_output) 648 | self.pooler = BertPooler(config) 649 | self.apply(self.init_bert_weights) 650 | 651 | def prune_heads(self, heads_to_prune): 652 | """ Prunes heads of the model. 653 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 654 | """ 655 | for layer, heads in heads_to_prune.items(): 656 | self.encoder.layer[layer].attention.prune_heads(heads) 657 | 658 | def get_multihead_outputs(self): 659 | """ Gather all multi-head outputs. 660 | Return: list (layers) of multihead module outputs with gradients 661 | """ 662 | return [layer.attention.self.multihead_output for layer in self.encoder.layer] 663 | 664 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, head_mask=None): 665 | if attention_mask is None: 666 | attention_mask = torch.ones_like(input_ids) 667 | if token_type_ids is None: 668 | token_type_ids = torch.zeros_like(input_ids) 669 | 670 | # We create a 3D attention mask from a 2D tensor mask. 671 | # Sizes are [batch_size, 1, 1, to_seq_length] 672 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 673 | # this attention mask is more simple than the triangular masking of causal attention 674 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 675 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 676 | 677 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 678 | # masked positions, this operation will create a tensor which is 0.0 for 679 | # positions we want to attend and -10000.0 for masked positions. 680 | # Since we are adding it to the raw scores before the softmax, this is 681 | # effectively the same as removing these entirely. 682 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 683 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 684 | 685 | # Prepare head mask if needed 686 | # 1.0 in head_mask indicate we keep the head 687 | # attention_probs has shape bsz x n_heads x N x N 688 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 689 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 690 | if head_mask is not None: 691 | if head_mask.dim() == 1: 692 | head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 693 | head_mask = head_mask.expand_as(self.config.num_hidden_layers, -1, -1, -1, -1) 694 | elif head_mask.dim() == 2: 695 | head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer 696 | head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility 697 | else: 698 | head_mask = [None] * self.config.num_hidden_layers 699 | 700 | embedding_output = self.embeddings(input_ids, token_type_ids) 701 | encoded_layers = self.encoder(embedding_output, 702 | extended_attention_mask, 703 | output_all_encoded_layers=output_all_encoded_layers, 704 | head_mask=head_mask) 705 | if self.output_attentions: 706 | all_attentions, encoded_layers = encoded_layers 707 | sequence_output = encoded_layers[-1] 708 | pooled_output = self.pooler(sequence_output) 709 | if not output_all_encoded_layers: 710 | encoded_layers = encoded_layers[-1] 711 | if self.output_attentions: 712 | return all_attentions, encoded_layers, pooled_output 713 | return encoded_layers, pooled_output 714 | 715 | 716 | class BertForPreTraining(BertPreTrainedModel): 717 | """BERT model with pre-training heads. 718 | This module comprises the BERT model followed by the two pre-training heads: 719 | - the masked language modeling head, and 720 | - the next sentence classification head. 721 | 722 | Params: 723 | `config`: a BertConfig class instance with the configuration to build a new model 724 | `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False 725 | `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. 726 | This can be used to compute head importance metrics. Default: False 727 | 728 | Inputs: 729 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 730 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 731 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 732 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 733 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 734 | a `sentence B` token (see BERT paper for more details). 735 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 736 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 737 | input sequence length in the current batch. It's the mask that we typically use for attention when 738 | a batch has varying length sentences. 739 | `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 740 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 741 | is only computed for the labels set in [0, ..., vocab_size] 742 | `next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size] 743 | with indices selected in [0, 1]. 744 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence. 745 | `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. 746 | It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. 747 | 748 | Outputs: 749 | if `masked_lm_labels` and `next_sentence_label` are not `None`: 750 | Outputs the total_loss which is the sum of the masked language modeling loss and the next 751 | sentence classification loss. 752 | if `masked_lm_labels` or `next_sentence_label` is `None`: 753 | Outputs a tuple comprising 754 | - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and 755 | - the next sentence classification logits of shape [batch_size, 2]. 756 | 757 | Example usage: 758 | ```python 759 | # Already been converted into WordPiece token ids 760 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 761 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 762 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 763 | 764 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 765 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 766 | 767 | model = BertForPreTraining(config) 768 | masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) 769 | ``` 770 | """ 771 | def __init__(self, config, output_attentions=False, keep_multihead_output=False): 772 | super(BertForPreTraining, self).__init__(config) 773 | self.output_attentions = output_attentions 774 | self.bert = BertModel(config, output_attentions=output_attentions, 775 | keep_multihead_output=keep_multihead_output) 776 | self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) 777 | self.apply(self.init_bert_weights) 778 | 779 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, head_mask=None): 780 | outputs = self.bert(input_ids, token_type_ids, attention_mask, 781 | output_all_encoded_layers=False, head_mask=head_mask) 782 | if self.output_attentions: 783 | all_attentions, sequence_output, pooled_output = outputs 784 | else: 785 | sequence_output, pooled_output = outputs 786 | prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) 787 | 788 | if masked_lm_labels is not None and next_sentence_label is not None: 789 | loss_fct = CrossEntropyLoss(ignore_index=-1) 790 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 791 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 792 | total_loss = masked_lm_loss + next_sentence_loss 793 | return total_loss 794 | elif self.output_attentions: 795 | return all_attentions, prediction_scores, seq_relationship_score 796 | return prediction_scores, seq_relationship_score 797 | 798 | 799 | class BertForMaskedLM(BertPreTrainedModel): 800 | """BERT model with the masked language modeling head. 801 | This module comprises the BERT model followed by the masked language modeling head. 802 | 803 | Params: 804 | `config`: a BertConfig class instance with the configuration to build a new model 805 | `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False 806 | `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. 807 | This can be used to compute head importance metrics. Default: False 808 | 809 | Inputs: 810 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 811 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 812 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 813 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 814 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 815 | a `sentence B` token (see BERT paper for more details). 816 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 817 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 818 | input sequence length in the current batch. It's the mask that we typically use for attention when 819 | a batch has varying length sentences. 820 | `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 821 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 822 | is only computed for the labels set in [0, ..., vocab_size] 823 | `head_mask`: an optional torch.LongTensor of shape [num_heads] with indices 824 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 825 | input sequence length in the current batch. It's the mask that we typically use for attention when 826 | a batch has varying length sentences. 827 | `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. 828 | It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. 829 | 830 | Outputs: 831 | if `masked_lm_labels` is not `None`: 832 | Outputs the masked language modeling loss. 833 | if `masked_lm_labels` is `None`: 834 | Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size]. 835 | 836 | Example usage: 837 | ```python 838 | # Already been converted into WordPiece token ids 839 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 840 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 841 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 842 | 843 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 844 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 845 | 846 | model = BertForMaskedLM(config) 847 | masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) 848 | ``` 849 | """ 850 | def __init__(self, config, output_attentions=False, keep_multihead_output=False): 851 | super(BertForMaskedLM, self).__init__(config) 852 | self.output_attentions = output_attentions 853 | self.bert = BertModel(config, output_attentions=output_attentions, 854 | keep_multihead_output=keep_multihead_output) 855 | self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) 856 | self.apply(self.init_bert_weights) 857 | 858 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None): 859 | outputs = self.bert(input_ids, token_type_ids, attention_mask, 860 | output_all_encoded_layers=False, 861 | head_mask=head_mask) 862 | if self.output_attentions: 863 | all_attentions, sequence_output, _ = outputs 864 | else: 865 | sequence_output, _ = outputs 866 | prediction_scores = self.cls(sequence_output) 867 | 868 | if masked_lm_labels is not None: 869 | loss_fct = CrossEntropyLoss(ignore_index=-1) 870 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 871 | return masked_lm_loss 872 | elif self.output_attentions: 873 | return all_attentions, prediction_scores 874 | return prediction_scores 875 | 876 | 877 | class BertForNextSentencePrediction(BertPreTrainedModel): 878 | """BERT model with next sentence prediction head. 879 | This module comprises the BERT model followed by the next sentence classification head. 880 | 881 | Params: 882 | `config`: a BertConfig class instance with the configuration to build a new model 883 | `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False 884 | `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. 885 | This can be used to compute head importance metrics. Default: False 886 | 887 | Inputs: 888 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 889 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 890 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 891 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 892 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 893 | a `sentence B` token (see BERT paper for more details). 894 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 895 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 896 | input sequence length in the current batch. It's the mask that we typically use for attention when 897 | a batch has varying length sentences. 898 | `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] 899 | with indices selected in [0, 1]. 900 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence. 901 | `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. 902 | It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. 903 | 904 | Outputs: 905 | if `next_sentence_label` is not `None`: 906 | Outputs the total_loss which is the sum of the masked language modeling loss and the next 907 | sentence classification loss. 908 | if `next_sentence_label` is `None`: 909 | Outputs the next sentence classification logits of shape [batch_size, 2]. 910 | 911 | Example usage: 912 | ```python 913 | # Already been converted into WordPiece token ids 914 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 915 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 916 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 917 | 918 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 919 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 920 | 921 | model = BertForNextSentencePrediction(config) 922 | seq_relationship_logits = model(input_ids, token_type_ids, input_mask) 923 | ``` 924 | """ 925 | def __init__(self, config, output_attentions=False, keep_multihead_output=False): 926 | super(BertForNextSentencePrediction, self).__init__(config) 927 | self.output_attentions = output_attentions 928 | self.bert = BertModel(config, output_attentions=output_attentions, 929 | keep_multihead_output=keep_multihead_output) 930 | self.cls = BertOnlyNSPHead(config) 931 | self.apply(self.init_bert_weights) 932 | 933 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, head_mask=None): 934 | outputs = self.bert(input_ids, token_type_ids, attention_mask, 935 | output_all_encoded_layers=False, 936 | head_mask=head_mask) 937 | if self.output_attentions: 938 | all_attentions, _, pooled_output = outputs 939 | else: 940 | _, pooled_output = outputs 941 | seq_relationship_score = self.cls(pooled_output) 942 | 943 | if next_sentence_label is not None: 944 | loss_fct = CrossEntropyLoss(ignore_index=-1) 945 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 946 | return next_sentence_loss 947 | elif self.output_attentions: 948 | return all_attentions, seq_relationship_score 949 | return seq_relationship_score 950 | 951 | 952 | class BertForSequenceClassification(BertPreTrainedModel): 953 | """BERT model for classification. 954 | This module is composed of the BERT model with a linear layer on top of 955 | the pooled output. 956 | 957 | Params: 958 | `config`: a BertConfig class instance with the configuration to build a new model 959 | `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False 960 | `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. 961 | This can be used to compute head importance metrics. Default: False 962 | `num_labels`: the number of classes for the classifier. Default = 2. 963 | 964 | Inputs: 965 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 966 | with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts 967 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 968 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 969 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 970 | a `sentence B` token (see BERT paper for more details). 971 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 972 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 973 | input sequence length in the current batch. It's the mask that we typically use for attention when 974 | a batch has varying length sentences. 975 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 976 | with indices selected in [0, ..., num_labels]. 977 | `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. 978 | It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. 979 | 980 | Outputs: 981 | if `labels` is not `None`: 982 | Outputs the CrossEntropy classification loss of the output with the labels. 983 | if `labels` is `None`: 984 | Outputs the classification logits of shape [batch_size, num_labels]. 985 | 986 | Example usage: 987 | ```python 988 | # Already been converted into WordPiece token ids 989 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 990 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 991 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 992 | 993 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 994 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 995 | 996 | num_labels = 2 997 | 998 | model = BertForSequenceClassification(config, num_labels) 999 | logits = model(input_ids, token_type_ids, input_mask) 1000 | ``` 1001 | """ 1002 | def __init__(self, config, num_labels=2, output_attentions=False, keep_multihead_output=False): 1003 | super(BertForSequenceClassification, self).__init__(config) 1004 | self.output_attentions = output_attentions 1005 | self.num_labels = num_labels 1006 | self.bert = BertModel(config, output_attentions=output_attentions, 1007 | keep_multihead_output=keep_multihead_output) 1008 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1009 | self.classifier = nn.Linear(config.hidden_size, num_labels) 1010 | self.apply(self.init_bert_weights) 1011 | 1012 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None): 1013 | outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, head_mask=head_mask) 1014 | if self.output_attentions: 1015 | all_attentions, _, pooled_output = outputs 1016 | else: 1017 | _, pooled_output = outputs 1018 | pooled_output = self.dropout(pooled_output) 1019 | logits = self.classifier(pooled_output) 1020 | 1021 | if labels is not None: 1022 | loss_fct = CrossEntropyLoss() 1023 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1024 | return loss 1025 | elif self.output_attentions: 1026 | return all_attentions, logits 1027 | return logits 1028 | 1029 | 1030 | class BertForMultipleChoice(BertPreTrainedModel): 1031 | """BERT model for multiple choice tasks. 1032 | This module is composed of the BERT model with a linear layer on top of 1033 | the pooled output. 1034 | 1035 | Params: 1036 | `config`: a BertConfig class instance with the configuration to build a new model 1037 | `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False 1038 | `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. 1039 | This can be used to compute head importance metrics. Default: False 1040 | `num_choices`: the number of classes for the classifier. Default = 2. 1041 | 1042 | Inputs: 1043 | `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] 1044 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1045 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1046 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] 1047 | with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` 1048 | and type 1 corresponds to a `sentence B` token (see BERT paper for more details). 1049 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices 1050 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1051 | input sequence length in the current batch. It's the mask that we typically use for attention when 1052 | a batch has varying length sentences. 1053 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 1054 | with indices selected in [0, ..., num_choices]. 1055 | `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. 1056 | It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. 1057 | 1058 | Outputs: 1059 | if `labels` is not `None`: 1060 | Outputs the CrossEntropy classification loss of the output with the labels. 1061 | if `labels` is `None`: 1062 | Outputs the classification logits of shape [batch_size, num_labels]. 1063 | 1064 | Example usage: 1065 | ```python 1066 | # Already been converted into WordPiece token ids 1067 | input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) 1068 | input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) 1069 | token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) 1070 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1071 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1072 | 1073 | num_choices = 2 1074 | 1075 | model = BertForMultipleChoice(config, num_choices) 1076 | logits = model(input_ids, token_type_ids, input_mask) 1077 | ``` 1078 | """ 1079 | def __init__(self, config, num_choices=2, output_attentions=False, keep_multihead_output=False): 1080 | super(BertForMultipleChoice, self).__init__(config) 1081 | self.output_attentions = output_attentions 1082 | self.num_choices = num_choices 1083 | self.bert = BertModel(config, output_attentions=output_attentions, 1084 | keep_multihead_output=keep_multihead_output) 1085 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1086 | self.classifier = nn.Linear(config.hidden_size, 1) 1087 | self.apply(self.init_bert_weights) 1088 | 1089 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None): 1090 | flat_input_ids = input_ids.view(-1, input_ids.size(-1)) 1091 | flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None 1092 | flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None 1093 | outputs = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False, head_mask=head_mask) 1094 | if self.output_attentions: 1095 | all_attentions, _, pooled_output = outputs 1096 | else: 1097 | _, pooled_output = outputs 1098 | pooled_output = self.dropout(pooled_output) 1099 | logits = self.classifier(pooled_output) 1100 | reshaped_logits = logits.view(-1, self.num_choices) 1101 | 1102 | if labels is not None: 1103 | loss_fct = CrossEntropyLoss() 1104 | loss = loss_fct(reshaped_logits, labels) 1105 | return loss 1106 | elif self.output_attentions: 1107 | return all_attentions, reshaped_logits 1108 | return reshaped_logits 1109 | 1110 | 1111 | class BertForTokenClassification(BertPreTrainedModel): 1112 | """BERT model for token-level classification. 1113 | This module is composed of the BERT model with a linear layer on top of 1114 | the full hidden state of the last layer. 1115 | 1116 | Params: 1117 | `config`: a BertConfig class instance with the configuration to build a new model 1118 | `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False 1119 | `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. 1120 | This can be used to compute head importance metrics. Default: False 1121 | `num_labels`: the number of classes for the classifier. Default = 2. 1122 | 1123 | Inputs: 1124 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1125 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1126 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1127 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1128 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1129 | a `sentence B` token (see BERT paper for more details). 1130 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1131 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1132 | input sequence length in the current batch. It's the mask that we typically use for attention when 1133 | a batch has varying length sentences. 1134 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length] 1135 | with indices selected in [0, ..., num_labels]. 1136 | `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. 1137 | It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. 1138 | 1139 | Outputs: 1140 | if `labels` is not `None`: 1141 | Outputs the CrossEntropy classification loss of the output with the labels. 1142 | if `labels` is `None`: 1143 | Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. 1144 | 1145 | Example usage: 1146 | ```python 1147 | # Already been converted into WordPiece token ids 1148 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1149 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1150 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1151 | 1152 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1153 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1154 | 1155 | num_labels = 2 1156 | 1157 | model = BertForTokenClassification(config, num_labels) 1158 | logits = model(input_ids, token_type_ids, input_mask) 1159 | ``` 1160 | """ 1161 | def __init__(self, config, num_labels=2, output_attentions=False, keep_multihead_output=False): 1162 | super(BertForTokenClassification, self).__init__(config) 1163 | self.output_attentions = output_attentions 1164 | self.num_labels = num_labels 1165 | self.bert = BertModel(config, output_attentions=output_attentions, 1166 | keep_multihead_output=keep_multihead_output) 1167 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1168 | self.classifier = nn.Linear(config.hidden_size, num_labels) 1169 | self.apply(self.init_bert_weights) 1170 | 1171 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None): 1172 | outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, head_mask=head_mask) 1173 | if self.output_attentions: 1174 | all_attentions, sequence_output, _ = outputs 1175 | else: 1176 | sequence_output, _ = outputs 1177 | sequence_output = self.dropout(sequence_output) 1178 | logits = self.classifier(sequence_output) 1179 | 1180 | if labels is not None: 1181 | loss_fct = CrossEntropyLoss() 1182 | # Only keep active parts of the loss 1183 | if attention_mask is not None: 1184 | active_loss = attention_mask.view(-1) == 1 1185 | active_logits = logits.view(-1, self.num_labels)[active_loss] 1186 | active_labels = labels.view(-1)[active_loss] 1187 | loss = loss_fct(active_logits, active_labels) 1188 | else: 1189 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1190 | return loss 1191 | elif self.output_attentions: 1192 | return all_attentions, logits 1193 | return logits 1194 | 1195 | 1196 | class BertForQuestionAnswering(BertPreTrainedModel): 1197 | """BERT model for Question Answering (span extraction). 1198 | This module is composed of the BERT model with a linear layer on top of 1199 | the sequence output that computes start_logits and end_logits 1200 | 1201 | Params: 1202 | `config`: a BertConfig class instance with the configuration to build a new model 1203 | `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False 1204 | `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. 1205 | This can be used to compute head importance metrics. Default: False 1206 | 1207 | Inputs: 1208 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1209 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1210 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1211 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1212 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1213 | a `sentence B` token (see BERT paper for more details). 1214 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1215 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1216 | input sequence length in the current batch. It's the mask that we typically use for attention when 1217 | a batch has varying length sentences. 1218 | `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. 1219 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken 1220 | into account for computing the loss. 1221 | `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. 1222 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken 1223 | into account for computing the loss. 1224 | `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. 1225 | It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. 1226 | 1227 | Outputs: 1228 | if `start_positions` and `end_positions` are not `None`: 1229 | Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. 1230 | if `start_positions` or `end_positions` is `None`: 1231 | Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end 1232 | position tokens of shape [batch_size, sequence_length]. 1233 | 1234 | Example usage: 1235 | ```python 1236 | # Already been converted into WordPiece token ids 1237 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1238 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1239 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1240 | 1241 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1242 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1243 | 1244 | model = BertForQuestionAnswering(config) 1245 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 1246 | ``` 1247 | """ 1248 | def __init__(self, config, output_attentions=False, keep_multihead_output=False): 1249 | super(BertForQuestionAnswering, self).__init__(config) 1250 | self.output_attentions = output_attentions 1251 | self.bert = BertModel(config, output_attentions=output_attentions, 1252 | keep_multihead_output=keep_multihead_output) 1253 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 1254 | self.apply(self.init_bert_weights) 1255 | 1256 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, 1257 | end_positions=None, head_mask=None): 1258 | outputs = self.bert(input_ids, token_type_ids, attention_mask, 1259 | output_all_encoded_layers=False, 1260 | head_mask=head_mask) 1261 | if self.output_attentions: 1262 | all_attentions, sequence_output, _ = outputs 1263 | else: 1264 | sequence_output, _ = outputs 1265 | logits = self.qa_outputs(sequence_output) 1266 | start_logits, end_logits = logits.split(1, dim=-1) 1267 | start_logits = start_logits.squeeze(-1) 1268 | end_logits = end_logits.squeeze(-1) 1269 | 1270 | if start_positions is not None and end_positions is not None: 1271 | # If we are on multi-GPU, split add a dimension 1272 | if len(start_positions.size()) > 1: 1273 | start_positions = start_positions.squeeze(-1) 1274 | if len(end_positions.size()) > 1: 1275 | end_positions = end_positions.squeeze(-1) 1276 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 1277 | ignored_index = start_logits.size(1) 1278 | start_positions.clamp_(0, ignored_index) 1279 | end_positions.clamp_(0, ignored_index) 1280 | 1281 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 1282 | start_loss = loss_fct(start_logits, start_positions) 1283 | end_loss = loss_fct(end_logits, end_positions) 1284 | total_loss = (start_loss + end_loss) / 2 1285 | return total_loss 1286 | elif self.output_attentions: 1287 | return all_attentions, start_logits, end_logits 1288 | return start_logits, end_logits 1289 | 1290 | 1291 | class BertForTaskNLU(BertPreTrainedModel): 1292 | """BERT model for Task-Orient NLU .""" 1293 | def __init__(self, config, output_attentions=False, keep_multihead_output=False, *inputs, **kwargs): 1294 | label_list = kwargs['label_list'] 1295 | super(BertForTaskNLU, self).__init__(config) 1296 | self.domain_num = len(label_list["domain"]) 1297 | self.intent_num = len(label_list["intent"]) 1298 | self.slots_num = len(label_list["slots"]) 1299 | self.max_seq_length = kwargs['max_seq_len'] 1300 | self.bert = BertModel(config, output_attentions=output_attentions, 1301 | keep_multihead_output=keep_multihead_output) 1302 | self.hidden_size = config.hidden_size 1303 | self.domain_outputs = nn.Linear(config.hidden_size, self.domain_num) 1304 | self.intent_outputs = nn.Linear(config.hidden_size, self.intent_num) 1305 | self.slots_outputs = nn.Linear(config.hidden_size, self.slots_num) 1306 | self.dropout = nn.Dropout(p = 0.1) 1307 | self.apply(self.init_bert_weights) 1308 | 1309 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, head_mask=None): 1310 | outputs = self.bert(input_ids, token_type_ids, attention_mask, 1311 | output_all_encoded_layers=False, 1312 | head_mask=head_mask) 1313 | sequence_output, cls_output = outputs 1314 | cls_output = self.dropout(cls_output) 1315 | sequence_output = self.dropout(sequence_output) 1316 | 1317 | domain_logits = self.domain_outputs(cls_output) 1318 | intent_logits = self.intent_outputs(cls_output) 1319 | 1320 | sequence_output = sequence_output.view(-1, self.hidden_size) 1321 | slots_logits = self.slots_outputs(sequence_output) 1322 | slots_logits = slots_logits.view(-1, self.max_seq_length, self.slots_num) 1323 | 1324 | return domain_logits, intent_logits, slots_logits --------------------------------------------------------------------------------