├── find_entity ├── __init__.py ├── probable_acmation.py ├── exacter_acmation.py └── acmation.py ├── __init__.py ├── official_transformer ├── __init__.py ├── README ├── model_params.py ├── ffn_layer.py ├── embedding_layer.py ├── model_utils.py ├── tpu.py └── attention_layer.py ├── entity_files ├── slot-dictionaries │ ├── custom_destination.txt │ ├── instrument.txt │ ├── language.txt │ ├── toplist.txt │ ├── style.txt │ ├── emotion.txt │ ├── scene.txt │ ├── theme.txt │ └── age.txt ├── custom_destination.json ├── scene.json ├── age.json ├── instrument.json ├── language.json ├── toplist.json ├── origin.json ├── emotion.json ├── theme.json ├── style.json └── frequentSinger.json ├── configs ├── lasertagger_config.json └── lasertagger_config-tiny.json ├── requirements.txt ├── README.md ├── curLine_file.py ├── .gitignore ├── score_main.py ├── utils.py ├── confusion_words_duoyin.py ├── split_corpus.py ├── start.sh ├── score_lib.py ├── preprocess_main.py ├── confusion_words.py ├── predict_param.py ├── predict_utils.py ├── sari_hook.py ├── LICENSE ├── run_lasertagger.py ├── run_lasertagger_utils.py ├── predict_main.py └── confusion_words_danyin.py /find_entity/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import compute_lcs -------------------------------------------------------------------------------- /official_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /entity_files/slot-dictionaries/custom_destination.txt: -------------------------------------------------------------------------------- 1 | 你家 2 | 公司 3 | 家 4 | -------------------------------------------------------------------------------- /entity_files/custom_destination.json: -------------------------------------------------------------------------------- 1 | { 2 | "公司": [ 3 | "公司", 4 | 28 5 | ], 6 | "家": [ 7 | "家", 8 | 81 9 | ] 10 | } -------------------------------------------------------------------------------- /entity_files/scene.json: -------------------------------------------------------------------------------- 1 | { 2 | "酒吧": [ 3 | "酒吧", 4 | 3 5 | ], 6 | "咖啡厅": [ 7 | "咖啡厅", 8 | 3 9 | ], 10 | "k歌": [ 11 | "k歌", 12 | 1 13 | ] 14 | } -------------------------------------------------------------------------------- /entity_files/age.json: -------------------------------------------------------------------------------- 1 | { 2 | "儿歌": [ 3 | "儿歌", 4 | 36 5 | ], 6 | "二零一一": [ 7 | "二零一一", 8 | 1 9 | ], 10 | "零零后": [ 11 | "零零后", 12 | 1 13 | ] 14 | } -------------------------------------------------------------------------------- /entity_files/instrument.json: -------------------------------------------------------------------------------- 1 | { 2 | "古筝": [ 3 | "古筝", 4 | 5 5 | ], 6 | "锁那": [ 7 | "唢呐", 8 | 2 9 | ], 10 | "唢呐": [ 11 | "唢呐", 12 | 3 13 | ], 14 | "萨克斯": [ 15 | "萨克斯", 16 | 4 17 | ] 18 | } -------------------------------------------------------------------------------- /official_transformer/README: -------------------------------------------------------------------------------- 1 | This directory contains Transformer related code files. These are copied from: 2 | https://github.com/tensorflow/models/tree/master/official/transformer 3 | to make it possible to install all LaserTagger dependencies with pip (the 4 | official Transformer implementation doesn't support pip installation). -------------------------------------------------------------------------------- /entity_files/slot-dictionaries/instrument.txt: -------------------------------------------------------------------------------- 1 | 钢琴曲 2 | 琵琶曲 3 | 吉他 4 | 管弦 5 | 吉他曲 6 | 管弦乐 7 | 缸琴 8 | 交响 9 | harmonica 10 | 葫芦丝 11 | 唢呐 12 | 竖琴 13 | piano 14 | 小提琴 15 | 笛子 16 | 二胡 17 | 钢琴 18 | 大提琴 19 | 吉它 20 | 风琴 21 | 手风琴 22 | 口琴 23 | 笛音 24 | guitar 25 | 琵琶 26 | 交响乐 27 | 古筝 28 | 萨克斯 29 | 刚琴 30 | 港琴 31 | -------------------------------------------------------------------------------- /entity_files/slot-dictionaries/language.txt: -------------------------------------------------------------------------------- 1 | 日语 2 | 粤语 3 | 韩国 4 | 华语 5 | 那美克星语 6 | 德语 7 | 印度语 8 | 韩语 9 | 英文 10 | 巫族语 11 | 泰国 12 | 外语 13 | 普通话 14 | 葡萄牙 15 | 阿拉伯语 16 | 汉语 17 | 香港话 18 | 中国 19 | 欧美 20 | 法文 21 | 外国 22 | 泰语 23 | 法国 24 | 日本 25 | 中文 26 | 闽南语 27 | 因语 28 | 法语 29 | 小语种 30 | 国语 31 | 藏族 32 | 闽南 33 | 蒙语 34 | 欧美英语 35 | 日文 36 | 葡萄牙语 37 | 阴语 38 | 印度 39 | 英语 40 | 广东话 41 | 韩文 42 | -------------------------------------------------------------------------------- /configs/lasertagger_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 11, 11 | "type_vocab_size": 2, 12 | "vocab_size": 21128, 13 | "decoder_num_hidden_layers": 1, 14 | "decoder_hidden_size": 768, 15 | "decoder_num_attention_heads": 4, 16 | "decoder_filter_size": 3072, 17 | "use_full_attention": false 18 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.15.2 # CPU Version of TensorFlow. 2 | # tensorflow-gpu==1.15.0 # GPU version of TensorFlow. 3 | absl-py==0.8.1 4 | astor==0.8.0 5 | bert-tensorflow==1.0.1 6 | gast==0.2.2 7 | google-pasta==0.1.7 8 | grpcio==1.24.3 9 | h5py==2.10.0 10 | Keras-Applications==1.0.8 11 | Keras-Preprocessing==1.1.0 12 | Markdown==3.1.1 13 | numpy==1.17.3 14 | pkg-resources==0.0.0 15 | protobuf==3.10.0 16 | scipy==1.3.1 17 | six==1.12.0 18 | tensorboard==1.15.0 19 | tensorflow-estimator==1.15.1 20 | termcolor==1.1.0 21 | Werkzeug==0.16.0 22 | wrapt==1.11.2 23 | flask 24 | -------------------------------------------------------------------------------- /entity_files/language.json: -------------------------------------------------------------------------------- 1 | { 2 | "藏族": [ 3 | "藏族", 4 | 1 5 | ], 6 | "华语": [ 7 | "华语", 8 | 1 9 | ], 10 | "中文": [ 11 | "中文", 12 | 7 13 | ], 14 | "国语": [ 15 | "国语", 16 | 3 17 | ], 18 | "闽南语": [ 19 | "闽南语", 20 | 4 21 | ], 22 | "粤语": [ 23 | "粤语", 24 | 7 25 | ], 26 | "欧美": [ 27 | "欧美", 28 | 1 29 | ], 30 | "闽南": [ 31 | "闽南", 32 | 2 33 | ], 34 | "韩国": [ 35 | "韩国", 36 | 1 37 | ] 38 | } -------------------------------------------------------------------------------- /configs/lasertagger_config-tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 312, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 1248, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 4, 11 | "type_vocab_size": 2, 12 | "vocab_size": 8021, 13 | "use_t2t_decoder": true, 14 | "decoder_num_hidden_layers": 1, 15 | "decoder_hidden_size": 768, 16 | "decoder_num_attention_heads": 4, 17 | "decoder_filter_size": 3072, 18 | "use_full_attention": false 19 | } 20 | -------------------------------------------------------------------------------- /entity_files/toplist.json: -------------------------------------------------------------------------------- 1 | { 2 | "流行": [ 3 | "流行", 4 | 8 5 | ], 6 | "最新": [ 7 | "最新", 8 | 7 9 | ], 10 | "好歌": [ 11 | "好歌", 12 | 2 13 | ], 14 | "好听": [ 15 | "好听", 16 | 5 17 | ], 18 | "网络": [ 19 | "网络", 20 | 8 21 | ], 22 | "最近": [ 23 | "最近", 24 | 1 25 | ], 26 | "最流行": [ 27 | "最流行", 28 | 1 29 | ], 30 | "热门": [ 31 | "热门", 32 | 2 33 | ], 34 | "热热热": [ 35 | "热热热", 36 | 1 37 | ], 38 | "新歌": [ 39 | "新歌", 40 | 2 41 | ], 42 | "港台歌曲": [ 43 | "港台歌曲", 44 | 1 45 | ] 46 | } -------------------------------------------------------------------------------- /entity_files/slot-dictionaries/toplist.txt: -------------------------------------------------------------------------------- 1 | 摩登 2 | 榜单 3 | 日韩 4 | 时髦 5 | 港台歌曲 6 | 风靡 7 | 第一首 8 | 出道时 9 | pop 10 | 热门 11 | 高品质音乐榜 12 | 最好听 13 | 最老 14 | 时兴 15 | 美国itunes榜 16 | vivo手机高品质音乐榜 17 | billboard 18 | 好听 19 | 怀旧华语 20 | 抢手 21 | 盛行 22 | qq音乐榜单 23 | 巅峰榜 24 | 风行 25 | 第一个 26 | 最老一个 27 | 最老一张 28 | 网络流行 29 | 网易音乐 30 | 第一张 31 | 新 32 | 最新 33 | 香港商台榜 34 | 大陆 35 | 受欢迎 36 | 韩国mnet榜 37 | 最受欢迎 38 | 榜首 39 | 日本公信榜 40 | 美国公告牌榜 41 | 热歌 42 | 最流行 43 | 流行 44 | 热热热 45 | 台湾幽浮榜 46 | 最近 47 | 动听 48 | 怀旧粤语 49 | 英国uk榜 50 | 新歌 51 | 播放最多 52 | 首张 53 | 时新 54 | 公告牌 55 | 香港电台榜 56 | 最火 57 | qq音乐 58 | 最热门 59 | 网络 60 | 火热 61 | 最热 62 | 网络热歌 63 | 云音乐 64 | 网易云音乐 65 | 怀旧英语 66 | 好歌 67 | 港台 68 | 好 69 | 内地 70 | -------------------------------------------------------------------------------- /entity_files/origin.json: -------------------------------------------------------------------------------- 1 | { 2 | "江苏": [ 3 | "江苏", 4 | 1 5 | ], 6 | "袁浦中学": [ 7 | "袁浦中学", 8 | 1 9 | ], 10 | "购物公园": [ 11 | "购物公园", 12 | 1 13 | ], 14 | "红安": [ 15 | "红安", 16 | 1 17 | ], 18 | "海思厂": [ 19 | "海思厂", 20 | 1 21 | ], 22 | "凯仕厂": [ 23 | "凯仕厂", 24 | 1 25 | ], 26 | "万达花园": [ 27 | "万达花园", 28 | 1 29 | ], 30 | "遵义": [ 31 | "遵义", 32 | 1 33 | ], 34 | "石马巷": [ 35 | "石马巷", 36 | 1 37 | ], 38 | "宾馆路": [ 39 | "宾馆路", 40 | 1 41 | ], 42 | "海椒市": [ 43 | "海椒市", 44 | 1 45 | ], 46 | "205国道": [ 47 | "205国道", 48 | 1 49 | ] 50 | } -------------------------------------------------------------------------------- /entity_files/emotion.json: -------------------------------------------------------------------------------- 1 | { 2 | "励志": [ 3 | "励志", 4 | 1 5 | ], 6 | "伤感": [ 7 | "伤感", 8 | 2 9 | ], 10 | "嗨曲": [ 11 | "嗨曲", 12 | 1 13 | ], 14 | "怀旧": [ 15 | "怀旧", 16 | 1 17 | ], 18 | "快歌": [ 19 | "快歌", 20 | 3 21 | ], 22 | "寂寞": [ 23 | "寂寞", 24 | 1 25 | ], 26 | "劲爆": [ 27 | "劲爆", 28 | 6 29 | ], 30 | "柔情": [ 31 | "柔情", 32 | 1 33 | ], 34 | "火爆": [ 35 | "火爆", 36 | 1 37 | ], 38 | "温柔": [ 39 | "温柔", 40 | 1 41 | ], 42 | "欢快": [ 43 | "欢快", 44 | 4 45 | ], 46 | "抒情": [ 47 | "抒情", 48 | 3 49 | ], 50 | "慢歌": [ 51 | "慢歌", 52 | 1 53 | ], 54 | "猛烈": [ 55 | "猛烈", 56 | 1 57 | ] 58 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 音乐领域的意图和填槽,以及槽值纠错方案 2 | AI研习社比赛 对话系统中的口语理解 第一名 3 | 4 | 1.意图和填槽方案 5 | 模型是RoBERT+BiLSTM模型 6 | 通过AC自动机发现输入句子中的疑似实体,然后在输入特征中融入疑似实体的信息,从而提高模型的泛化能力; 7 | 2.槽值纠错方案 8 | 有些请求中会出现错别字,如果错别字出现在槽值中,则前面的填槽模型根据句式语义识别到的结果,后台服务也无法找到对应的歌曲。 9 | 因此,需要结合实体库和拼音特征对识别到的槽值进行纠正。 10 | 11 | 经过分析,语料中错误的槽值分为以下几种情况: ||的前后分别是纠错前后的槽值 12 | 1).同音字 例如:妈妈巴巴巴||妈妈爸爸;江湖大盗||江湖大道;摩托谣||摩托摇 13 | 2).错字 苍天一声笑||沧海一声笑;火锅调料||火锅底料 14 | 3).声母错了 体验||体面;道别的爱给特别的你||特别的爱给特别的你 15 | 4).韵母错 唱之歌||蝉之歌;说三就剩||说散就散;没有你陪伴真的很孤单||没有你陪伴真的好孤单 16 | 5).缺字 爱你千百回||一生爱你千百回;九百九十朵玫瑰||九百九十九朵玫瑰 17 | 6).表达不一致  1234||一二三四 18 | 7).停用字或标点 兄弟想你啦||兄弟想你了 19 | 8).差异很大 我想带你去土耳其 带你去旅行 20 | 采用拼音的编辑距离即可纠正1)情况;在得到拼音后替换容易混淆的声母对即可纠正3)情况; 21 | 在得到拼音后替换容易混淆的韵母对即可纠正4)情况;对实体或槽值中的数字转化为对应汉字即可纠正6)情况; 22 | 对实体或槽值做去停用字的操作即可纠正7)情况;其它2),5),8)的情况不容易纠正,具体能否纠正就看拼音的编辑距离和设定的阈值。 23 | 24 | 运行实验的方法: 25 | bash start.sh host_name 26 | 注:因为我是Linux系统下实验的,host_name是用户名。 27 | 请大家根据自身情况,自行修改脚本start.sh中的文件目录。 28 | 29 | ## License 30 | 31 | Apache 2.0; see [LICENSE](LICENSE) for details. 32 | -------------------------------------------------------------------------------- /entity_files/slot-dictionaries/style.txt: -------------------------------------------------------------------------------- 1 | b box 2 | 动感炫舞 3 | 含蓄 4 | 狂野 5 | 电子 6 | 文雅 7 | 乡间 8 | 拉丁 9 | 典雅 10 | 蓝草 11 | hip-hop 12 | 迪斯可 13 | 嘻哈风 14 | 嘻哈说唱 15 | 舞曲 16 | 风雅 17 | 斯卡 18 | 山歌 19 | 节奏布鲁斯 20 | 嘻哈 21 | 俚歌 22 | 大合唱 23 | 跳舞 24 | 广场舞 25 | 高雅 26 | 街舞 27 | 古风 28 | 男女合唱 29 | 大雅 30 | hip hop 31 | 民族风 32 | 高贵 33 | 精雅 34 | 摇滚乐 35 | 轻音乐 36 | 农村 37 | 舞蹈 38 | 英伦风 39 | 对唱 40 | 黑暗 41 | 电音 42 | 古老 43 | 乡村 44 | r&b 45 | 爵士 46 | 草原风 47 | 男女对唱 48 | 歌谣 49 | rapping 50 | 风谣 51 | 拉锯电音 52 | 摇滚 53 | 民谣 54 | 拉丁舞 55 | 劲歌 56 | 对歌 57 | 饶舌 58 | 后摇 59 | 朋克 60 | 串烧 61 | 雅致 62 | 狂放 63 | 民间 64 | 草原 65 | beatboxing 66 | 嘻嘻哈哈 67 | 巴萨诺瓦 68 | bossa nova 69 | 民俗 70 | 狂热 71 | 合唱 72 | 比波普 73 | 民族 74 | 爵士乐 75 | 蓝调 76 | 炫舞 77 | 新世纪 78 | 大草原 79 | 暗潮 80 | 金属 81 | 豪放 82 | 优雅 83 | rap 84 | 热舞 85 | 打碟 86 | 草地 87 | breaking 88 | 鬼步舞 89 | 陕北民歌 90 | 英伦 91 | 民歌 92 | 电子乐 93 | bebop 94 | 慢摇舞曲 95 | 现代蓝调 96 | 古典 97 | 淡雅 98 | 慢摇 99 | 乡下 100 | 说唱 101 | 布鲁斯 102 | 嘻嘻嘻哈哈哈 103 | -------------------------------------------------------------------------------- /entity_files/theme.json: -------------------------------------------------------------------------------- 1 | { 2 | "dj": [ 3 | "dj", 4 | 83 5 | ], 6 | "经典老歌": [ 7 | "经典老歌", 8 | 6 9 | ], 10 | "车载": [ 11 | "车载", 12 | 2 13 | ], 14 | "老歌": [ 15 | "老歌", 16 | 3 17 | ], 18 | "情歌": [ 19 | "情歌", 20 | 10 21 | ], 22 | "我是歌手": [ 23 | "我是歌手", 24 | 3 25 | ], 26 | "经典": [ 27 | "经典", 28 | 7 29 | ], 30 | "红歌": [ 31 | "红歌", 32 | 4 33 | ], 34 | "网络热歌": [ 35 | "网络热歌", 36 | 1 37 | ], 38 | "中国好声音": [ 39 | "中国好声音", 40 | 2 41 | ], 42 | "纯音乐": [ 43 | "纯音乐", 44 | 5 45 | ], 46 | "神曲": [ 47 | "神曲", 48 | 2 49 | ], 50 | "基督教": [ 51 | "基督教", 52 | 1 53 | ], 54 | "成名曲": [ 55 | "成名曲", 56 | 2 57 | ], 58 | "喊麦": [ 59 | "喊麦", 60 | 1 61 | ], 62 | "新年": [ 63 | "新年", 64 | 2 65 | ], 66 | "过年": [ 67 | "过年", 68 | 1 69 | ], 70 | "串烧": [ 71 | "串烧", 72 | 3 73 | ], 74 | "伤感情歌": [ 75 | "伤感情歌", 76 | 3 77 | ] 78 | } -------------------------------------------------------------------------------- /entity_files/style.json: -------------------------------------------------------------------------------- 1 | { 2 | "串烧": [ 3 | "串烧", 4 | 2 5 | ], 6 | "舞曲": [ 7 | "舞曲", 8 | 11 9 | ], 10 | "轻音乐": [ 11 | "轻音乐", 12 | 6 13 | ], 14 | "蓝调": [ 15 | "蓝调", 16 | 2 17 | ], 18 | "乡村": [ 19 | "乡村", 20 | 1 21 | ], 22 | "民歌": [ 23 | "民歌", 24 | 3 25 | ], 26 | "摇滚": [ 27 | "摇滚", 28 | 12 29 | ], 30 | "合唱": [ 31 | "合唱", 32 | 2 33 | ], 34 | "慢摇": [ 35 | "慢摇", 36 | 2 37 | ], 38 | "草原": [ 39 | "草原", 40 | 3 41 | ], 42 | "鬼步舞": [ 43 | "鬼步舞", 44 | 1 45 | ], 46 | "嘻哈": [ 47 | "嘻哈", 48 | 2 49 | ], 50 | "电音": [ 51 | "电音", 52 | 2 53 | ], 54 | "拉锯电音": [ 55 | "拉锯电音", 56 | 1 57 | ], 58 | "古典": [ 59 | "古典", 60 | 2 61 | ], 62 | "b box": [ 63 | "b box", 64 | 1 65 | ], 66 | "嘻哈说唱": [ 67 | "嘻哈说唱", 68 | 1 69 | ], 70 | "陕北民歌": [ 71 | "陕北民歌", 72 | 3 73 | ], 74 | "民族": [ 75 | "民族", 76 | 1 77 | ], 78 | "民谣": [ 79 | "民谣", 80 | 1 81 | ] 82 | } -------------------------------------------------------------------------------- /curLine_file.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | def curLine(): 4 | file_path = sys._getframe().f_back.f_code.co_filename # 获取调用函数的路径 5 | file_name=file_path[file_path.rfind("/") + 1:] # 获取调用函数所在的文件名 6 | lineno=sys._getframe().f_back.f_lineno#当前行号 7 | str="[%s:%s] "%(file_name,lineno) 8 | return str 9 | 10 | 11 | stop_words = set(('\n', '\r', '\t', ' ')) 12 | # """全角转半角""" 13 | def quanjiao2banjiao(query_lower): 14 | rstring = list() 15 | for uchar in query_lower: 16 | # if uchar in stop_words: # 因为word piece会忽略空格引发问题,故当前版本都是忽略空格的,会引发错误 17 | # continue 18 | inside_code = ord(uchar) 19 | if inside_code == 12288: # 全角空格直接转换 20 | inside_code = 32 21 | elif (inside_code >= 65281 and inside_code <= 65374): # 全角字符(除空格)根据关系转化 22 | inside_code -= 65248 23 | uchar = chr(inside_code) 24 | # if uchar not in stop_mark: 25 | rstring.append(uchar) 26 | rerturn_str = "".join(rstring) 27 | return rerturn_str 28 | # 大写变小写, 全角变半角, 去标点 注意训练语料和模板也要改 29 | def normal_transformer(query): 30 | query_lower = query # .strip("’,?!,。?!�") #.lower() 31 | rstring = quanjiao2banjiao(query_lower).strip() #.replace("您","你") 32 | return rstring 33 | 34 | other_tag = "OTHERS" -------------------------------------------------------------------------------- /entity_files/slot-dictionaries/emotion.txt: -------------------------------------------------------------------------------- 1 | 酷一点 2 | 休闲 3 | 励志 4 | 嚣张 5 | 幸福快乐 6 | 兴奋 7 | 清幽 8 | 疗伤 9 | 减压 10 | 治愈系 11 | 愉快 12 | 凄恻 13 | 忆旧 14 | 快活 15 | 欢乐 16 | 激动人心 17 | 欢悦 18 | 凄怆 19 | 打动人 20 | 快歌 21 | 心潮难平 22 | 哀戚 23 | 孤单 24 | 清新 25 | 安谧 26 | 激动不已 27 | 气盛 28 | 最high 29 | 清静 30 | 悲愁 31 | 浪漫 32 | 热血 33 | 喜悦 34 | 舒缓 35 | 性感 36 | 眷恋 37 | 悲伤 38 | 罗曼蒂克 39 | 孤孤单单 40 | 凄惶 41 | 舒服 42 | 柔情 43 | 开心 44 | 恋旧 45 | 慢一点 46 | 昂奋 47 | 凄然 48 | 忧伤 49 | 欢娱 50 | 分手 51 | 伤悲 52 | 猛歌 53 | 令人感动 54 | 幸福 55 | 慢歌 56 | 最嗨 57 | 温柔 58 | 静谧 59 | 宣泄 60 | 下雨天 61 | 喜欢 62 | high歌 63 | 嗨歌 64 | 相思 65 | 难过 66 | 孤单单 67 | 催人泪下 68 | 安静 69 | 静寂 70 | 欣悦 71 | 欢快 72 | 宁静 73 | 寂寞 74 | 抑郁自杀 75 | 抚平 76 | 动人 77 | 哀愁 78 | 激情澎湃 79 | 放松 80 | 表白 81 | 嗨曲 82 | 治愈心灵 83 | 甜蜜 84 | 欣喜 85 | 心痛 86 | 快乐 87 | 悲哀 88 | 称快 89 | 轻快 90 | 猛烈 91 | 动感 92 | 哀伤 93 | 悲怆 94 | 难受 95 | 正能量 96 | 伤感 97 | 治愈 98 | 欢喜 99 | 念旧 100 | 欢欢喜喜 101 | 愉悦 102 | 火爆 103 | 激情 104 | 心潮起伏 105 | 悲戚 106 | 劲爆 107 | 快快乐乐 108 | 感动 109 | 伤心 110 | 抑郁 111 | 奋斗 112 | 孤独 113 | 熬心 114 | 欢欣 115 | 美满 116 | 怀念 117 | 催人奋进 118 | 轻松 119 | 甜美 120 | 妖媚 121 | 思念 122 | 喜庆 123 | 肉麻 124 | 怀旧 125 | 怀古 126 | 嗨 127 | 激动 128 | 欢愉 129 | 清爽 130 | 抒情 131 | 激励 132 | 发泄 133 | 心潮澎湃 134 | 令人鼓舞 135 | 高兴 136 | -------------------------------------------------------------------------------- /entity_files/slot-dictionaries/scene.txt: -------------------------------------------------------------------------------- 1 | 宅 2 | 做事 3 | 出车 4 | 夜晚 5 | 咖啡吧 6 | 午觉 7 | 早饭 8 | 睡懒觉 9 | 求婚 10 | 公车 11 | 喝咖啡 12 | 晚饭 13 | 清早 14 | 催眠 15 | 写作业 16 | 转转 17 | 小资 18 | 幽会 19 | 旅游 20 | 赖床 21 | 馆子 22 | 旅行 23 | 健身 24 | 校车 25 | 聚集 26 | 午饭 27 | 喜宴 28 | 跑步 29 | 发车 30 | 码代码 31 | 开车 32 | 婚礼 33 | 早餐 34 | 来一杯咖啡 35 | 夜总会 36 | 商场 37 | 团聚 38 | 饭庄 39 | 喜酒 40 | 晚上 41 | 深夜 42 | 下午茶 43 | 饭店 44 | 坐车 45 | 网吧 46 | 驾车 47 | 喜筵 48 | 好好工作 49 | 地铁 50 | 商店 51 | 睡前 52 | 回家 53 | 相聚 54 | 早上 55 | 餐馆 56 | 阴天 57 | 迪厅 58 | 午睡 59 | 咖啡店 60 | 夜店 61 | 酒家 62 | 公交车 63 | 小憩 64 | 喝酒 65 | 午餐 66 | 午休 67 | 上网 68 | k歌 69 | 学习 70 | 写书法 71 | 食堂 72 | 吃饭 73 | 下雨天 74 | 酒馆 75 | 午眠 76 | 店铺 77 | 饭铺 78 | 家务 79 | 游玩 80 | 看书 81 | 绕弯儿 82 | 散步 83 | 菜馆 84 | 早晨 85 | 做生日 86 | 迪吧 87 | 走走 88 | 远足 89 | 洗澡 90 | 睡觉 91 | 家事 92 | 游历 93 | 起床 94 | 工作 95 | 婚宴 96 | 瑜伽 97 | 写代码 98 | 宅在家里 99 | 大巴 100 | 寿诞 101 | 超市 102 | 酒吧 103 | 读书 104 | 牵手 105 | 活动 106 | 自习 107 | 冲凉 108 | 做清洁 109 | 溜达 110 | 约会 111 | 做寿 112 | 驱车 113 | 生日 114 | 饭馆 115 | 家务事 116 | 在路上 117 | 观光 118 | 雨天 119 | 聚会 120 | 家务时间 121 | 做作业 122 | 晚餐 123 | 作事 124 | 休息 125 | 夜间 126 | 清晨 127 | 家务活 128 | 咖啡馆 129 | 咖啡厅 130 | 睡觉前 131 | 胎教 132 | 劳动 133 | 泡澡 134 | 睡眠 135 | 忙碌 136 | 遛弯儿 137 | 集会 138 | 运动 139 | 结婚 140 | 公交 141 | 班车 142 | 转悠 143 | 酒店 144 | 做卫生 145 | 过生日 146 | -------------------------------------------------------------------------------- /entity_files/slot-dictionaries/theme.txt: -------------------------------------------------------------------------------- 1 | 格莱美奖 2 | 翻唱 3 | 最经典 4 | 成名歌曲 5 | 基督 6 | 翻唱歌曲 7 | 招牌歌曲 8 | 红歌 9 | 超强音浪 10 | 电影 11 | 沙发歌曲 12 | 天籁之战 13 | 基督教 14 | 情歌 15 | dj 16 | 广场舞 17 | 车载音乐 18 | 幼儿园 19 | 纯人声 20 | 沙发 21 | 男生歌曲 22 | 中国风 23 | 车载低音炮 24 | 歌厅必点歌 25 | 神曲 26 | 单纯 27 | 电视剧 28 | 奥斯卡奖 29 | 圣诞 30 | 铃音 31 | 清新 32 | 音乐剧 33 | 革命 34 | 网游 35 | 冷门 36 | 学校 37 | 影视原声 38 | 梦想的声音 39 | 电视 40 | 佛经 41 | 浪漫韩剧 42 | 游戏 43 | 经典情歌 44 | 成名作 45 | 情侣 46 | 春节 47 | 现场版 48 | 中国新歌声 49 | 喊麦 50 | 奥奖 51 | 动画 52 | 格莱美 53 | 世界名曲 54 | 我是歌手 55 | 男人听的歌 56 | dj碟 57 | 动漫 58 | 影视 59 | 劲歌 60 | 老情歌 61 | 个性化 62 | 格奖 63 | 沙发音乐 64 | 新歌声 65 | 网络热歌 66 | 纯音乐 67 | 网络伤感情歌 68 | 经典老歌 69 | 串烧 70 | 老一点的歌 71 | mv喊麦 72 | 漫画 73 | 男人情歌 74 | 演唱会版 75 | 成名曲 76 | 佛教 77 | 武侠 78 | dj舞曲 79 | 电视连续剧 80 | 韩剧 81 | dj热碟 82 | 蒙面唱将 83 | 二次元 84 | 平安夜 85 | live 86 | 新春 87 | 武侠风 88 | 歌手2018 89 | 年会 90 | 动画片 91 | 男生听的歌 92 | 经典现场 93 | 次元 94 | 新年 95 | 中国好歌曲 96 | ktv 97 | 中国好声音 98 | live版 99 | 一人一首招牌歌 100 | 代表歌曲 101 | 铃声 102 | 演唱会 103 | ktv必点歌 104 | 小清新 105 | 老歌 106 | 个性电台 107 | ktv歌曲 108 | 爱情 109 | 车载歌曲 110 | 洗脑 111 | 纯旋律 112 | 代表作 113 | 武侠情节 114 | 奥斯卡 115 | 大清新 116 | 圣诞节 117 | 冷僻 118 | 爱国 119 | 伤情中国风 120 | 传统 121 | 伤感的情歌 122 | 好歌曲 123 | 校园 124 | 车载 125 | mc喊麦 126 | 伴奏 127 | 个性歌曲 128 | 军旅 129 | 经典 130 | 个性 131 | 伤感情歌 132 | 没有歌词 133 | 过年 134 | 好声音 135 | 男性歌曲 136 | 元旦 137 | 流金岁月 138 | 经典翻唱 139 | 蒙面唱将猜猜猜 140 | 超赞纯人声 141 | -------------------------------------------------------------------------------- /entity_files/slot-dictionaries/age.txt: -------------------------------------------------------------------------------- 1 | 小孩 2 | 小学 3 | 女神 4 | 2016年 5 | 1994年 6 | 1978年 7 | 儿童 8 | 1986年 9 | 女人 10 | 中老年 11 | 2004年 12 | 六十后老年人 13 | 1989年 14 | 1965年 15 | 七零年代 16 | 1977年 17 | 两千后 18 | 1949年 19 | 娃娃 20 | 2017年 21 | 1954年 22 | 1971年 23 | 1995年 24 | 1979年 25 | 1981年 26 | 美女 27 | 男人 28 | 2007年 29 | 10年代 30 | 1964年 31 | 八十年代 32 | 1951年 33 | 女孩子 34 | 2222年 35 | 一岁半 36 | 1968年 37 | 1996年 38 | 2010年 39 | 80后 40 | 九零年代 41 | 2002年 42 | 九零后 43 | 90年代 44 | 70后 45 | 1956年 46 | 老年 47 | 1967年 48 | 1984年 49 | 九十年代 50 | 2020年 51 | 80年代 52 | 八零年代 53 | 1997年 54 | 2011年 55 | 90后 56 | 1987年 57 | 2019年 58 | 2005年 59 | 1970年 60 | 帅哥 61 | 1957年 62 | 20年代 63 | 男孩子 64 | 1966年 65 | 2012年 66 | 父母辈 67 | 1953年 68 | 2008年 69 | 1962年 70 | 1982年 71 | 2000年 72 | 一岁 73 | 七十后 74 | 女生 75 | 1990年 76 | 一零年代 77 | 八零后 78 | 1999年 79 | 1960年 80 | 1958年 81 | 1975年 82 | 1950年 83 | 1985年 84 | 2010年之前 85 | 幼儿 86 | 1961年 87 | 八十后 88 | 童谣 89 | 亲子儿歌 90 | 1969年 91 | 00年代 92 | 1991年 93 | 2013年 94 | 91世纪 95 | 2003年 96 | 六零后 97 | 1976年 98 | 暖男 99 | 1959年 100 | 小朋友 101 | 22世纪 102 | 1992年 103 | 2014年 104 | 21世纪 105 | 1972年 106 | 1955年 107 | 男孩 108 | 文艺青年 109 | 1980年 110 | 2006年 111 | 七十年代 112 | 2010之前 113 | 1963年 114 | 零零后 115 | 儿歌 116 | 男神 117 | 60后 118 | 零零年代 119 | 1993年 120 | 2015年 121 | 70年代 122 | 1952年 123 | 1973年 124 | 2009年 125 | 2018年 126 | 少儿 127 | 1983年 128 | 2001年 129 | 上班族 130 | 九十后 131 | 二零一一 132 | 00后 133 | 两岁 134 | 七零后 135 | 女孩 136 | 1988年 137 | 1998年 138 | 男生 139 | 1974年 140 | -------------------------------------------------------------------------------- /entity_files/frequentSinger.json: -------------------------------------------------------------------------------- 1 | { 2 | "1": 0.1622604899345137, 3 | "0": 0.06766917293233082, 4 | "10": 0.04640633842671194, 5 | "9": 0.041959738054814455, 6 | "N": 0.03128789716226049, 7 | "M": 0.025305198480071147, 8 | "未来": 0.024173336567224514, 9 | "P": 0.023930794728757376, 10 | "左右": 0.01447166302853909, 11 | "安": 0.01228878648233487, 12 | "M2": 0.011399466407955373, 13 | "风": 0.009782520818174469, 14 | "15": 0.008650658905327836, 15 | "19": 0.007437949712992158, 16 | "21": 0.007114560595035977, 17 | "张航": 0.00606354596167839, 18 | "凌晨": 0.004608294930875576, 19 | "31": 0.004284905812919395, 20 | "信": 0.003718974856496079, 21 | "气象预报": 0.0031530439000727628, 22 | "美": 0.0031530439000727628, 23 | "城市": 0.002991349341094672, 24 | "C": 0.002991349341094672, 25 | "-": 0.0028296547821165816, 26 | "吉": 0.0028296547821165816, 27 | "泉": 0.0026679602231384912, 28 | "流行歌曲": 0.0025871129436494463, 29 | "in": 0.0025871129436494463, 30 | "夏": 0.002425418384671356, 31 | "S": 0.0023445711051823105, 32 | "B": 0.00218287654620422, 33 | "A": 0.00218287654620422, 34 | "乌兰": 0.0020211819872261298, 35 | "R": 0.0019403347077370846, 36 | "D": 0.0017786401487589942, 37 | "H": 0.0017786401487589942, 38 | "W": 0.001697792869269949, 39 | "红": 0.001697792869269949, 40 | "晴天": 0.0015360983102918587, 41 | "T": 0.0015360983102918587, 42 | "星": 0.0015360983102918587, 43 | "十一": 0.0015360983102918587, 44 | "ch": 0.0012935564718247231, 45 | "李白": 0.001212709192335678, 46 | "G": 0.001212709192335678, 47 | "杨": 0.001212709192335678, 48 | "零": 0.0011318619128466328, 49 | "I": 0.0011318619128466328, 50 | "J": 0.0011318619128466328, 51 | "小雨": 0.0010510146333575876, 52 | "L": 0.0010510146333575876, 53 | "O": 0.0009701673538685423, 54 | 55 | "m2":0.8, 56 | "什么": 0.8, 57 | "倩女幽魂": 0.8, 58 | "两只老虎": 0.8, 59 | "中华": 0.8, 60 | "文歌": 0.8, 61 | "岳云鹏": 0.8, 62 | "郭德纲": 0.8, 63 | "大美": 0.8, 64 | "大王叫我来巡山": 0.8, 65 | "大海": 0.8, 66 | "大小姐": 0.8, 67 | "白天": 0.8, 68 | "江南": 0.8, 69 | "凌晨": 0.8, 70 | "金刚": 0.8, 71 | "苍蝇": 0.8, 72 | "海洋": 0.8, 73 | "相思": 0.8, 74 | "李白": 0.8, 75 | "dj": 0.8, 76 | "you": 0.8, 77 | "曹操": 0.8, 78 | "曲歌": 0.5, 79 | "云朵": 0.5, 80 | "小女子": 0.8 81 | } 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | *-notice.txt 3 | *log.txt 4 | *score.txt 5 | *.jsonl 6 | *.pickle 7 | *results.txt 8 | hunxiao_words.txt 9 | *.csv 10 | 11 | # Distribution / packaging 12 | .Python 13 | models/ 14 | .idea/ 15 | __pycache__/ 16 | *.pyc 17 | *.py[cod] 18 | *$py.class 19 | *.so 20 | 21 | 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | pip-wheel-metadata/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | -------------------------------------------------------------------------------- /score_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Calculates evaluation scores for a prediction TSV file. 18 | 19 | The prediction file is produced by predict_main.py and should contain 3 or more 20 | columns: 21 | 1: sources (concatenated) 22 | 2: prediction 23 | 3-n: targets (1 or more) 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | 29 | from __future__ import print_function 30 | 31 | from absl import app, flags, logging 32 | import score_lib 33 | from bert import tokenization 34 | from curLine_file import curLine 35 | 36 | FLAGS = flags.FLAGS 37 | 38 | flags.DEFINE_string( 39 | 'prediction_file', None, 40 | 'TSV file containing source, prediction, and target columns.') 41 | flags.DEFINE_bool( 42 | 'do_lower_case', True, 43 | 'Whether score computation should be case insensitive (in the LaserTagger ' 44 | 'paper this was set to True).') 45 | flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.') 46 | flags.DEFINE_string( 47 | 'domain_name', None, 48 | 'Whether to lower case the input text. Should be True for uncased ' 49 | 'models and False for cased models.') 50 | def main(argv): 51 | if len(argv) > 1: 52 | raise app.UsageError('Too many command-line arguments.') 53 | flags.mark_flag_as_required('prediction_file') 54 | target_domain_name = FLAGS.domain_name 55 | print(curLine(), "target_domain_name:", target_domain_name) 56 | 57 | predDomain_list, predIntent_list, domain_list, right_intent_num, right_slot_num, exact_num = score_lib.read_data( 58 | FLAGS.prediction_file, FLAGS.do_lower_case, target_domain_name=target_domain_name) 59 | logging.info(f'Read file: {FLAGS.prediction_file}') 60 | all_num = len(domain_list) 61 | domain_acc = score_lib.compute_exact_score(predDomain_list, domain_list) 62 | intent_acc = float(right_intent_num) / all_num 63 | slot_acc = float(right_slot_num) / all_num 64 | exact_score = float(exact_num) / all_num 65 | print('Num=%d, domain_acc=%.4f, intent_acc=%.4f, slot_acc=%.5f, exact_score=%.4f' 66 | % (all_num, domain_acc, intent_acc, slot_acc, exact_score)) 67 | 68 | 69 | if __name__ == '__main__': 70 | app.run(main) 71 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Utility functions for LaserTagger.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | 22 | from __future__ import print_function 23 | 24 | import json 25 | import csv 26 | 27 | import tensorflow as tf 28 | from curLine_file import curLine, normal_transformer, other_tag 29 | 30 | 31 | 32 | def yield_sources_and_targets(input_file, input_format, domain_name): 33 | """Reads and yields source lists and targets from the input file. 34 | 35 | Args: 36 | input_file: Path to the input file. 37 | input_format: Format of the input file. 38 | 39 | Yields: 40 | Tuple with (list of source texts, target text). 41 | """ 42 | if input_format == 'nlu': 43 | yield_example_fn = _nlu_examples 44 | else: 45 | raise ValueError('Unsupported input_format: {}'.format(input_format)) 46 | for sources, target in yield_example_fn(input_file, domain_name): 47 | yield sources, target 48 | 49 | def _nlu_examples(input_file, target_domain_name=None): 50 | with tf.gfile.GFile(input_file) as f: 51 | reader = csv.reader(f) 52 | session_list = [] 53 | for row_id, (sessionId, raw_query, domain_intent, param) in enumerate(reader): 54 | query = normal_transformer(raw_query) 55 | param = normal_transformer(param) 56 | sources = [] 57 | if row_id > 0 and sessionId == session_list[row_id - 1][0]: 58 | sources.append(session_list[row_id-1][1]) # last query 59 | sources.append(query) 60 | if domain_intent == other_tag: 61 | domain = other_tag 62 | else: 63 | domain, intent = domain_intent.split(".") 64 | session_list.append((sessionId, query)) 65 | if target_domain_name is not None and target_domain_name != domain: 66 | continue 67 | yield sources, (intent,param) 68 | 69 | 70 | def read_label_map(path): 71 | """Returns label map read from the given path.""" 72 | with tf.gfile.GFile(path) as f: 73 | if path.endswith('.json'): 74 | return json.load(f) 75 | else: 76 | label_map = {} 77 | empty_line_encountered = False 78 | for tag in f: 79 | tag = tag.strip() 80 | if tag: 81 | label_map[tag] = len(label_map) 82 | else: 83 | if empty_line_encountered: 84 | raise ValueError( 85 | 'There should be no empty lines in the middle of the label map ' 86 | 'file.' 87 | ) 88 | empty_line_encountered = True 89 | return label_map 90 | -------------------------------------------------------------------------------- /official_transformer/model_params.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Defines Transformer model parameters.""" 17 | 18 | from collections import defaultdict 19 | 20 | 21 | BASE_PARAMS = defaultdict( 22 | lambda: None, # Set default value to None. 23 | 24 | # Input params 25 | default_batch_size=2048, # Maximum number of tokens per batch of examples. 26 | default_batch_size_tpu=32768, 27 | max_length=256, # Maximum number of tokens per example. 28 | 29 | # Model params 30 | initializer_gain=1.0, # Used in trainable variable initialization. 31 | vocab_size=33708, # Number of tokens defined in the vocabulary file. 32 | hidden_size=512, # Model dimension in the hidden layers. 33 | num_hidden_layers=6, # Number of layers in the encoder and decoder stacks. 34 | num_heads=8, # Number of heads to use in multi-headed attention. 35 | filter_size=2048, # Inner layer dimension in the feedforward network. 36 | 37 | # Dropout values (only used when training) 38 | layer_postprocess_dropout=0.1, 39 | attention_dropout=0.1, 40 | relu_dropout=0.1, 41 | 42 | # Training params 43 | label_smoothing=0.1, 44 | learning_rate=2.0, 45 | learning_rate_decay_rate=1.0, 46 | learning_rate_warmup_steps=16000, 47 | 48 | # Optimizer params 49 | optimizer_adam_beta1=0.9, 50 | optimizer_adam_beta2=0.997, 51 | optimizer_adam_epsilon=1e-09, 52 | 53 | # Default prediction params 54 | extra_decode_length=50, 55 | beam_size=4, 56 | alpha=0.6, # used to calculate length normalization in beam search 57 | 58 | # TPU specific parameters 59 | use_tpu=False, 60 | static_batch=False, 61 | allow_ffn_pad=True, 62 | ) 63 | 64 | BIG_PARAMS = BASE_PARAMS.copy() 65 | BIG_PARAMS.update( 66 | default_batch_size=4096, 67 | 68 | # default batch size is smaller than for BASE_PARAMS due to memory limits. 69 | default_batch_size_tpu=16384, 70 | 71 | hidden_size=1024, 72 | filter_size=4096, 73 | num_heads=16, 74 | ) 75 | 76 | # Parameters for running the model in multi gpu. These should not change the 77 | # params that modify the model shape (such as the hidden_size or num_heads). 78 | BASE_MULTI_GPU_PARAMS = BASE_PARAMS.copy() 79 | BASE_MULTI_GPU_PARAMS.update( 80 | learning_rate_warmup_steps=8000 81 | ) 82 | 83 | BIG_MULTI_GPU_PARAMS = BIG_PARAMS.copy() 84 | BIG_MULTI_GPU_PARAMS.update( 85 | layer_postprocess_dropout=0.3, 86 | learning_rate_warmup_steps=8000 87 | ) 88 | 89 | # Parameters for testing the model 90 | TINY_PARAMS = BASE_PARAMS.copy() 91 | TINY_PARAMS.update( 92 | default_batch_size=1024, 93 | default_batch_size_tpu=1024, 94 | hidden_size=32, 95 | num_heads=4, 96 | filter_size=256, 97 | ) 98 | -------------------------------------------------------------------------------- /official_transformer/ffn_layer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Implementation of fully connected network.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | 25 | class FeedFowardNetwork(tf.layers.Layer): 26 | """Fully connected feedforward network.""" 27 | 28 | def __init__(self, hidden_size, filter_size, relu_dropout, train, allow_pad): 29 | super(FeedFowardNetwork, self).__init__() 30 | self.hidden_size = hidden_size 31 | self.filter_size = filter_size 32 | self.relu_dropout = relu_dropout 33 | self.train = train 34 | self.allow_pad = allow_pad 35 | 36 | self.filter_dense_layer = tf.layers.Dense( 37 | filter_size, use_bias=True, activation=tf.nn.relu, name="filter_layer") 38 | self.output_dense_layer = tf.layers.Dense( 39 | hidden_size, use_bias=True, name="output_layer") 40 | 41 | def call(self, x, padding=None): 42 | """Return outputs of the feedforward network. 43 | 44 | Args: 45 | x: tensor with shape [batch_size, length, hidden_size] 46 | padding: (optional) If set, the padding values are temporarily removed 47 | from x (provided self.allow_pad is set). The padding values are placed 48 | back in the output tensor in the same locations. 49 | shape [batch_size, length] 50 | 51 | Returns: 52 | Output of the feedforward network. 53 | tensor with shape [batch_size, length, hidden_size] 54 | """ 55 | padding = None if not self.allow_pad else padding 56 | 57 | # Retrieve dynamically known shapes 58 | batch_size = tf.shape(x)[0] 59 | length = tf.shape(x)[1] 60 | 61 | if padding is not None: 62 | with tf.name_scope("remove_padding"): 63 | # Flatten padding to [batch_size*length] 64 | pad_mask = tf.reshape(padding, [-1]) 65 | 66 | nonpad_ids = tf.to_int32(tf.where(pad_mask < 1e-9)) 67 | 68 | # Reshape x to [batch_size*length, hidden_size] to remove padding 69 | x = tf.reshape(x, [-1, self.hidden_size]) 70 | x = tf.gather_nd(x, indices=nonpad_ids) 71 | 72 | # Reshape x from 2 dimensions to 3 dimensions. 73 | x.set_shape([None, self.hidden_size]) 74 | x = tf.expand_dims(x, axis=0) 75 | 76 | output = self.filter_dense_layer(x) 77 | if self.train: 78 | output = tf.nn.dropout(output, 1.0 - self.relu_dropout) 79 | output = self.output_dense_layer(output) 80 | 81 | if padding is not None: 82 | with tf.name_scope("re_add_padding"): 83 | output = tf.squeeze(output, axis=0) 84 | output = tf.scatter_nd( 85 | indices=nonpad_ids, 86 | updates=output, 87 | shape=[batch_size * length, self.hidden_size] 88 | ) 89 | output = tf.reshape(output, [batch_size, length, self.hidden_size]) 90 | return output 91 | -------------------------------------------------------------------------------- /find_entity/probable_acmation.py: -------------------------------------------------------------------------------- 1 | # 发现疑似实体,辅助训练 2 | # 用ac自动机构建发现疑似实体的工具 3 | import os 4 | from collections import defaultdict 5 | import json 6 | import re 7 | from .acmation import KeywordTree, add_to_ac, entity_files_folder, entity_folder 8 | from curLine_file import curLine, normal_transformer 9 | 10 | domain2entity_map = {} 11 | domain2entity_map["music"] = ["age", "singer", "song", "toplist", "theme", "style", "scene", "language", "emotion", "instrument"] 12 | domain2entity_map["navigation"] = ["custom_destination", "city"] # place city 13 | domain2entity_map["phone_call"] = ["phone_num", "contact_name"] 14 | 15 | re_phoneNum = re.compile("[0-9一二三四五六七八九十拾]+") # 编译 16 | 17 | 18 | # 也许直接读取下载的xls文件更方便,但那样需要安装xlrd模块 19 | self_entity_trie_tree = {} # 总的实体字典 自己建立的某些实体类型的实体树 20 | for domain, entity_type_list in domain2entity_map.items(): 21 | print(curLine(), domain, entity_type_list) 22 | for entity_type in entity_type_list: 23 | if entity_type not in self_entity_trie_tree: 24 | ac = KeywordTree(case_insensitive=True) 25 | else: 26 | ac = self_entity_trie_tree[entity_type] 27 | 28 | # TODO 29 | if entity_type == "city": 30 | # for current_entity_type in ["city", "province"]: 31 | # entity_file = waibu_folder + "%s.json" % current_entity_type 32 | # with open(entity_file, "r") as f: 33 | # current_entity_dict = json.load(f) 34 | # print(curLine(), "get %d %s from %s" % 35 | # (len(current_entity_dict), current_entity_type, entity_file)) 36 | # for entity_before, entity_times in current_entity_dict.items(): 37 | # entity_after = entity_before 38 | # add_to_ac(ac, entity_type, entity_before, entity_after, pri=1) 39 | 40 | ## 从标注语料中挖掘得到的地名 41 | for current_entity_type in ["destination", "origin"]: 42 | entity_file = os.path.join(entity_files_folder, "%s.json" % current_entity_type) 43 | with open(entity_file, "r") as f: 44 | current_entity_dict = json.load(f) 45 | print(curLine(), "get %d %s from %s" % 46 | (len(current_entity_dict), current_entity_type, entity_file)) 47 | for entity_before, entity_after_times in current_entity_dict.items(): 48 | entity_after = entity_after_times[0] 49 | add_to_ac(ac, entity_type, entity_before, entity_after, pri=2) 50 | input(curLine()) 51 | 52 | # 给的实体库,最高优先级 53 | entity_file = os.path.join(entity_folder, "%s.txt" % entity_type) 54 | if os.path.exists(entity_file): 55 | with open(entity_file, "r") as fr: 56 | lines = fr.readlines() 57 | print(curLine(), "get %d %s from %s" % (len(lines), entity_type, entity_file)) 58 | for line in lines: 59 | entity_after = line.strip() 60 | entity_before = entity_after # TODO 61 | pri = 3 62 | if entity_type in ["song"]: 63 | pri -= 0.5 64 | add_to_ac(ac, entity_type, entity_before, entity_after, pri=pri) 65 | ac.finalize() 66 | self_entity_trie_tree[entity_type] = ac 67 | 68 | 69 | def get_all_entity(corpus, useEntityTypeList): 70 | self_entityTypeMap = defaultdict(list) 71 | for entity_type in useEntityTypeList: 72 | result = self_entity_trie_tree[entity_type].search(corpus) 73 | for res in result: 74 | after, priority = res.meta_data 75 | self_entityTypeMap[entity_type].append({'before': res.keyword, 'after': after, "priority":priority}) 76 | if "phone_num" in useEntityTypeList: 77 | token_numbers = re_phoneNum.findall(corpus) 78 | for number in token_numbers: 79 | self_entityTypeMap["phone_num"].append({'before':number, 'after':number, 'priority': 2}) 80 | return self_entityTypeMap -------------------------------------------------------------------------------- /confusion_words_duoyin.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 混淆词提取算法 3 | from pypinyin import lazy_pinyin, pinyin, Style 4 | # pinyin 的方法默认带声调,而 lazy_pinyin 方法不带声调,但是会得到不常见的多音字 5 | # pinyin(c, heteronym=True, style=0) 不考虑音调情况下的多音字 6 | 7 | from Levenshtein import distance # python-Levenshtein 8 | # 编辑距离 (Levenshtein Distance算法) 9 | import os 10 | import copy 11 | from curLine_file import curLine 12 | from find_entity.acmation import entity_folder 13 | 14 | number_map = {"0":"零", "1":"一", "2":"二", "3":"三", "4":"四", "5":"五", "6":"六", "7":"七", "8":"八", "9":"九"} 15 | def get_pinyin_combination(entity): 16 | all_combination = [""] 17 | for zi in entity: 18 | if zi in number_map: 19 | zi = number_map[zi] 20 | zi_pinyin_list = pinyin(zi, heteronym=True, style=0)[0] 21 | pre_combination = copy.deepcopy(all_combination) 22 | all_combination = [] 23 | for index, zi_pinyin in enumerate(zi_pinyin_list): # 逐个添加每种发音 24 | for com in pre_combination: # com代表一种情况 25 | # com.append(zi_pinyin) 26 | # new = com + [zi_pinyin] 27 | # new = "".join(new) 28 | all_combination.append("%s%s" % (com, zi_pinyin)) # TODO 要不要加分隔符 29 | return all_combination 30 | 31 | def get_entityType_pinyin(entity_type): 32 | entity_info_dict = {} 33 | entity_file = os.path.join(entity_folder, "%s.txt" % entity_type) 34 | with open(entity_file, "r") as fr: 35 | lines = fr.readlines() 36 | pri = 3 37 | if entity_type in ["song"]: 38 | pri -= 0.5 39 | print(curLine(), "get %d %s from %s, pri=%f" % (len(lines), entity_type, entity_file, pri)) 40 | for line in lines: 41 | entity_after = line.strip() 42 | all_combination = get_pinyin_combination(entity=entity_after) 43 | # for combination in all_combination: 44 | if entity_after not in entity_info_dict: # 新的发音 45 | entity_info_dict[entity_after] = (all_combination, pri) 46 | else: # 一般不会重复,最多重复一次 47 | _, old_pri = entity_info_dict[entity_after] 48 | if pri > old_pri: 49 | entity_info_dict[entity_after] = (all_combination, pri) 50 | return entity_info_dict 51 | 52 | # 用编辑距离度量拼音字符串之间的相似度 53 | def pinyin_similar_word(entity_info_dict, word): 54 | similar_word = None 55 | if word in entity_info_dict: # 存在实体,无需纠错 56 | return 0, word 57 | all_combination = get_pinyin_combination(entity=word) 58 | top_similar_score = 0 59 | for current_combination in all_combination: # 当前的各种发音 60 | current_distance = 10000 61 | 62 | for entity_after,(com_list, pri) in entity_info_dict.items(): 63 | for com in com_list: 64 | d = distance(com, current_combination) 65 | if d < current_distance: 66 | current_distance = d 67 | similar_word = entity_after 68 | current_similar_score = 1.0 - float(current_distance) / len(current_combination) 69 | print(curLine(), "current_combination:%s, %f" % (current_combination, current_similar_score), similar_word, current_distance) 70 | 71 | 72 | 73 | singer_pinyin = get_entityType_pinyin(entity_type="singer") 74 | print(curLine(), len(singer_pinyin), "singer_pinyin") 75 | 76 | song_pinyin = get_entityType_pinyin(entity_type="song") 77 | print(curLine(), len(song_pinyin), "song_pinyin") 78 | 79 | 80 | 81 | if __name__ == "__main__": 82 | s = "你为什123c单身" 83 | for c in s: 84 | c_pinyin = lazy_pinyin(s, errors="ignore") 85 | print(curLine(), c,c_pinyin) 86 | print(curLine(), pinyin(s, heteronym=True, style=0)) 87 | # # print(pinyin(c, heteronym=True, style=0)) 88 | 89 | # all_combination = get_pinyin_combination(entity=s) 90 | # 91 | # for index, combination in enumerate(all_combination): 92 | # print(curLine(), index, combination) 93 | # for s in ["abc", "abs1", "fabc"]: 94 | # ed = distance("abs", s) 95 | # print(s, ed) 96 | 97 | pinyin_similar_word(singer_pinyin, "周杰") 98 | pinyin_similar_word(singer_pinyin, "前任") 99 | -------------------------------------------------------------------------------- /split_corpus.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 分割语料 3 | import os 4 | import csv, json 5 | from collections import defaultdict 6 | import re 7 | from curLine_file import curLine, other_tag 8 | 9 | all_entity_dict = defaultdict(dict) 10 | before2after = {"父亲":"父亲", "赵磊":"赵磊", "甜蜜":"甜蜜", "大王叫我来":"大王叫我来巡山"} 11 | def get_slot(param): 12 | slot = [] 13 | if "<" not in param: 14 | return slot 15 | if ">" not in param: 16 | print(curLine(), "param:", param) 17 | return slot 18 | if "", param) 21 | end_segment = re.findall("", param) 22 | if len(start_segment) != len(end_segment): 23 | print(curLine(), "start_segment:", start_segment) 24 | print(curLine(), "end_segment:", end_segment) 25 | search_location = 0 26 | for s,e in zip(start_segment, end_segment): 27 | entityType = s[1:-1] 28 | assert "" % entityType == e 29 | start_index = param[search_location:].index(s) + len(s) 30 | end_index = param[search_location:].index(e) 31 | entity_info = param[search_location:][start_index:end_index] 32 | search_location += end_index + len(e) 33 | before,after = entity_info, entity_info 34 | if "||" in entity_info: 35 | before, after = entity_info.split("||") 36 | if before in before2after: 37 | after = before2after[before] 38 | if before not in all_entity_dict[entityType]: 39 | all_entity_dict[entityType][before] = [after, 1] 40 | else: 41 | if after != all_entity_dict[entityType][before][0]: 42 | print(curLine(), entityType, before, after, all_entity_dict[entityType][before]) 43 | assert after == all_entity_dict[entityType][before][0] 44 | all_entity_dict[entityType][before][1] += 1 45 | if before != after: 46 | before = after 47 | if before not in all_entity_dict[entityType]: 48 | all_entity_dict[entityType][before] = [after, 1] 49 | else: 50 | assert after == all_entity_dict[entityType][before][0] 51 | all_entity_dict[entityType][before][1] += 1 52 | 53 | 54 | # 预处理 55 | def process(source_file, train_file, dev_file): 56 | dev_lines = [] 57 | train_num = 0 58 | intent_distribution = defaultdict(dict) 59 | with open(source_file, "r") as f, open(train_file, "w") as f_train: 60 | reader = csv.reader(f) 61 | train_write = csv.writer(f_train, dialect='excel') 62 | for row_id, line in enumerate(reader): 63 | if row_id==0: 64 | continue 65 | (sessionId, raw_query, domain_intent, param) = line 66 | get_slot(param) 67 | 68 | if domain_intent == other_tag: 69 | domain = other_tag 70 | intent = other_tag 71 | else: 72 | domain, intent = domain_intent.split(".") 73 | if intent in intent_distribution[domain]: 74 | intent_distribution[domain][intent] += 1 75 | else: 76 | intent_distribution[domain][intent] = 0 77 | if row_id == 0: 78 | continue 79 | sessionId = int(sessionId) 80 | if sessionId % 10>0: 81 | train_write.writerow(line) 82 | train_num += 1 83 | else: 84 | dev_lines.append(line) 85 | with open(dev_file, "w") as f_dev: 86 | write = csv.writer(f_dev, dialect='excel') 87 | for line in dev_lines: 88 | write.writerow(line) 89 | print(curLine(), "dev=%d, train=%d" % (len(dev_lines), train_num)) 90 | for domain, intent_num in intent_distribution.items(): 91 | print(curLine(), domain, intent_num) 92 | 93 | 94 | if __name__ == "__main__": 95 | corpus_folder = "/home/wzk/Mywork/corpus/compe/69" 96 | source_file = os.path.join(corpus_folder, "train.csv") 97 | train_file = os.path.join(corpus_folder, "train.txt") 98 | dev_file = os.path.join(corpus_folder, "dev.txt") 99 | process(source_file, train_file, dev_file) 100 | 101 | for entityType, entityDict in all_entity_dict.items(): 102 | json_file = os.path.join(corpus_folder, "%s.json" % entityType) 103 | with open(json_file, "w") as f: 104 | json.dump(entityDict, f, ensure_ascii=False, indent=4) 105 | print(curLine(), "save %d %s to %s" % (len(entityDict), entityType, json_file)) 106 | 107 | -------------------------------------------------------------------------------- /official_transformer/embedding_layer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Implementation of embedding layer with shared weights.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf # pylint: disable=g-bad-import-order 23 | 24 | from official_transformer import tpu as tpu_utils 25 | 26 | 27 | class EmbeddingSharedWeights(tf.layers.Layer): 28 | """Calculates input embeddings and pre-softmax linear with shared weights.""" 29 | 30 | def __init__(self, vocab_size, hidden_size, method="gather"): 31 | """Specify characteristic parameters of embedding layer. 32 | 33 | Args: 34 | vocab_size: Number of tokens in the embedding. (Typically ~32,000) 35 | hidden_size: Dimensionality of the embedding. (Typically 512 or 1024) 36 | method: Strategy for performing embedding lookup. "gather" uses tf.gather 37 | which performs well on CPUs and GPUs, but very poorly on TPUs. "matmul" 38 | one-hot encodes the indicies and formulates the embedding as a sparse 39 | matrix multiplication. The matmul formulation is wasteful as it does 40 | extra work, however matrix multiplication is very fast on TPUs which 41 | makes "matmul" considerably faster than "gather" on TPUs. 42 | """ 43 | super(EmbeddingSharedWeights, self).__init__() 44 | self.vocab_size = vocab_size 45 | self.hidden_size = hidden_size 46 | if method not in ("gather", "matmul"): 47 | raise ValueError("method {} must be 'gather' or 'matmul'".format(method)) 48 | self.method = method 49 | 50 | def build(self, _): 51 | with tf.variable_scope("embedding_and_softmax", reuse=tf.AUTO_REUSE): 52 | # Create and initialize weights. The random normal initializer was chosen 53 | # randomly, and works well. 54 | self.shared_weights = tf.get_variable( 55 | "weights", [self.vocab_size, self.hidden_size], 56 | initializer=tf.random_normal_initializer( 57 | 0., self.hidden_size ** -0.5)) 58 | 59 | self.built = True 60 | 61 | def call(self, x): 62 | """Get token embeddings of x. 63 | 64 | Args: 65 | x: An int64 tensor with shape [batch_size, length] 66 | Returns: 67 | embeddings: float32 tensor with shape [batch_size, length, embedding_size] 68 | padding: float32 tensor with shape [batch_size, length] indicating the 69 | locations of the padding tokens in x. 70 | """ 71 | with tf.name_scope("embedding"): 72 | # Create binary mask of size [batch_size, length] 73 | mask = tf.to_float(tf.not_equal(x, 0)) 74 | 75 | if self.method == "gather": 76 | embeddings = tf.gather(self.shared_weights, x) 77 | embeddings *= tf.expand_dims(mask, -1) 78 | else: # matmul 79 | embeddings = tpu_utils.embedding_matmul( 80 | embedding_table=self.shared_weights, 81 | values=tf.cast(x, dtype=tf.int32), 82 | mask=mask 83 | ) 84 | # embedding_matmul already zeros out masked positions, so 85 | # `embeddings *= tf.expand_dims(mask, -1)` is unnecessary. 86 | 87 | 88 | # Scale embedding by the sqrt of the hidden size 89 | embeddings *= self.hidden_size ** 0.5 90 | 91 | return embeddings 92 | 93 | 94 | def linear(self, x): 95 | """Computes logits by running x through a linear layer. 96 | 97 | Args: 98 | x: A float32 tensor with shape [batch_size, length, hidden_size] 99 | Returns: 100 | float32 tensor with shape [batch_size, length, vocab_size]. 101 | """ 102 | with tf.name_scope("presoftmax_linear"): 103 | batch_size = tf.shape(x)[0] 104 | length = tf.shape(x)[1] 105 | 106 | x = tf.reshape(x, [-1, self.hidden_size]) 107 | logits = tf.matmul(x, self.shared_weights, transpose_b=True) 108 | 109 | return tf.reshape(logits, [batch_size, length, self.vocab_size]) 110 | -------------------------------------------------------------------------------- /official_transformer/model_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Transformer model helper methods.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import math 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | 27 | # Very low numbers to represent -infinity. We do not actually use -Inf, since we 28 | # want to be able to multiply these values by zero to get zero. (-Inf * 0 = NaN) 29 | _NEG_INF_FP32 = -1e9 30 | _NEG_INF_FP16 = np.finfo(np.float16).min 31 | 32 | 33 | def get_position_encoding( 34 | length, hidden_size, min_timescale=1.0, max_timescale=1.0e2): # TODO previous max_timescale=1.0e4 35 | """Return positional encoding. 36 | 37 | Calculates the position encoding as a mix of sine and cosine functions with 38 | geometrically increasing wavelengths. 39 | Defined and formulized in Attention is All You Need, section 3.5. 40 | 41 | Args: 42 | length: Sequence length. 43 | hidden_size: Size of the 44 | min_timescale: Minimum scale that will be applied at each position 45 | max_timescale: Maximum scale that will be applied at each position 46 | 47 | Returns: 48 | Tensor with shape [length, hidden_size] 49 | """ 50 | # We compute the positional encoding in float32 even if the model uses 51 | # float16, as many of the ops used, like log and exp, are numerically unstable 52 | # in float16. 53 | position = tf.cast(tf.range(length), tf.float32) 54 | num_timescales = hidden_size // 2 55 | log_timescale_increment = ( 56 | math.log(float(max_timescale) / float(min_timescale)) / 57 | (tf.cast(num_timescales, tf.float32) - 1)) 58 | inv_timescales = min_timescale * tf.exp( 59 | tf.cast(tf.range(num_timescales), tf.float32) * -log_timescale_increment) 60 | scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0) 61 | signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) 62 | return signal 63 | 64 | 65 | def get_decoder_self_attention_bias(length, dtype=tf.float32): 66 | """Calculate bias for decoder that maintains model's autoregressive property. 67 | 68 | Creates a tensor that masks out locations that correspond to illegal 69 | connections, so prediction at position i cannot draw information from future 70 | positions. 71 | 72 | Args: 73 | length: int length of sequences in batch. 74 | dtype: The dtype of the return value. 75 | 76 | Returns: 77 | float tensor of shape [1, 1, length, length] 78 | """ 79 | neg_inf = _NEG_INF_FP16 if dtype == tf.float16 else _NEG_INF_FP32 80 | with tf.name_scope("decoder_self_attention_bias"): 81 | valid_locs = tf.linalg.band_part(input=tf.ones([length, length], dtype=dtype), # 2个length分别是输入和输出的最大长度 82 | num_lower=-1, num_upper=0) # Lower triangular part is 1, other is 0 83 | valid_locs = tf.reshape(valid_locs, [1, 1, length, length]) 84 | decoder_bias = neg_inf * (1.0 - valid_locs) 85 | return decoder_bias 86 | 87 | 88 | def get_padding(x, padding_value=0, dtype=tf.float32): 89 | """Return float tensor representing the padding values in x. 90 | 91 | Args: 92 | x: int tensor with any shape 93 | padding_value: int value that 94 | dtype: The dtype of the return value. 95 | 96 | Returns: 97 | float tensor with same shape as x containing values 0 or 1. 98 | 0 -> non-padding, 1 -> padding 99 | """ 100 | with tf.name_scope("padding"): 101 | return tf.cast(tf.equal(x, padding_value), dtype) 102 | 103 | 104 | def get_padding_bias(x): 105 | """Calculate bias tensor from padding values in tensor. 106 | 107 | Bias tensor that is added to the pre-softmax multi-headed attention logits, 108 | which has shape [batch_size, num_heads, length, length]. The tensor is zero at 109 | non-padding locations, and -1e9 (negative infinity) at padding locations. 110 | 111 | Args: 112 | x: int tensor with shape [batch_size, length] 113 | 114 | Returns: 115 | Attention bias tensor of shape [batch_size, 1, 1, length]. 116 | """ 117 | with tf.name_scope("attention_bias"): 118 | padding = get_padding(x) 119 | attention_bias = padding * _NEG_INF_FP32 120 | attention_bias = tf.expand_dims( 121 | tf.expand_dims(attention_bias, axis=1), axis=1) 122 | return attention_bias 123 | -------------------------------------------------------------------------------- /official_transformer/tpu.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Functions specific to running TensorFlow on TPUs.""" 17 | 18 | import tensorflow as tf 19 | 20 | 21 | # "local" is a magic word in the TPU cluster resolver; it informs the resolver 22 | # to use the local CPU as the compute device. This is useful for testing and 23 | # debugging; the code flow is ostensibly identical, but without the need to 24 | # actually have a TPU on the other end. 25 | LOCAL = "local" 26 | 27 | 28 | def construct_scalar_host_call(metric_dict, model_dir, prefix=""): 29 | """Construct a host call to log scalars when training on TPU. 30 | 31 | Args: 32 | metric_dict: A dict of the tensors to be logged. 33 | model_dir: The location to write the summary. 34 | prefix: The prefix (if any) to prepend to the metric names. 35 | 36 | Returns: 37 | A tuple of (function, args_to_be_passed_to_said_function) 38 | """ 39 | # type: (dict, str) -> (function, list) 40 | metric_names = list(metric_dict.keys()) 41 | 42 | def host_call_fn(global_step, *args): 43 | """Training host call. Creates scalar summaries for training metrics. 44 | 45 | This function is executed on the CPU and should not directly reference 46 | any Tensors in the rest of the `model_fn`. To pass Tensors from the 47 | model to the `metric_fn`, provide as part of the `host_call`. See 48 | https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec 49 | for more information. 50 | 51 | Arguments should match the list of `Tensor` objects passed as the second 52 | element in the tuple passed to `host_call`. 53 | 54 | Args: 55 | global_step: `Tensor with shape `[batch]` for the global_step 56 | *args: Remaining tensors to log. 57 | 58 | Returns: 59 | List of summary ops to run on the CPU host. 60 | """ 61 | step = global_step[0] 62 | with tf.contrib.summary.create_file_writer( 63 | logdir=model_dir, filename_suffix=".host_call").as_default(): 64 | with tf.contrib.summary.always_record_summaries(): 65 | for i, name in enumerate(metric_names): 66 | tf.contrib.summary.scalar(prefix + name, args[i][0], step=step) 67 | 68 | return tf.contrib.summary.all_summary_ops() 69 | 70 | # To log the current learning rate, and gradient norm for Tensorboard, the 71 | # summary op needs to be run on the host CPU via host_call. host_call 72 | # expects [batch_size, ...] Tensors, thus reshape to introduce a batch 73 | # dimension. These Tensors are implicitly concatenated to 74 | # [params['batch_size']]. 75 | global_step_tensor = tf.reshape( 76 | tf.compat.v1.train.get_or_create_global_step(), [1]) 77 | other_tensors = [tf.reshape(metric_dict[key], [1]) for key in metric_names] 78 | 79 | return host_call_fn, [global_step_tensor] + other_tensors 80 | 81 | 82 | def embedding_matmul(embedding_table, values, mask, name="embedding_matmul"): 83 | """Performs embedding lookup via a matmul. 84 | 85 | The matrix to be multiplied by the embedding table Tensor is constructed 86 | via an implementation of scatter based on broadcasting embedding indices 87 | and performing an equality comparison against a broadcasted 88 | range(num_embedding_table_rows). All masked positions will produce an 89 | embedding vector of zeros. 90 | 91 | Args: 92 | embedding_table: Tensor of embedding table. 93 | Rank 2 (table_size x embedding dim) 94 | values: Tensor of embedding indices. Rank 2 (batch x n_indices) 95 | mask: Tensor of mask / weights. Rank 2 (batch x n_indices) 96 | name: Optional name scope for created ops 97 | 98 | Returns: 99 | Rank 3 tensor of embedding vectors. 100 | """ 101 | 102 | with tf.name_scope(name): 103 | n_embeddings = embedding_table.get_shape().as_list()[0] 104 | batch_size, padded_size = values.shape.as_list() 105 | 106 | emb_idcs = tf.tile( 107 | tf.reshape(values, (batch_size, padded_size, 1)), (1, 1, n_embeddings)) 108 | emb_weights = tf.tile( 109 | tf.reshape(mask, (batch_size, padded_size, 1)), (1, 1, n_embeddings)) 110 | col_idcs = tf.tile( 111 | tf.reshape(tf.range(n_embeddings), (1, 1, n_embeddings)), 112 | (batch_size, padded_size, 1)) 113 | one_hot = tf.where( 114 | tf.equal(emb_idcs, col_idcs), emb_weights, 115 | tf.zeros((batch_size, padded_size, n_embeddings))) 116 | 117 | return tf.tensordot(one_hot, embedding_table, 1) 118 | -------------------------------------------------------------------------------- /start.sh: -------------------------------------------------------------------------------- 1 | # 意图识别和填槽模型 2 | start_tm=`date +%s%N`; 3 | 4 | export HOST_NAME=$1 5 | if [[ "wzk" == "$HOST_NAME" ]] 6 | then 7 | # set gpu id to use 8 | export CUDA_VISIBLE_DEVICES=0 9 | else 10 | # not use gpu 11 | export CUDA_VISIBLE_DEVICES="" 12 | fi 13 | 14 | ### Optional parameters ### 15 | # To quickly test that model training works, set the number of epochs to a 16 | # smaller value (e.g. 0.01). 17 | export DOMAIN_NAME="music" 18 | export input_format="nlu" 19 | export num_train_epochs=5 20 | export TRAIN_BATCH_SIZE=100 21 | export learning_rate=1e-4 22 | export warmup_proportion=0.1 23 | export max_seq_length=45 24 | export drop_keep_prob=0.9 25 | export MAX_INPUT_EXAMPLES=1000000 26 | export SAVE_CHECKPOINT_STEPS=1000 27 | export CORPUS_DIR="/home/${HOST_NAME}/Mywork/corpus/compe/69" 28 | export BERT_BASE_DIR="/home/${HOST_NAME}/Mywork/model/chinese_L-12_H-768_A-12" 29 | export CONFIG_FILE=configs/lasertagger_config.json 30 | export OUTPUT_DIR="${CORPUS_DIR}/${DOMAIN_NAME}_output" 31 | export MODEL_DIR="${OUTPUT_DIR}/${DOMAIN_NAME}_models" 32 | export do_lower_case=true 33 | export kernel_size=3 34 | export label_map_file=${OUTPUT_DIR}/label_map.json 35 | export slot_label_map_file=${OUTPUT_DIR}/slot_label_map.json 36 | export SUBMIT_FILE=${MODEL_DIR}/submit.csv 37 | export entity_type_list_file=${OUTPUT_DIR}/entity_type_list.json 38 | 39 | # Check these numbers from the "*.num_examples" files created in step 2. 40 | export NUM_TRAIN_EXAMPLES=300000 41 | export NUM_EVAL_EXAMPLES=5000 42 | 43 | python preprocess_main.py \ 44 | --input_file=${CORPUS_DIR}/train.txt \ 45 | --input_format=${input_format} \ 46 | --output_tfrecord_train=${OUTPUT_DIR}/train.tf_record \ 47 | --output_tfrecord_dev=${OUTPUT_DIR}/dev.tf_record \ 48 | --label_map_file=${label_map_file} \ 49 | --slot_label_map_file=${slot_label_map_file} \ 50 | --vocab_file=${BERT_BASE_DIR}/vocab.txt \ 51 | --max_seq_length=${max_seq_length} \ 52 | --do_lower_case=${do_lower_case} \ 53 | --domain_name=${DOMAIN_NAME} \ 54 | --entity_type_list_file=${entity_type_list_file} 55 | 56 | 57 | 58 | echo "Train the model." 59 | python run_lasertagger.py \ 60 | --training_file=${OUTPUT_DIR}/train.tf_record \ 61 | --eval_file=${OUTPUT_DIR}/dev.tf_record \ 62 | --label_map_file=${label_map_file} \ 63 | --slot_label_map_file=${slot_label_map_file} \ 64 | --model_config_file=${CONFIG_FILE} \ 65 | --output_dir=${MODEL_DIR} \ 66 | --init_checkpoint=${BERT_BASE_DIR}/bert_model.ckpt \ 67 | --do_train=true \ 68 | --do_eval=true \ 69 | --num_train_epochs=${num_train_epochs} \ 70 | --train_batch_size=${TRAIN_BATCH_SIZE} \ 71 | --learning_rate=${learning_rate} \ 72 | --warmup_proportion=${warmup_proportion} \ 73 | --drop_keep_prob=${drop_keep_prob} \ 74 | --kernel_size=${kernel_size} \ 75 | --save_checkpoints_steps=${SAVE_CHECKPOINT_STEPS} \ 76 | --max_seq_length=${max_seq_length} \ 77 | --num_train_examples=${NUM_TRAIN_EXAMPLES} \ 78 | --num_eval_examples=${NUM_EVAL_EXAMPLES} \ 79 | --domain_name=${DOMAIN_NAME} \ 80 | --entity_type_list_file=${entity_type_list_file} 81 | 82 | 83 | ## 4. Prediction 84 | 85 | ### Export the model. 86 | echo "Export the model." 87 | python run_lasertagger.py \ 88 | --label_map_file=${label_map_file} \ 89 | --slot_label_map_file=${slot_label_map_file} \ 90 | --model_config_file=${CONFIG_FILE} \ 91 | --max_seq_length=${max_seq_length} \ 92 | --kernel_size=${kernel_size} \ 93 | --output_dir=${MODEL_DIR} \ 94 | --do_export=true \ 95 | --export_path="${MODEL_DIR}/export" \ 96 | --domain_name=${DOMAIN_NAME} \ 97 | --entity_type_list_file=${entity_type_list_file} 98 | 99 | ######### Get the most recently exported model directory. 100 | TIMESTAMP=$(ls "${MODEL_DIR}/export/" | \ 101 | grep -v "temp-" | sort -r | head -1) 102 | SAVED_MODEL_DIR=${MODEL_DIR}/export/${TIMESTAMP} 103 | PREDICTION_FILE=${MODEL_DIR}/pred.tsv 104 | 105 | 106 | echo "predict_main.py for eval" 107 | python predict_main.py \ 108 | --input_file=${OUTPUT_DIR}/pred.tsv \ 109 | --input_format=${input_format} \ 110 | --output_file=${PREDICTION_FILE} \ 111 | --label_map_file=${label_map_file} \ 112 | --slot_label_map_file=${slot_label_map_file} \ 113 | --vocab_file=${BERT_BASE_DIR}/vocab.txt \ 114 | --max_seq_length=${max_seq_length} \ 115 | --do_lower_case=${do_lower_case} \ 116 | --saved_model=${SAVED_MODEL_DIR} \ 117 | --domain_name=${DOMAIN_NAME} \ 118 | --entity_type_list_file=${entity_type_list_file} 119 | 120 | ####### 5. Evaluation 121 | echo "python score_main.py --prediction_file=" ${PREDICTION_FILE} 122 | python score_main.py --prediction_file=${PREDICTION_FILE} --vocab_file=${BERT_BASE_DIR}/vocab.txt --do_lower_case=true --domain_name=${DOMAIN_NAME} 123 | 124 | 125 | #echo "predict_main.py for test" 126 | #python predict_main.py \ 127 | # --input_file=${OUTPUT_DIR}/submit.csv \ 128 | # --input_format=${input_format} \ 129 | # --output_file=${PREDICTION_FILE} \ 130 | # --submit_file=${SUBMIT_FILE} \ 131 | # --label_map_file=${label_map_file} \ 132 | # --slot_label_map_file=${slot_label_map_file} \ 133 | # --vocab_file=${BERT_BASE_DIR}/vocab.txt \ 134 | # --max_seq_length=${max_seq_length} \ 135 | # --do_lower_case=${do_lower_case} \ 136 | # --saved_model=${SAVED_MODEL_DIR} \ 137 | # --domain_name=${DOMAIN_NAME} \ 138 | # --entity_type_list_file=${entity_type_list_file} 139 | 140 | end_tm=`date +%s%N`; 141 | use_tm=`echo $end_tm $start_tm | awk '{ print ($1 - $2) / 1000000000 /3600}'` 142 | echo "cost time" $use_tm "h" 143 | -------------------------------------------------------------------------------- /score_lib.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Utility functions for computing evaluation metrics.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | 22 | from __future__ import print_function 23 | 24 | import re 25 | from typing import List, Text, Tuple 26 | 27 | import sari_hook 28 | import utils 29 | 30 | import tensorflow as tf 31 | from curLine_file import normal_transformer, curLine 32 | 33 | def read_data( 34 | path, 35 | lowercase, 36 | target_domain_name): 37 | """Reads data from prediction TSV file. 38 | 39 | The prediction file should contain 3 or more columns: 40 | 1: sources (concatenated) 41 | 2: prediction 42 | 3-n: targets (1 or more) 43 | 44 | Args: 45 | path: Path to the prediction file. 46 | lowercase: Whether to lowercase the data (to compute case insensitive 47 | scores). 48 | 49 | Returns: 50 | Tuple (list of sources, list of predictions, list of target lists) 51 | """ 52 | sources = [] 53 | predDomain_list = [] 54 | predIntent_list = [] 55 | domain_list = [] 56 | right_intent_num = 0 57 | right_slot_num = 0 58 | exact_num = 0 59 | with tf.gfile.GFile(path) as f: 60 | for lineId, line in enumerate(f): 61 | if "sessionId" in line and "pred" in line: 62 | continue 63 | sessionId, query, predDomain,predIntent,predSlot, domain,intent,Slot = line.rstrip('\n').split('\t') 64 | # if target_domain_name == predDomain and target_domain_name != domain: # domain 错了 65 | # print(curLine(), lineId, predSlot, "Slot:", Slot, "predDomain:%s, domain:%s" % (predDomain, domain), "predIntent:", 66 | # predIntent) 67 | if predIntent == intent: 68 | right_intent_num += 1 69 | if predSlot == Slot: 70 | exact_num += 1 71 | 72 | if predSlot == Slot: 73 | right_slot_num += 1 74 | else: 75 | if target_domain_name == predDomain and target_domain_name == domain: # 76 | print(curLine(), predSlot, "Slot:", Slot, "predDomain:%s, domain:%s" % (predDomain, domain), "predIntent:", predIntent) 77 | predDomain_list.append(predDomain) 78 | predIntent_list.append(predIntent) 79 | domain_list.append(domain) 80 | return predDomain_list, predIntent_list, domain_list, right_intent_num, right_slot_num, exact_num 81 | 82 | 83 | def compute_exact_score(predictions, domain_list): 84 | """Computes the Exact score (accuracy) of the predictions. 85 | 86 | Exact score is defined as the percentage of predictions that match at least 87 | one of the targets. 88 | 89 | Args: 90 | predictions: List of predictions. 91 | target_lists: List of targets (1 or more per prediction). 92 | 93 | Returns: 94 | Exact score between [0, 1]. 95 | """ 96 | 97 | correct_domain = 0 98 | for p, d in zip(predictions, domain_list): 99 | if p==d: 100 | correct_domain += 1 101 | return correct_domain / max(len(predictions), 0.1) # Avoids 0/0. 102 | 103 | 104 | def compute_sari_scores( 105 | sources, 106 | predictions, 107 | target_lists, 108 | ignore_wikisplit_separators = True, 109 | tokenizer=None): 110 | """Computes SARI scores. 111 | 112 | Wraps the t2t implementation of SARI computation. 113 | 114 | Args: 115 | sources: List of sources. 116 | predictions: List of predictions. 117 | target_lists: List of targets (1 or more per prediction). 118 | ignore_wikisplit_separators: Whether to ignore "<::::>" tokens, used as 119 | sentence separators in Wikisplit, when evaluating. For the numbers 120 | reported in the paper, we accidentally ignored those tokens. Ignoring them 121 | does not affect the Exact score (since there's usually always a period 122 | before the separator to indicate sentence break), but it decreases the 123 | SARI score (since the Addition score goes down as the model doesn't get 124 | points for correctly adding <::::> anymore). 125 | 126 | Returns: 127 | Tuple (SARI score, keep score, addition score, deletion score). 128 | """ 129 | sari_sum = 0 130 | keep_sum = 0 131 | add_sum = 0 132 | del_sum = 0 133 | length_sum = 0 134 | length_max = 0 135 | for source, pred, targets in zip(sources, predictions, target_lists): 136 | if ignore_wikisplit_separators: 137 | source = re.sub(' <::::> ', ' ', source) 138 | pred = re.sub(' <::::> ', ' ', pred) 139 | targets = [re.sub(' <::::> ', ' ', t) for t in targets] 140 | source_ids = tokenizer.tokenize(source) # utils.get_token_list(source) 141 | pred_ids = tokenizer.tokenize(pred) # utils.get_token_list(pred) 142 | list_of_targets = [tokenizer.tokenize(t) for t in targets] 143 | length_sum += len(pred_ids) 144 | length_max = max(length_max, len(pred_ids)) 145 | sari, keep, addition, deletion = sari_hook.get_sari_score( 146 | source_ids, pred_ids, list_of_targets, beta_for_deletion=1) 147 | sari_sum += sari 148 | keep_sum += keep 149 | add_sum += addition 150 | del_sum += deletion 151 | n = max(len(sources), 0.1) # Avoids 0/0. 152 | return (sari_sum / n, keep_sum / n, add_sum / n, del_sum / n, length_sum/n, length_max) -------------------------------------------------------------------------------- /find_entity/exacter_acmation.py: -------------------------------------------------------------------------------- 1 | # 发现疑似实体,辅助预测 2 | import os 3 | from collections import defaultdict 4 | import re 5 | import json 6 | from .acmation import KeywordTree, add_to_ac, entity_files_folder, entity_folder 7 | from curLine_file import curLine, normal_transformer 8 | 9 | re_phoneNum = re.compile("[0-9一二三四五六七八九十拾]+") # 编译 10 | 11 | # AC自动机, similar to trie tree 12 | # 也许直接读取下载的xls文件更方便,但那样需要安装xlrd模块 13 | 14 | domain2entity_map = {} 15 | domain2entity_map["music"] = ["age", "singer", "song", "toplist", "theme", "style", "scene", "language", "emotion", "instrument"] 16 | # domain2entity_map["navigation"] = ["custom_destination", "city"] # "destination", "origin"] 17 | # domain2entity_map["phone_call"] = ["phone_num", "contact_name"] 18 | self_entity_trie_tree = {} # 总的实体字典 自己建立的某些实体类型的实体树 19 | for domain, entity_type_list in domain2entity_map.items(): 20 | print(curLine(), domain, entity_type_list) 21 | for entity_type in entity_type_list: 22 | if entity_type not in self_entity_trie_tree: 23 | ac = KeywordTree(case_insensitive=True) 24 | else: 25 | ac = self_entity_trie_tree[entity_type] 26 | 27 | ### 从标注语料中挖掘得到 28 | entity_file = os.path.join(entity_files_folder, "%s.json" % entity_type) 29 | with open(entity_file, "r") as fr: 30 | current_entity_dict = json.load(fr) 31 | print(curLine(), "get %d %s from %s" % (len(current_entity_dict), entity_type, entity_file)) 32 | for entity_before, entity_after_times in current_entity_dict.items(): 33 | entity_after = entity_after_times[0] 34 | pri = 2 35 | if entity_type in ["song"]: 36 | pri -= 0.5 37 | add_to_ac(ac, entity_type, entity_before, entity_after, pri=pri) 38 | if entity_type == "song": 39 | add_to_ac(ac, entity_type, "花的画", "花的话", 1.5) 40 | 41 | # 给的实体库,最高优先级 42 | entity_file = os.path.join(entity_folder, "%s.txt" % entity_type) 43 | if os.path.exists(entity_file): 44 | with open(entity_file, "r") as fr: 45 | lines = fr.readlines() 46 | print(curLine(), "get %d %s from %s" % (len(lines), entity_type, entity_file)) 47 | for line in lines: 48 | entity_after = line.strip() 49 | entity_before = entity_after # TODO 50 | pri = 3 51 | if entity_type in ["song"]: 52 | pri -= 0.5 53 | add_to_ac(ac, entity_type, entity_before, entity_after, pri=pri) 54 | ac.finalize() 55 | self_entity_trie_tree[entity_type] = ac 56 | 57 | 58 | def get_all_entity(corpus, useEntityTypeList): 59 | self_entityTypeMap = defaultdict(list) 60 | for entity_type in useEntityTypeList: 61 | result = self_entity_trie_tree[entity_type].search(corpus) 62 | for res in result: 63 | after, priority = res.meta_data 64 | self_entityTypeMap[entity_type].append({'before': res.keyword, 'after': after, "priority":priority}) 65 | if "phone_num" in useEntityTypeList: 66 | token_numbers = re_phoneNum.findall(corpus) 67 | for number in token_numbers: 68 | self_entityTypeMap["phone_num"].append({'before':number, 'after':number, 'priority': 2}) 69 | return self_entityTypeMap 70 | 71 | 72 | def get_slot_info(query, domain): 73 | useEntityTypeList = domain2entity_map[domain] 74 | entityTypeMap = get_all_entity(query, useEntityTypeList=useEntityTypeList) 75 | entity_list_all = [] # 汇总所有实体 76 | for entity_type, entity_list in entityTypeMap.items(): 77 | for entity in entity_list: 78 | entity_before = entity['before'] 79 | ignore_flag = False 80 | if entity_type != "song" and len(entity_before) < 2 and entity_before not in ["家","妈"]: 81 | ignore_flag = True 82 | if entity_type == "song" and len(entity_before) < 2 and \ 83 | entity_before not in {"鱼", "云", "逃", "退", "陶", "美", "图", "默"}: 84 | ignore_flag = True 85 | if entity_before in {"什么歌", "一首", "小花","叮当","傻逼", "给你听", "现在", "当我"}: 86 | ignore_flag = True 87 | if ignore_flag: 88 | if entity_before not in "好点没走伤": 89 | print(curLine(), "ignore entity_type:%s, entity:%s, query:%s" 90 | % (entity_type, entity_before, query)) 91 | else: 92 | entity_list_all.append((entity_type, entity_before, entity['after'], entity['priority'])) 93 | entity_list_all = sorted(entity_list_all, key=lambda item: len(item[1])*100+item[3], 94 | reverse=True) # new_entity_map 中key是实体,value是实体类型 95 | slot_info = query 96 | exist_entityType_set = set() 97 | replace_mask = [0] * len(query) 98 | for entity_type, entity_before, entity_after, priority in entity_list_all: 99 | if entity_before not in query: 100 | continue 101 | if entity_type in exist_entityType_set: 102 | continue # 已经有这个类型了,忽略 # TODO 103 | start_location = slot_info.find(entity_before) 104 | if start_location > -1: 105 | exist_entityType_set.add(entity_type) 106 | if entity_after == entity_before: 107 | entity_info_str = "<%s>%s" % (entity_type, entity_after, entity_type) 108 | else: 109 | entity_info_str = "<%s>%s||%s" % (entity_type, entity_before, entity_after, entity_type) 110 | slot_info = slot_info.replace(entity_before, entity_info_str) 111 | query = query.replace(entity_before, "") 112 | else: 113 | print(curLine(), replace_mask, slot_info, "entity_type:", entity_type, entity_before) 114 | return slot_info 115 | 116 | if __name__ == '__main__': 117 | for query in ["拨打10086", "打电话给100十五", "打电话给一二三拾"]: 118 | res = get_slot_info(query, domain="phone_call") 119 | print(curLine(), query, res) 120 | 121 | for query in ["节奏来一首一朵鲜花送给心上人", "陈瑞的歌曲", "放一首销愁"]: 122 | res = get_slot_info(query, domain="music") 123 | print(curLine(), query, res) -------------------------------------------------------------------------------- /official_transformer/attention_layer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Implementation of multiheaded attention and self-attention layers.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | 25 | class Attention(tf.layers.Layer): 26 | """Multi-headed attention layer.""" 27 | 28 | def __init__(self, hidden_size, num_heads, attention_dropout, train): 29 | if hidden_size % num_heads != 0: 30 | raise ValueError("Hidden size must be evenly divisible by the number of " 31 | "heads.") 32 | 33 | super(Attention, self).__init__() 34 | self.hidden_size = hidden_size 35 | self.num_heads = num_heads 36 | self.attention_dropout = attention_dropout 37 | self.train = train 38 | 39 | # Layers for linearly projecting the queries, keys, and values. 40 | self.q_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="q") 41 | self.k_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="k") 42 | self.v_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="v") 43 | 44 | self.output_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, 45 | name="output_transform") 46 | 47 | def split_heads(self, x): 48 | """Split x into different heads, and transpose the resulting value. 49 | 50 | The tensor is transposed to insure the inner dimensions hold the correct 51 | values during the matrix multiplication. 52 | 53 | Args: 54 | x: A tensor with shape [batch_size, length, hidden_size] 55 | 56 | Returns: 57 | A tensor with shape [batch_size, num_heads, length, hidden_size/num_heads] 58 | """ 59 | with tf.name_scope("split_heads"): 60 | batch_size = tf.shape(x)[0] 61 | length = tf.shape(x)[1] 62 | 63 | # Calculate depth of last dimension after it has been split. 64 | depth = (self.hidden_size // self.num_heads) 65 | 66 | # Split the last dimension 67 | x = tf.reshape(x, [batch_size, length, self.num_heads, depth]) 68 | 69 | # Transpose the result 70 | return tf.transpose(x, [0, 2, 1, 3]) 71 | 72 | def combine_heads(self, x): 73 | """Combine tensor that has been split. 74 | 75 | Args: 76 | x: A tensor [batch_size, num_heads, length, hidden_size/num_heads] 77 | 78 | Returns: 79 | A tensor with shape [batch_size, length, hidden_size] 80 | """ 81 | with tf.name_scope("combine_heads"): 82 | batch_size = tf.shape(x)[0] 83 | length = tf.shape(x)[2] 84 | x = tf.transpose(x, [0, 2, 1, 3]) # --> [batch, length, num_heads, depth] 85 | return tf.reshape(x, [batch_size, length, self.hidden_size]) 86 | 87 | def call(self, x, y, bias, cache=None): 88 | """Apply attention mechanism to x and y. 89 | 90 | Args: 91 | x: a tensor with shape [batch_size, length_x, hidden_size] 92 | y: a tensor with shape [batch_size, length_y, hidden_size] 93 | bias: attention bias that will be added to the result of the dot product. 94 | cache: (Used during prediction) dictionary with tensors containing results 95 | of previous attentions. The dictionary must have the items: 96 | {"k": tensor with shape [batch_size, i, key_channels], 97 | "v": tensor with shape [batch_size, i, value_channels]} 98 | where i is the current decoded length. 99 | 100 | Returns: 101 | Attention layer output with shape [batch_size, length_x, hidden_size] 102 | """ 103 | # Linearly project the query (q), key (k) and value (v) using different 104 | # learned projections. This is in preparation of splitting them into 105 | # multiple heads. Multi-head attention uses multiple queries, keys, and 106 | # values rather than regular attention (which uses a single q, k, v). 107 | q = self.q_dense_layer(x) 108 | k = self.k_dense_layer(y) 109 | v = self.v_dense_layer(y) 110 | 111 | if cache is not None: 112 | # Combine cached keys and values with new keys and values. 113 | k = tf.concat([cache["k"], k], axis=1) 114 | v = tf.concat([cache["v"], v], axis=1) 115 | 116 | # Update cache 117 | cache["k"] = k 118 | cache["v"] = v 119 | 120 | # Split q, k, v into heads. 121 | q = self.split_heads(q) 122 | k = self.split_heads(k) 123 | v = self.split_heads(v) 124 | 125 | # Scale q to prevent the dot product between q and k from growing too large. 126 | depth = (self.hidden_size // self.num_heads) 127 | q *= depth ** -0.5 128 | 129 | # Calculate dot product attention 130 | logits = tf.matmul(q, k, transpose_b=True) 131 | logits += bias 132 | weights = tf.nn.softmax(logits, name="attention_weights") 133 | if self.train: 134 | weights = tf.nn.dropout(weights, 1.0 - self.attention_dropout) 135 | attention_output = tf.matmul(weights, v) 136 | 137 | # Recombine heads --> [batch_size, length, hidden_size] 138 | attention_output = self.combine_heads(attention_output) 139 | 140 | # Run the combined outputs through another linear projection layer. 141 | attention_output = self.output_dense_layer(attention_output) 142 | return attention_output 143 | 144 | 145 | class SelfAttention(Attention): 146 | """Multiheaded self-attention layer.""" 147 | 148 | def call(self, x, bias, cache=None): 149 | return super(SelfAttention, self).call(x, x, bias, cache) 150 | -------------------------------------------------------------------------------- /preprocess_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Convert a dataset into the TFRecord format. 18 | 19 | The resulting TFRecord file will be used when training a LaserTagger model. 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | 25 | from __future__ import print_function 26 | 27 | from typing import Text 28 | 29 | from absl import app 30 | from absl import flags 31 | from absl import logging 32 | from bert import tokenization 33 | import bert_example 34 | import json 35 | import utils 36 | 37 | import tensorflow as tf 38 | from find_entity.probable_acmation import domain2entity_map, get_all_entity 39 | from curLine_file import curLine 40 | 41 | 42 | FLAGS = flags.FLAGS 43 | 44 | flags.DEFINE_string( 45 | 'input_file', None, 46 | 'Path to the input file containing examples to be converted to ' 47 | 'tf.Examples.') 48 | flags.DEFINE_enum( 49 | 'input_format', None, ['nlu'], 50 | 'Format which indicates how to parse the input_file.') 51 | flags.DEFINE_string('output_tfrecord_train', None, 52 | 'Path to the resulting TFRecord file.') 53 | flags.DEFINE_string('output_tfrecord_dev', None, 54 | 'Path to the resulting TFRecord file.') 55 | flags.DEFINE_string( 56 | 'label_map_file', None, 57 | 'Path to the label map file. Either a JSON file ending with ".json", that ' 58 | 'maps each possible tag to an ID, or a text file that has one tag per line.') 59 | flags.DEFINE_string( 60 | 'slot_label_map_file', None, 61 | 'Path to the label map file. Either a JSON file ending with ".json", that ' 62 | 'maps each possible tag to an ID, or a text file that has one tag per line.') 63 | flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.') 64 | flags.DEFINE_integer('max_seq_length', 128, 'Maximum sequence length.') 65 | flags.DEFINE_bool( 66 | 'do_lower_case', False, 67 | 'Whether to lower case the input text. Should be True for uncased ' 68 | 'models and False for cased models.') 69 | flags.DEFINE_string( 70 | 'domain_name', None, 'domain_name') 71 | flags.DEFINE_string( 72 | 'entity_type_list_file', None, 'path of entity_type_list_file') 73 | 74 | def _write_example_count(count: int, output_file: str) -> Text: 75 | """Saves the number of converted examples to a file. 76 | 77 | This count is used when determining the number of training steps. 78 | 79 | Args: 80 | count: The number of converted examples. 81 | 82 | Returns: 83 | The filename to which the count is saved. 84 | """ 85 | count_fname = output_file + '.num_examples.txt' 86 | with tf.gfile.GFile(count_fname, 'w') as count_writer: 87 | count_writer.write(str(count)) 88 | return count_fname 89 | 90 | 91 | 92 | def main(argv): 93 | if len(argv) > 1: 94 | raise app.UsageError('Too many command-line arguments.') 95 | flags.mark_flag_as_required('input_file') 96 | flags.mark_flag_as_required('input_format') 97 | flags.mark_flag_as_required('output_tfrecord_train') 98 | flags.mark_flag_as_required('output_tfrecord_dev') 99 | flags.mark_flag_as_required('vocab_file') 100 | target_domain_name = FLAGS.domain_name 101 | entity_type_list = domain2entity_map[target_domain_name] 102 | 103 | print(curLine(), "target_domain_name:", target_domain_name, len(entity_type_list), "entity_type_list:", entity_type_list) 104 | builder = bert_example.BertExampleBuilder({}, FLAGS.vocab_file, 105 | FLAGS.max_seq_length, 106 | FLAGS.do_lower_case, slot_label_map={}, 107 | entity_type_list=entity_type_list, get_entity_func=get_all_entity) 108 | 109 | num_converted = 0 110 | num_ignored = 0 111 | # ff = open("new_%s.txt" % target_domain_name, "w") # TODO 112 | with tf.python_io.TFRecordWriter(FLAGS.output_tfrecord_train) as writer_train: 113 | for i, (sources, target) in enumerate(utils.yield_sources_and_targets( 114 | FLAGS.input_file, FLAGS.input_format, target_domain_name)): 115 | logging.log_every_n( 116 | logging.INFO, 117 | f'{i} examples processed, {num_converted} converted to tf.Example.', 118 | 10000) 119 | if len(sources[0]) > 35: # 忽略问题太长的样本 120 | num_ignored += 1 121 | print(curLine(), "ignore num_ignored=%d, question length=%d" % (num_ignored, len(sources[0]))) 122 | continue 123 | example1, _, info_str = builder.build_bert_example(sources, target) 124 | example =example1.to_tf_example().SerializeToString() 125 | writer_train.write(example) 126 | num_converted += 1 127 | # ff.write("%d %s\n" % (i, info_str)) 128 | logging.info(f'Done. {num_converted} examples converted to tf.Example, num_ignored {num_ignored} examples.') 129 | for output_file in [FLAGS.output_tfrecord_train]: 130 | count_fname = _write_example_count(num_converted, output_file=output_file) 131 | logging.info(f'Wrote:\n{output_file}\n{count_fname}') 132 | with open(FLAGS.label_map_file, "w") as f: 133 | json.dump(builder._label_map, f, ensure_ascii=False, indent=4) 134 | print(curLine(), "save %d to %s" % (len(builder._label_map), FLAGS.label_map_file)) 135 | with open(FLAGS.slot_label_map_file, "w") as f: 136 | json.dump(builder.slot_label_map, f, ensure_ascii=False, indent=4) 137 | print(curLine(), "save %d to %s" % (len(builder.slot_label_map), FLAGS.slot_label_map_file)) 138 | 139 | with open(FLAGS.entity_type_list_file, "w") as f: 140 | json.dump(domain2entity_map, f, ensure_ascii=False, indent=4) 141 | print(curLine(), "save %d to %s" % (len(domain2entity_map), FLAGS.entity_type_list_file)) 142 | 143 | 144 | if __name__ == '__main__': 145 | app.run(main) -------------------------------------------------------------------------------- /confusion_words.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 混淆词提取算法 不考虑多音字,只处理同音字的情况 3 | from pypinyin import lazy_pinyin, pinyin, Style 4 | # pinyin 的方法默认带声调,而 lazy_pinyin 方法不带声调,但是会得到不常见的多音字 5 | # pinyin(c, heteronym=True, style=0) 不考虑音调情况下的多音字 6 | from Levenshtein import distance # python-Levenshtein 7 | # 编辑距离 (Levenshtein Distance算法) 8 | import os 9 | from curLine_file import curLine 10 | from find_entity.acmation import entity_folder 11 | 12 | number_map = {"0":"零", "1":"一", "2":"二", "3":"三", "4":"四", "5":"五", "6":"六", "7":"七", "8":"八", "9":"九"} 13 | cons_table = {'n': 'l', 'l': 'n', 'f': 'h', 'h': 'f', 'zh': 'z', 'z': 'zh', 'c': 'ch', 'ch': 'c', 's': 'sh', 'sh': 's'} 14 | # def get_similar_pinyin(entity_type): # 存的时候直接把拼音的变体也保存下来 15 | # entity_info_dict = {} 16 | # entity_file = os.path.join(entity_folder, "%s.txt" % entity_type) 17 | # with open(entity_file, "r") as fr: 18 | # lines = fr.readlines() 19 | # pri = 3 20 | # if entity_type in ["song"]: 21 | # pri -= 0.5 22 | # print(curLine(), "get %d %s from %s, pri=%f" % (len(lines), entity_type, entity_file, pri)) 23 | # for line in lines: 24 | # entity = line.strip() 25 | # for k,v in number_map.items(): 26 | # entity.replace(k, v) 27 | # # for combination in all_combination: 28 | # if entity not in entity_info_dict: # 新的实体 29 | # combination = "".join(lazy_pinyin(entity)) # default:默认行为,不处理,原木原样返回 , errors="ignore" 30 | # if len(combination) < 2: 31 | # print(curLine(), "warning:", entity, "combination:", combination) 32 | # entity_info_dict[entity] = (combination, pri) 33 | 34 | def get_entityType_pinyin(entity_type): 35 | entity_info_dict = {} 36 | entity_file = os.path.join(entity_folder, "%s.txt" % entity_type) 37 | with open(entity_file, "r") as fr: 38 | lines = fr.readlines() 39 | pri = 3 40 | if entity_type in ["song"]: 41 | pri -= 0.5 42 | print(curLine(), "get %d %s from %s, pri=%f" % (len(lines), entity_type, entity_file, pri)) 43 | for line in lines: 44 | entity = line.strip() 45 | for k,v in number_map.items(): 46 | entity.replace(k, v) 47 | # for combination in all_combination: 48 | if entity not in entity_info_dict: # 新的实体 49 | combination = "".join(lazy_pinyin(entity)) # default:默认行为,不处理,原木原样返回 , errors="ignore" 50 | if len(combination) < 2: 51 | print(curLine(), "warning:", entity, "combination:", combination) 52 | entity_info_dict[entity] = (combination, pri) 53 | else: 54 | combination, old_pri = entity_info_dict[entity] 55 | if pri > old_pri: 56 | entity_info_dict[entity] = (combination, pri) 57 | return entity_info_dict 58 | 59 | # 用编辑距离度量拼音字符串之间的相似度 不考虑相似实体的多音情况 60 | def pinyin_similar_word_noduoyin(entity_info_dict, word): 61 | if word in entity_info_dict: # 存在实体,无需纠错 62 | return 1.0, word 63 | best_similar_word = None 64 | top_similar_score = 0 65 | try: 66 | all_combination = ["".join(lazy_pinyin(word))] # get_pinyin_combination(entity=word) # 67 | for current_combination in all_combination: # 当前的各种发音 68 | if len(current_combination) == 0: 69 | print(curLine(), "word:", word) 70 | continue 71 | similar_word = None 72 | current_distance = 10000 73 | for entity,(com, pri) in entity_info_dict.items(): 74 | char_ratio = 0.0 75 | d = distance(com, current_combination)*(1.0-char_ratio) + distance(entity, word) * char_ratio 76 | if d < current_distance: 77 | current_distance = d 78 | similar_word = entity 79 | # if d<=2.5: 80 | # print(curLine(),com, current_combination, distance(com, current_combination), distance(entity, word) ) 81 | # print(curLine(), word, entity, similar_word, "current_distance=", current_distance) 82 | 83 | 84 | current_similar_score = 1.0 - float(current_distance) / len(current_combination) 85 | # print(curLine(), "current_combination:%s, %f" % (current_combination, current_similar_score), similar_word, current_distance) 86 | if current_similar_score > top_similar_score: 87 | # print(curLine(), current_similar_score, top_similar_score, best_similar_word, similar_word) 88 | best_similar_word = similar_word 89 | top_similar_score = current_similar_score 90 | except Exception as error: 91 | print(curLine(), "error:", error) 92 | return top_similar_score, best_similar_word 93 | 94 | # 自己包装的函数,返回字符的声母(可能为空,如啊呀),韵母,整体拼音 95 | def my_pinyin(char): 96 | shengmu = pinyin(char, style=Style.INITIALS, strict=True)[0][0] 97 | yunmu = pinyin(char, style=Style.FINALS, strict=True)[0][0] 98 | total_pinyin = lazy_pinyin(char, errors='default')[0] 99 | if shengmu + yunmu != total_pinyin: 100 | print(curLine(), "char:", char, ",shengmu:%s, yunmu:%s" % (shengmu, yunmu), total_pinyin) 101 | return shengmu, yunmu, total_pinyin 102 | 103 | singer_pinyin = get_entityType_pinyin(entity_type="singer") 104 | print(curLine(), len(singer_pinyin), "singer_pinyin") 105 | 106 | song_pinyin = get_entityType_pinyin(entity_type="song") 107 | print(curLine(), len(song_pinyin), "song_pinyin") 108 | 109 | 110 | if __name__ == "__main__": 111 | # pinyin('中心', heteronym=True) # 启用多音字模式 112 | 113 | 114 | import confusion_words_duoyin 115 | res = pinyin_similar_word_noduoyin(song_pinyin, "体验") # confusion_words_duoyin.song_pinyin, "体验") 116 | print(curLine(), res) 117 | 118 | s = "啊饿恩昂你为什*.\"123c单身" 119 | s = "啊饿恩昂你为什单身的呢?" 120 | yunmu_list = pinyin(s, style=Style.FINALS) 121 | shengmu_list = pinyin(s, style=Style.INITIALS, strict=False) # 返回声母,但是严格按照标准(strict=True)有些字没有声母会返回空字符串,strict=False时会返回 122 | pinyin_list = lazy_pinyin(s, errors='default') # ""ignore") 123 | print(curLine(), len(shengmu_list), "shengmu_list:", shengmu_list) 124 | print(curLine(), len(yunmu_list), "yunmu_list:", yunmu_list) 125 | print(curLine(), len(pinyin_list), "pinyin_list:", pinyin_list) 126 | for id, c in enumerate(s): 127 | print(curLine(), id, c, my_pinyin(c)) 128 | # print(curLine(), c, pinyin(c, style=Style.INITIALS), s) 129 | # print(curLine(), c,c_pinyin) 130 | # print(curLine(), pinyin(s, heteronym=True, style=0)) 131 | # # print(pinyin(c, heteronym=True, style=0)) 132 | 133 | # all_combination = get_pinyin_combination(entity=s) 134 | # 135 | # for index, combination in enumerate(all_combination): 136 | # print(curLine(), index, combination) 137 | # for s in ["abc", "abs1", "fabc"]: 138 | # ed = distance("abs", s) 139 | # print(s, ed) 140 | for predict_singer in ["周杰伦", "前任", "朱姐", "奇隆"]: 141 | top_similar_score, best_similar_word = pinyin_similar_word_noduoyin(singer_pinyin, predict_singer) 142 | print(curLine(), predict_singer, top_similar_score, best_similar_word) 143 | -------------------------------------------------------------------------------- /find_entity/acmation.py: -------------------------------------------------------------------------------- 1 | import json, os 2 | import re 3 | from curLine_file import curLine, normal_transformer 4 | 5 | re_phoneNum = re.compile("[0-9一二三四五六七八九十拾]+") # 编译 6 | 7 | entity_files_folder = "./entity_files" 8 | entity_folder = os.path.join(entity_files_folder, "slot-dictionaries") 9 | frequentSong = {} 10 | frequentSinger = {} 11 | with open(os.path.join(entity_files_folder,"frequentSong.json"), "r") as f: 12 | frequentSong = json.load(f) 13 | 14 | with open(os.path.join(entity_files_folder,"frequentSinger.json"), "r") as f: 15 | frequentSinger = json.load(f) 16 | 17 | # AC自动机, similar to trie tree 18 | class State(object): 19 | __slots__ = ['identifier', 'symbol', 'success', 'transitions', 'parent', 20 | 'matched_keyword', 'longest_strict_suffix', 'meta_data'] 21 | 22 | def __init__(self, identifier, symbol=None, parent=None, success=False): 23 | self.symbol = symbol 24 | self.identifier = identifier 25 | self.transitions = {} 26 | self.parent = parent 27 | self.success = success 28 | self.matched_keyword = None 29 | self.longest_strict_suffix = None 30 | 31 | 32 | class Result(object): 33 | __slots__ = ['keyword', 'location', 'meta_data'] 34 | 35 | def __init__(self, **kwargs): 36 | for k, v in kwargs.items(): 37 | setattr(self, k, v) 38 | 39 | def __str__(self): 40 | return_str = '' 41 | for k in self.__slots__: 42 | return_str += '{}:{:<20}\t'.format(k, json.dumps(getattr(self, k))) 43 | return return_str 44 | 45 | class KeywordTree(object): 46 | def __init__(self, case_insensitive=True): 47 | ''' 48 | @param case_insensitive: If true, case will be ignored when searching. 49 | Setting this to true will have a positive 50 | impact on performance. 51 | Defaults to false. 52 | ''' 53 | self._zero_state = State(0) 54 | self._counter = 1 55 | self._finalized = False 56 | self._case_insensitive = case_insensitive 57 | 58 | def add(self, keywords, meta_data=None): 59 | ''' 60 | Add a keyword to the tree. 61 | Can only be used before finalize() has been called. 62 | Keyword should be str or unicode. 63 | ''' 64 | if self._finalized: 65 | raise ValueError('KeywordTree has been finalized.' + 66 | ' No more keyword additions allowed') 67 | original_keyword = keywords 68 | if self._case_insensitive: 69 | if isinstance(keywords, list): 70 | keywords = map(str.lower, keywords) 71 | elif isinstance(keywords, str): 72 | keywords = keywords.lower() 73 | # if keywords != original_keyword: 74 | # print(curLine(), keywords, original_keyword) 75 | # input(curLine()) 76 | else: 77 | raise Exception('keywords error') 78 | if len(keywords) <= 0: 79 | return 80 | current_state = self._zero_state 81 | for word in keywords: 82 | try: 83 | current_state = current_state.transitions[word] 84 | except KeyError: 85 | next_state = State(self._counter, parent=current_state, 86 | symbol=word) 87 | self._counter += 1 88 | current_state.transitions[word] = next_state 89 | current_state = next_state 90 | current_state.success = True 91 | current_state.matched_keyword = original_keyword 92 | current_state.meta_data = meta_data 93 | 94 | def search(self, text, greedy=False, cut_word=False, cut_separator=' '): 95 | ''' 96 | 97 | :param text: 98 | :param greedy: 99 | :param cut_word: 100 | :param cut_separator: 101 | :return: 102 | ''' 103 | gen = self._search(text, cut_word, cut_separator) 104 | pre = None 105 | for result in gen: 106 | assert isinstance(result, Result) 107 | if not greedy: 108 | yield result 109 | continue 110 | if pre is None: 111 | pre = result 112 | 113 | if result.location > pre.location: 114 | yield pre 115 | pre = result 116 | continue 117 | 118 | if len(result.keyword) > len(pre.keyword): 119 | pre = result 120 | continue 121 | if pre is not None: 122 | yield pre 123 | 124 | def _search(self, text, cut_word=False, cut_separator=' '): 125 | ''' 126 | Search a text for all occurences of the added keywords. 127 | Can only be called after finalized() has been called. 128 | O(n) with n = len(text) 129 | @return: Generator used to iterate over the results. 130 | Or None if no keyword was found in the text. 131 | ''' 132 | # if not self._finalized: 133 | # raise ValueError('KeywordTree has not been finalized.' + 134 | # ' No search allowed. Call finalize() first.') 135 | if self._case_insensitive: 136 | if isinstance(text, list): 137 | text = map(str.lower, text) 138 | elif isinstance(text, str): 139 | text = text.lower() 140 | else: 141 | raise Exception('context type error') 142 | 143 | if cut_word: 144 | if isinstance(text, str): 145 | text = text.split(cut_separator) 146 | 147 | current_state = self._zero_state 148 | for idx, symbol in enumerate(text): 149 | current_state = current_state.transitions.get( 150 | symbol, self._zero_state.transitions.get(symbol, self._zero_state)) 151 | state = current_state 152 | while state != self._zero_state: 153 | if state.success: 154 | keyword = state.matched_keyword 155 | yield Result(**{ 156 | 'keyword': keyword, 157 | 'location': idx - len(keyword) + 1, 158 | 'meta_data': state.meta_data 159 | }) 160 | # yield (keyword, idx - len(keyword) + 1, state.meta_data) 161 | state = state.longest_strict_suffix 162 | 163 | def finalize(self): 164 | ''' 165 | Needs to be called after all keywords have been added and 166 | before any searching is performed. 167 | ''' 168 | if self._finalized: 169 | raise ValueError('KeywordTree has already been finalized.') 170 | self._zero_state.longest_strict_suffix = self._zero_state 171 | processed = set() 172 | to_process = [self._zero_state] 173 | while to_process: 174 | state = to_process.pop() # 删除并返回最后一个元素,所以这是深度优先搜索 175 | processed.add(state.identifier) 176 | for child in state.transitions.values(): 177 | if child.identifier not in processed: 178 | self.search_lss(child) 179 | to_process.append(child) 180 | self._finalized = True 181 | 182 | def __str__(self): 183 | return "ahocorapy KeywordTree" 184 | 185 | 186 | def search_lss(self, state): 187 | if state.longest_strict_suffix is None: 188 | parent = state.parent 189 | traversed = parent.longest_strict_suffix 190 | while True: 191 | if state.symbol in traversed.transitions and \ 192 | traversed.transitions[state.symbol] != state: 193 | state.longest_strict_suffix = \ 194 | traversed.transitions[state.symbol] 195 | break 196 | elif traversed == self._zero_state: 197 | state.longest_strict_suffix = self._zero_state 198 | break 199 | else: 200 | traversed = traversed.longest_strict_suffix 201 | suffix = state.longest_strict_suffix 202 | if suffix.longest_strict_suffix is None: 203 | self.search_lss(suffix) 204 | for symbol, next_state in suffix.transitions.items(): 205 | if (symbol not in state.transitions and 206 | suffix != self._zero_state): 207 | state.transitions[symbol] = next_state 208 | 209 | def add_to_ac(ac, entity_type, entity_before, entity_after, pri): 210 | entity_before = normal_transformer(entity_before) 211 | flag = "ignore" 212 | if entity_type == "song" and ((entity_after in frequentSong) and entity_after not in {"李白"}): 213 | return flag 214 | if entity_type == "singer" and entity_after in frequentSinger: 215 | return flag 216 | elif entity_type == "toplist" and entity_before == "首张": 217 | return flag 218 | elif entity_type == "emotion" and entity_before in {"high歌","相思","喜欢"}: # train和devs就是这么标注的 219 | return flag 220 | elif entity_type == "language" and entity_before in ["中国"]: # train和devs就是这么标注的 221 | return flag 222 | ac.add(keywords=entity_before, meta_data=(entity_after,pri)) 223 | return "add success" 224 | 225 | -------------------------------------------------------------------------------- /predict_param.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 识别槽位 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | from Levenshtein import distance 9 | from curLine_file import curLine 10 | from find_entity.exacter_acmation import get_all_entity 11 | from confusion_words_danyin import correct_song, correct_singer 12 | tongyin_yuzhi = 0.65 13 | char_distance_yuzhi = 0.6 14 | 15 | 16 | # 得到两个字符串在字符级别的相似度 越大越相似 17 | def get_char_similarScore(predict_entity, known_entity): 18 | ''' 19 | :param predict_entity: 识别得到的槽值 20 | :param known_entity: 库中已有的实体,用这个实体的长度对编辑距离进行归一化 21 | :return: 对编辑距离归一化后的分数,越大则越相似 22 | ''' 23 | d = distance(predict_entity, known_entity) 24 | score = 1.0 - float(d) / (len(known_entity)+ len(predict_entity)) 25 | if score < char_distance_yuzhi: 26 | score = 0 27 | return score 28 | 29 | 30 | def get_slot_info_str_forMusic(slot_info, raw_query, entityTypeMap): # 列表连接成字符串 31 | # 没有直接join得到字符串,因为想要剔除某些只有一个字的实体 例如phone_num 32 | slot_info_block = [] 33 | param_list = [] 34 | current_entityType = None 35 | for token in slot_info: 36 | if "<" in token and ">" in token and "/" not in token: # 一个槽位的开始 37 | if len(slot_info_block) > 0: 38 | print(curLine(), len(slot_info), "slot_info:", slot_info) 39 | assert len(slot_info_block) == 0, "".join(slot_info_block) 40 | slot_info_block = [] 41 | current_entityType = token[1:-1] 42 | elif "<" in token and ">" in token and "/" in token: # 一个槽位的结束 43 | slot_info_block_str = "".join(slot_info_block) 44 | # 在循环前按照没找到字符角度相似的实体的情况初始化 45 | entity_before = slot_info_block_str # 如果没找到字符角度相似的实体,就返回entity_before 46 | entity_after = slot_info_block_str 47 | priority = 0 # 优先级 48 | ignore_flag = False # 是否忽略这个槽值 49 | if slot_info_block_str in {"什么歌","一首","小花","叮当","傻逼", "给你听","现在","喜欢","ye ye","没","去你娘的蛋", "j c"}: # 黑名单 不是一个槽值 50 | ignore_flag = True # 忽略这个槽值 51 | else: 52 | # 逐个实体计算和slot_info_block_str的相似度 53 | char_similarScore = 0 # 识别的槽值与库中实体在字符上的相似度 54 | for entity_info in entityTypeMap[current_entityType]: 55 | current_entity_before = entity_info['before'] 56 | if slot_info_block_str == current_entity_before: # 在实体库中 57 | entity_before = current_entity_before 58 | entity_after = entity_info['after'] 59 | priority = entity_info['priority'] 60 | char_similarScore = 1.0 # 完全匹配,槽值在库中 61 | if slot_info_block_str != entity_after: # 已经出现过的纠错词 62 | slot_info_block_str = "%s||%s" % (slot_info_block_str, entity_after) 63 | break 64 | current_char_similarScore = get_char_similarScore(predict_entity=slot_info_block_str, known_entity=current_entity_before) 65 | # print(curLine(), slot_info_block_str, current_entity_before, current_char_similarScore) 66 | if current_char_similarScore > char_similarScore: # 在句子找到字面上更接近的实体 67 | entity_before = current_entity_before 68 | entity_after = entity_info['after'] 69 | priority = entity_info['priority'] 70 | char_similarScore = current_char_similarScore 71 | if priority == 0: 72 | if current_entityType not in {"singer", "song"}: 73 | ignore_flag = True 74 | if len(entity_before) < 2: # 忽略这个短的槽值 75 | ignore_flag = True 76 | if current_entityType == "song" and entity_before in {"鱼", "云", "逃", "退", "陶", "美", "图", "默", "哭", "雪"}: 77 | ignore_flag = False # 这个要在后面判断 78 | if ignore_flag: 79 | if entity_before not in "好点没走": 80 | print(curLine(), raw_query, "ignore entityType:%s, slot:%s, entity_before:%s" % (current_entityType, slot_info_block_str, entity_before)) 81 | else: 82 | param_list.append({'entityType':current_entityType, 'before':entity_before, 'after':entity_after, 'priority':priority}) 83 | slot_info_block = [] 84 | current_entityType = None 85 | elif current_entityType is not None: # is in a slot 86 | slot_info_block.append(token) 87 | param_list_sorted = sorted(param_list, key=lambda item: len(item['before'])*100+item['priority'], 88 | reverse=True) 89 | slot_info_str_list = [raw_query] 90 | replace_query = raw_query 91 | for param in param_list_sorted: 92 | entity_before = param['before'] 93 | replace_str = entity_before 94 | if replace_str not in replace_query: 95 | continue 96 | replace_query = replace_query.replace(replace_str, "", 1) # 只替换一次 97 | entityType = param['entityType'] 98 | 99 | if param['priority'] == 0: # 模型识别的结果不在库中,尝试用拼音纠错 100 | similar_score = 0.0 101 | best_similar_word = None 102 | if entityType == "singer": 103 | similar_score, best_similar_word = correct_singer(entity_before, jichu_distance=0.001, char_ratio=0.1, char_distance=0) 104 | elif entityType == "song": 105 | similar_score, best_similar_word = correct_song(entity_before, jichu_distance=0.001, char_ratio=0.1, char_distance=0) 106 | if similar_score > tongyin_yuzhi and best_similar_word != entity_before: # 槽值纠错 107 | # print(curLine(), entityType, "entity_before:",entity_before, best_similar_word, similar_score) 108 | param['after'] = best_similar_word 109 | 110 | if entity_before != param['after']: 111 | replace_str = "%s||%s" % (entity_before, param['after']) 112 | for s_index,s in enumerate(slot_info_str_list): 113 | if entity_before not in s or ("" in s): # 不是当前槽值或,已经是一个槽值不能再替换 114 | continue 115 | insert_list = [] 116 | start_index = s.find(entity_before) 117 | if start_index > 0: 118 | insert_list.append(s[:start_index]) 119 | insert_list.append("<%s>%s" % (entityType, replace_str, entityType)) 120 | end_index = start_index+len(entity_before) 121 | if end_index < len(s): 122 | insert_list.append(s[end_index:]) 123 | slot_info_str_list = slot_info_str_list[:s_index] + insert_list + slot_info_str_list[s_index+1:] 124 | break # 一个槽值只替换一次 125 | slot_info_str = "".join(slot_info_str_list) 126 | return slot_info_str 127 | 128 | 129 | def get_slot_info_str(slot_info, raw_query, entityTypeMap): # 列表连接成字符串 130 | # 没有直接join得到字符串,因为想要剔除某些只有一个字的实体 例如phone_num 131 | slot_info_str = [] 132 | slot_info_block = [] 133 | current_entityType = None 134 | for token in slot_info: 135 | if "<" in token and ">" in token and "/" not in token: # 一个槽位的开始 136 | if len(slot_info_block) > 0: 137 | print(curLine(), len(slot_info), "slot_info:", slot_info) 138 | assert len(slot_info_block) == 0, "".join(slot_info_block) 139 | slot_info_block = [] 140 | current_entityType = token[1:-1] 141 | elif "<" in token and ">" in token and "/" in token: # 一个槽位的结束 142 | ignore_flag = False 143 | slot_info_block_str = "".join(slot_info_block) 144 | if (current_entityType != "phone_num" or len(slot_info_block_str) > 1) and \ 145 | slot_info_block_str not in {"什么歌", "一首", "小花","叮当","傻逼", "给你听","现在"}: # 确认是一个槽值 # TODO 为什么phone_num例外 146 | if current_entityType in ["singer", "song"]: # 同音词查找 147 | known_entity_flag = False # 假设预测的实体不在实体库中 148 | for entity_info in entityTypeMap[current_entityType]: 149 | if slot_info_block_str == entity_info['before']: # 在实体库中 150 | entity_after = entity_info['after'] 151 | if slot_info_block_str != entity_after: # 已经出现过的纠错词 152 | slot_info_block_str = "%s||%s" % (slot_info_block_str, entity_after) 153 | known_entity_flag = True 154 | break 155 | 156 | if not known_entity_flag:# 预测的实体不在实体库中,则要尝试结合拼音进行纠错 157 | # TODO 目前未考虑多个纠错后的after具有相同拼音的情况 可以结合音调和字符进行筛选 158 | if len(slot_info_block_str) < 2 and slot_info_block_str not in {"鱼","云","逃","退"}: 159 | ignore_flag = True # 忽略一个字的 160 | elif slot_info_block_str in {"什么歌", "一首"}: 161 | ignore_flag = True # 忽略一个字的 162 | else: 163 | if current_entityType == "singer": 164 | similar_score, best_similar_word = correct_singer(slot_info_block_str) 165 | elif current_entityType == "song": 166 | similar_score, best_similar_word = correct_song(slot_info_block_str, char_ratio=0.55, char_distance=0.0) 167 | if best_similar_word != slot_info_block_str: 168 | if similar_score > tongyin_yuzhi: 169 | # print(curLine(), current_entityType, slot_info_block_str, best_similar_word, similar_score) 170 | slot_info_block_str = "%s||%s" % (slot_info_block_str, best_similar_word) 171 | elif current_entityType in ["theme", "style", "age", "toplist", "emotion", "language", "instrument", "scene"]: 172 | entity_info_list = get_all_entity(raw_query, useEntityTypeList=[current_entityType])[current_entityType] 173 | ignore_flag = True # 忽略不在库中的实体 174 | min_distance = 1000 175 | houxuan = None 176 | for entity_info in entity_info_list: 177 | entity_before = entity_info["before"] 178 | if slot_info_block_str == entity_before: 179 | houxuan = entity_before 180 | min_distance = 0 181 | break 182 | if houxuan is not None and min_distance <= 1: 183 | if slot_info_block_str != houxuan: 184 | print(curLine(), len(entity_info_list), "entity_info_list:", entity_info_list, "slot_info_block_str:", 185 | slot_info_block_str) 186 | print(curLine(), "change %s from %s to %s" % (current_entityType, slot_info_block_str, houxuan)) 187 | slot_info_block_str = houxuan 188 | ignore_flag = False # 不忽略 189 | if ignore_flag: 190 | slot_info_str.extend([slot_info_block_str]) 191 | else: 192 | slot_info_str.extend(["<%s>" % current_entityType, slot_info_block_str, token]) 193 | 194 | else: # 忽略这个短的槽值 195 | print(curLine(), "ignore entityType:%s, slot:%s" % (current_entityType, slot_info_block_str)) 196 | slot_info_str.append(slot_info_block_str) 197 | current_entityType = None 198 | slot_info_block = [] 199 | elif current_entityType is None: # not slot 200 | slot_info_str.append(token) 201 | else: # is in a slot 202 | slot_info_block.append(token) 203 | slot_info_str = "".join(slot_info_str) 204 | return slot_info_str -------------------------------------------------------------------------------- /predict_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Utility functions for running inference with a LaserTagger model.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | 22 | from __future__ import print_function 23 | from collections import defaultdict 24 | 25 | from find_entity.acmation import re_phoneNum 26 | from curLine_file import curLine 27 | from predict_param import get_slot_info_str, get_slot_info_str_forMusic 28 | 29 | tongyin_yuzhi = 0.7 30 | cancel_keywords = ["取消", "关闭", "停止", "结束", "关掉", "不要打", "退出", "不需要", "暂停", "谢谢你的服务"] 31 | 32 | class LaserTaggerPredictor(object): 33 | """Class for computing and realizing predictions with LaserTagger.""" 34 | 35 | def __init__(self, tf_predictor, 36 | example_builder, 37 | label_map, slot_label_map, 38 | target_domain_name): 39 | """Initializes an instance of LaserTaggerPredictor. 40 | 41 | Args: 42 | tf_predictor: Loaded Tensorflow model. 43 | example_builder: BERT example builder. 44 | label_map: Mapping from tags to tag IDs. 45 | """ 46 | self._predictor = tf_predictor 47 | self._example_builder = example_builder 48 | self.intent_id_2_tag = { 49 | tag_id: tag for tag, tag_id in label_map.items() 50 | } 51 | self.slot_id_2_tag = { 52 | tag_id: tag for tag, tag_id in slot_label_map.items() 53 | } 54 | self.target_domain_name = target_domain_name 55 | 56 | def predict_batch(self, sources_batch, location_batch=None, target_domain_name=None, predict_domain_batch=[], raw_query=None): # 由predict改成 57 | """Returns realized prediction for given sources.""" 58 | # Predict tag IDs. 59 | keys = ['input_ids', 'input_mask', 'segment_ids', 'entity_type_ids', 'sequence_lengths'] 60 | input_info = defaultdict(list) 61 | example_list = [] 62 | input_tokens_list = [] 63 | location = None 64 | for id, sources in enumerate(sources_batch): 65 | if location_batch is not None: 66 | location = location_batch[id] # 表示是否能修改 67 | example, input_tokens, _= self._example_builder.build_bert_example(sources, location=location) 68 | if example is None: 69 | raise ValueError("Example couldn't be built.") 70 | for key in keys: 71 | input_info[key].append(example.features[key]) 72 | example_list.append(example) 73 | input_tokens_list.append(input_tokens) 74 | 75 | out = self._predictor(input_info) 76 | intent_list = [] 77 | slot_list = [] 78 | 79 | for index, intent in enumerate(out['pred_intent']): 80 | predicted_intent_ids = intent.tolist() 81 | predict_intent = self.intent_id_2_tag[predicted_intent_ids] 82 | query = sources_batch[index][-1] 83 | slot_info = query 84 | 85 | 86 | predict_domain = predict_domain_batch[index] 87 | if predict_domain != target_domain_name: 88 | intent_list.append(predict_intent) 89 | slot_list.append(slot_info) 90 | continue 91 | for word in cancel_keywords: 92 | if word in query: 93 | if target_domain_name == 'navigation': 94 | predict_intent = 'cancel_navigation' 95 | elif target_domain_name == 'music': 96 | predict_intent = 'pause' 97 | elif target_domain_name == 'phone_call': 98 | predict_intent = 'cancel' 99 | else: 100 | raise "wrong target_domain_name:%s" % target_domain_name 101 | break 102 | 103 | if target_domain_name == 'navigation': 104 | if predict_intent != 'navigation': 105 | slot_info = query 106 | else: 107 | slot_info = self.get_slot_info(out['pred_slot'][index], example_list[index], input_tokens_list[index], query) 108 | # if "cancel" not in predict_intent: 109 | # slot_info = self.get_slot_info(out['pred_slot'][index], example_list[index], input_tokens_list[index], query) 110 | # if ">" in slot_info and " 0 and "[UNK]" in input_tokens[tokenizer_id-1]: 139 | if t in query[previous_id:]: 140 | previous_id = previous_id + query[previous_id:].index(t) 141 | else: # 出现连续的UNK情况,目前的做法是假设长度为1 142 | previous_id += 1 143 | token_index_map[tokenizer_id] = previous_id 144 | if "[UNK]" not in t: 145 | length_t = len(t) 146 | if t.startswith("##", 0, 2): 147 | length_t -= 2 148 | previous_id += length_t 149 | predicted_slot_ids = slot.tolist() 150 | labels_mask = example.features["labels_mask"] 151 | assert len(labels_mask) == len(predicted_slot_ids) 152 | slot_info = [] 153 | current_entityType = None 154 | index = -1 155 | # print(curLine(), len(predicted_slot_ids), "predicted_slot_ids:", predicted_slot_ids) 156 | for tokens, mask, slot_id in zip(input_tokens, labels_mask, predicted_slot_ids): 157 | index += 1 158 | if mask > 0: 159 | if tokens.startswith("##"): 160 | tokens = tokens[2:] 161 | elif "[UNK]" in tokens: # 处理UNK的情况 162 | previoud_id = token_index_map[index] #  unk对应word开始的位置 163 | next_previoud_id = previoud_id + 1 #  unk对应word结束的位置 164 | if index+1 in token_index_map: 165 | next_previoud_id = token_index_map[index+1] 166 | tokens = query[previoud_id:next_previoud_id] 167 | print(curLine(), "unk self.passage[%d,%d]=%s" % (previoud_id, next_previoud_id, tokens)) 168 | predict_slot = self.slot_id_2_tag[slot_id] 169 | # print(curLine(), tokens, mask, predict_slot) 170 | # 用规则增强对数字的识别 171 | if current_entityType == "phone_num" and "phone_num" not in predict_slot: # 正在phone_num的区间内 172 | token_numbers = re_phoneNum.findall(tokens) 173 | if len(token_numbers) > 0: 174 | first_span = token_numbers[0] 175 | if tokens.index(first_span) == 0: # 下一个wore piece仍然以数字开头,应该加到前面那个区间中 176 | slot_info.append(first_span) 177 | if len(first_span) < len(tokens): # 如果除了数字,还有其他部分则当作O来追加 178 | slot_info.append("") 179 | slot_info.append(tokens[len(first_span):]) 180 | current_entityType = None 181 | continue 182 | 183 | if predict_slot == "O": 184 | if current_entityType is not None: # 已经进入本区间 185 | slot_info.append("" % current_entityType) # 结束 186 | current_entityType = None 187 | slot_info.append(tokens) 188 | continue 189 | 190 | token_type = predict_slot[2:] # 当前token的类型 191 | if current_entityType is not None and token_type != current_entityType: # TODO 上一个区间还没结束 192 | slot_info.append("" % current_entityType) # 强行结束 193 | current_entityType = None 194 | if "B-" in predict_slot: 195 | if token_type != current_entityType: # 上一个本区间已经结束,则添加开始符号 196 | slot_info.append("<%s>" % token_type) 197 | slot_info.append(tokens) 198 | current_entityType = token_type 199 | elif "E-" in predict_slot: 200 | if token_type == current_entityType: # 已经进入本区间 201 | if token_type not in {"origin","destination","destination","phone_num","contact_name","singer"} \ 202 | or tokens not in {"的"}: # TODO 某些实体一般不会以这些字结尾 203 | slot_info.append(tokens) 204 | slot_info.append("" % token_type) # 正常情况,先添加再结束 205 | else: 206 | slot_info.append("" % token_type) # 先结束再添加 207 | slot_info.append(tokens) 208 | else: # 类型不符合,当作O处理 209 | slot_info.append(tokens) 210 | current_entityType = None 211 | else: # I 212 | if current_entityType != token_type and index+1 < len(predicted_slot_ids): # 前面没有B,直接以I开头.这时根据下一个tag决定 213 | next_predict_slot = self.slot_id_2_tag[predicted_slot_ids[index+1]] 214 | # print(curLine(), "next_predict_slot:%s, token_type=%s" % (next_predict_slot, token_type)) 215 | if token_type in next_predict_slot: 216 | slot_info.append("<%s>" % token_type) 217 | current_entityType = token_type 218 | slot_info.append(tokens) 219 | if current_entityType is not None: # 已经到末尾还没结束 220 | slot_info.append("" % current_entityType) # 强行结束 221 | if self.target_domain_name == "music": 222 | slot_info_str = get_slot_info_str_forMusic(slot_info, raw_query=query, entityTypeMap=example.entityTypeMap) 223 | else: 224 | slot_info_str = get_slot_info_str(slot_info, raw_query=query, entityTypeMap=example.entityTypeMap) 225 | return slot_info_str -------------------------------------------------------------------------------- /sari_hook.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """SARI score for evaluating paraphrasing and other text generation models. 17 | 18 | The score is introduced in the following paper: 19 | 20 | Optimizing Statistical Machine Translation for Text Simplification 21 | Wei Xu, Courtney Napoles, Ellie Pavlick, Quanze Chen and Chris Callison-Burch 22 | In Transactions of the Association for Computational Linguistics (TACL) 2015 23 | http://cs.jhu.edu/~napoles/res/tacl2016-optimizing.pdf 24 | 25 | This implementation has two differences with the GitHub [1] implementation: 26 | (1) Define 0/0=1 instead of 0 to give higher scores for predictions that match 27 | a target exactly. 28 | (2) Fix an alleged bug [2] in the deletion score computation. 29 | 30 | [1] https://github.com/cocoxu/simplification/blob/master/SARI.py 31 | (commit 0210f15) 32 | [2] https://github.com/cocoxu/simplification/issues/6 33 | """ 34 | 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function 38 | 39 | import collections 40 | 41 | import numpy as np 42 | import tensorflow as tf 43 | 44 | # The paper that intoduces the SARI score uses only the precision of the deleted 45 | # tokens (i.e. beta=0). To give more emphasis on recall, you may set, e.g., 46 | # beta=1. 47 | BETA_FOR_SARI_DELETION_F_MEASURE = 0 48 | 49 | 50 | def _get_ngram_counter(ids, n): 51 | """Get a Counter with the ngrams of the given ID list. 52 | 53 | Args: 54 | ids: np.array or a list corresponding to a single sentence 55 | n: n-gram size 56 | 57 | Returns: 58 | collections.Counter with ID tuples as keys and 1s as values. 59 | """ 60 | # Remove zero IDs used to pad the sequence. 61 | ids = [token_id for token_id in ids if token_id != 0] 62 | ngram_list = [tuple(ids[i:i + n]) for i in range(len(ids) + 1 - n)] 63 | ngrams = set(ngram_list) 64 | counts = collections.Counter() 65 | for ngram in ngrams: 66 | counts[ngram] = 1 67 | return counts 68 | 69 | 70 | def _get_fbeta_score(true_positives, selected, relevant, beta=1): 71 | """Compute Fbeta score. 72 | 73 | Args: 74 | true_positives: Number of true positive ngrams. 75 | selected: Number of selected ngrams. 76 | relevant: Number of relevant ngrams. 77 | beta: 0 gives precision only, 1 gives F1 score, and Inf gives recall only. 78 | 79 | Returns: 80 | Fbeta score. 81 | """ 82 | precision = 1 83 | if selected > 0: 84 | precision = true_positives / selected 85 | if beta == 0: 86 | return precision 87 | recall = 1 88 | if relevant > 0: 89 | recall = true_positives / relevant 90 | if precision > 0 and recall > 0: 91 | beta2 = beta * beta 92 | return (1 + beta2) * precision * recall / (beta2 * precision + recall) 93 | else: 94 | return 0 95 | 96 | 97 | def get_addition_score(source_counts, prediction_counts, target_counts): 98 | """Compute the addition score (Equation 4 in the paper).""" 99 | added_to_prediction_counts = prediction_counts - source_counts 100 | true_positives = sum((added_to_prediction_counts & target_counts).values()) 101 | selected = sum(added_to_prediction_counts.values()) 102 | # Note that in the paper the summation is done over all the ngrams in the 103 | # output rather than the ngrams in the following set difference. Since the 104 | # former does not make as much sense we compute the latter, which is also done 105 | # in the GitHub implementation. 106 | relevant = sum((target_counts - source_counts).values()) 107 | return _get_fbeta_score(true_positives, selected, relevant) 108 | 109 | 110 | def get_keep_score(source_counts, prediction_counts, target_counts): 111 | """Compute the keep score (Equation 5 in the paper).""" 112 | source_and_prediction_counts = source_counts & prediction_counts 113 | source_and_target_counts = source_counts & target_counts 114 | true_positives = sum((source_and_prediction_counts & 115 | source_and_target_counts).values()) 116 | selected = sum(source_and_prediction_counts.values()) 117 | relevant = sum(source_and_target_counts.values()) 118 | return _get_fbeta_score(true_positives, selected, relevant) 119 | 120 | 121 | def get_deletion_score(source_counts, prediction_counts, target_counts, beta=0): 122 | """Compute the deletion score (Equation 6 in the paper).""" 123 | source_not_prediction_counts = source_counts - prediction_counts 124 | source_not_target_counts = source_counts - target_counts 125 | true_positives = sum((source_not_prediction_counts & 126 | source_not_target_counts).values()) 127 | selected = sum(source_not_prediction_counts.values()) 128 | relevant = sum(source_not_target_counts.values()) 129 | return _get_fbeta_score(true_positives, selected, relevant, beta=beta) 130 | 131 | 132 | def get_sari_score(source_ids, prediction_ids, list_of_targets, 133 | max_gram_size=4, beta_for_deletion=0): 134 | """Compute the SARI score for a single prediction and one or more targets. 135 | 136 | Args: 137 | source_ids: a list / np.array of SentencePiece IDs 138 | prediction_ids: a list / np.array of SentencePiece IDs 139 | list_of_targets: a list of target ID lists / np.arrays 140 | max_gram_size: int. largest n-gram size we care about (e.g. 3 for unigrams, 141 | bigrams, and trigrams) 142 | beta_for_deletion: beta for deletion F score. 143 | 144 | Returns: 145 | the SARI score and its three components: add, keep, and deletion scores 146 | """ 147 | addition_scores = [] 148 | keep_scores = [] 149 | deletion_scores = [] 150 | for n in range(1, max_gram_size + 1): 151 | source_counts = _get_ngram_counter(source_ids, n) 152 | prediction_counts = _get_ngram_counter(prediction_ids, n) 153 | # All ngrams in the targets with count 1. 154 | target_counts = collections.Counter() 155 | # All ngrams in the targets with count r/num_targets, where r is the number 156 | # of targets where the ngram occurs. 157 | weighted_target_counts = collections.Counter() 158 | num_nonempty_targets = 0 159 | for target_ids_i in list_of_targets: 160 | target_counts_i = _get_ngram_counter(target_ids_i, n) 161 | if target_counts_i: 162 | weighted_target_counts += target_counts_i 163 | num_nonempty_targets += 1 164 | for gram in weighted_target_counts.keys(): 165 | weighted_target_counts[gram] /= num_nonempty_targets 166 | target_counts[gram] = 1 167 | keep_scores.append(get_keep_score(source_counts, prediction_counts, 168 | weighted_target_counts)) 169 | deletion_scores.append(get_deletion_score(source_counts, prediction_counts, 170 | weighted_target_counts, 171 | beta_for_deletion)) 172 | addition_scores.append(get_addition_score(source_counts, prediction_counts, 173 | target_counts)) 174 | 175 | avg_keep_score = sum(keep_scores) / max_gram_size 176 | avg_addition_score = sum(addition_scores) / max_gram_size 177 | avg_deletion_score = sum(deletion_scores) / max_gram_size 178 | sari = (avg_keep_score + avg_addition_score + avg_deletion_score) / 3.0 179 | return sari, avg_keep_score, avg_addition_score, avg_deletion_score 180 | 181 | 182 | def get_sari(source_ids, prediction_ids, target_ids, max_gram_size=4): 183 | """Computes the SARI scores from the given source, prediction and targets. 184 | 185 | Args: 186 | source_ids: A 2D tf.Tensor of size (batch_size , sequence_length) 187 | prediction_ids: A 2D tf.Tensor of size (batch_size, sequence_length) 188 | target_ids: A 3D tf.Tensor of size (batch_size, number_of_targets, 189 | sequence_length) 190 | max_gram_size: int. largest n-gram size we care about (e.g. 3 for unigrams, 191 | bigrams, and trigrams) 192 | 193 | Returns: 194 | A 4-tuple of 1D float Tensors of size (batch_size) for the SARI score and 195 | the keep, addition and deletion scores. 196 | """ 197 | 198 | def get_sari_numpy(source_ids, prediction_ids, target_ids): 199 | """Iterate over elements in the batch and call the SARI function.""" 200 | sari_scores = [] 201 | keep_scores = [] 202 | add_scores = [] 203 | deletion_scores = [] 204 | # Iterate over elements in the batch. 205 | for source_ids_i, prediction_ids_i, target_ids_i in zip( 206 | source_ids, prediction_ids, target_ids): 207 | sari, keep, add, deletion = get_sari_score( 208 | source_ids_i, prediction_ids_i, target_ids_i, max_gram_size, 209 | BETA_FOR_SARI_DELETION_F_MEASURE) 210 | sari_scores.append(sari) 211 | keep_scores.append(keep) 212 | add_scores.append(add) 213 | deletion_scores.append(deletion) 214 | return (np.asarray(sari_scores), np.asarray(keep_scores), 215 | np.asarray(add_scores), np.asarray(deletion_scores)) 216 | 217 | sari, keep, add, deletion = tf.py_func( 218 | get_sari_numpy, 219 | [source_ids, prediction_ids, target_ids], 220 | [tf.float64, tf.float64, tf.float64, tf.float64]) 221 | return sari, keep, add, deletion 222 | 223 | 224 | def sari_score(predictions, labels, features, **unused_kwargs): 225 | """Computes the SARI scores from the given source, prediction and targets. 226 | 227 | An approximate SARI scoring method since we do not glue word pieces or 228 | decode the ids and tokenize the output. By default, we use ngram order of 4. 229 | Also, this does not have beam search. 230 | 231 | Args: 232 | predictions: tensor, model predictions. 233 | labels: tensor, gold output. 234 | features: dict, containing inputs. 235 | 236 | Returns: 237 | sari: int, approx sari score 238 | """ 239 | if "inputs" not in features: 240 | raise ValueError("sari_score requires inputs feature") 241 | 242 | # Convert the inputs and outputs to a [batch_size, sequence_length] tensor. 243 | inputs = tf.squeeze(features["inputs"], axis=[-1, -2]) 244 | outputs = tf.to_int32(tf.argmax(predictions, axis=-1)) 245 | outputs = tf.squeeze(outputs, axis=[-1, -2]) 246 | 247 | # Convert the labels to a [batch_size, 1, sequence_length] tensor. 248 | labels = tf.squeeze(labels, axis=[-1, -2]) 249 | labels = tf.expand_dims(labels, axis=1) 250 | 251 | score, _, _, _ = get_sari(inputs, outputs, labels) 252 | return score, tf.constant(1.0) 253 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /run_lasertagger.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """BERT-based LaserTagger runner.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | 22 | from __future__ import print_function 23 | import os 24 | from absl import flags 25 | from termcolor import colored 26 | import tensorflow as tf 27 | import run_lasertagger_utils 28 | import utils 29 | from curLine_file import curLine 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | ## Required parameters 34 | flags.DEFINE_string("training_file", None, 35 | "Path to the TFRecord training file.") 36 | flags.DEFINE_string("eval_file", None, "Path to the the TFRecord dev file.") 37 | flags.DEFINE_string( 38 | "label_map_file", None, 39 | "Path to the label map file. Either a JSON file ending with '.json', that " 40 | "maps each possible tag to an ID, or a text file that has one tag per line.") 41 | flags.DEFINE_string( 42 | 'slot_label_map_file', None, 43 | 'Path to the label map file. Either a JSON file ending with ".json", that ' 44 | 'maps each possible tag to an ID, or a text file that has one tag per line.') 45 | flags.DEFINE_string( 46 | "model_config_file", None, 47 | "The config json file specifying the model architecture.") 48 | flags.DEFINE_string( 49 | "output_dir", None, 50 | "The output directory where the model checkpoints will be written. If " 51 | "`init_checkpoint' is not provided when exporting, the latest checkpoint " 52 | "from this directory will be exported.") 53 | 54 | ## Other parameters 55 | 56 | flags.DEFINE_string( 57 | "init_checkpoint", None, 58 | "Initial checkpoint, usually from a pre-trained BERT model. In the case of " 59 | "exporting, one can optionally provide path to a particular checkpoint to " 60 | "be exported here.") 61 | flags.DEFINE_integer( 62 | "max_seq_length", 128, # contain CLS and SEP 63 | "The maximum total input sequence length after WordPiece tokenization. " 64 | "Sequences longer than this will be truncated, and sequences shorter than " 65 | "this will be padded.") 66 | 67 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 68 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 69 | flags.DEFINE_bool("do_export", False, "Whether to export a trained model.") 70 | flags.DEFINE_bool("eval_all_checkpoints", False, "Run through all checkpoints.") 71 | flags.DEFINE_integer( 72 | "eval_timeout", 600, 73 | "The maximum amount of time (in seconds) for eval worker to wait between " 74 | "checkpoints.") 75 | 76 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 77 | flags.DEFINE_integer("eval_batch_size", 128, "Total batch size for eval.") 78 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.") 79 | flags.DEFINE_float("learning_rate", 3e-5, "The initial learning rate for Adam.") 80 | flags.DEFINE_float( 81 | "warmup_proportion", 0.1, 82 | "Proportion of training to perform linear learning rate warmup for. " 83 | "E.g., 0.1 = 10% of training.") 84 | 85 | flags.DEFINE_float("num_train_epochs", 3.0, 86 | "Total number of training epochs to perform.") 87 | 88 | flags.DEFINE_float("drop_keep_prob", 0.9, "drop_keep_prob") 89 | flags.DEFINE_integer("kernel_size", 2, "kernel_size") 90 | 91 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 92 | "How often to save the model checkpoint.") 93 | flags.DEFINE_integer("keep_checkpoint_max", 3, 94 | "How many checkpoints to keep.") 95 | flags.DEFINE_integer("iterations_per_loop", 1000, 96 | "How many steps to make in each estimator call.") 97 | flags.DEFINE_integer( 98 | "num_train_examples", None, 99 | "Number of training examples. This is used to determine the number of " 100 | "training steps to respect the `num_train_epochs` flag.") 101 | flags.DEFINE_integer( 102 | "num_eval_examples", None, 103 | "Number of eval examples. This is used to determine the number of " 104 | "eval steps to go through the eval file once.") 105 | 106 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 107 | flags.DEFINE_string( 108 | "tpu_name", None, 109 | "The Cloud TPU to use for training. This should be either the name " 110 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 111 | "url.") 112 | flags.DEFINE_string( 113 | "tpu_zone", None, 114 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 115 | "specified, we will attempt to automatically detect the GCE project from " 116 | "metadata.") 117 | flags.DEFINE_string( 118 | "gcp_project", None, 119 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 120 | "specified, we will attempt to automatically detect the GCE project from " 121 | "metadata.") 122 | flags.DEFINE_string("master", None, 123 | "Optional address of the master for the workers.") 124 | flags.DEFINE_string("export_path", None, "Path to save the exported model.") 125 | flags.DEFINE_float("slot_ratio", 0.3, "slot_ratio") 126 | 127 | flags.DEFINE_string( 128 | 'domain_name', None, 129 | 'Whether to lower case the input text. Should be True for uncased ' 130 | 'models and False for cased models.') 131 | flags.DEFINE_string( 132 | 'entity_type_list_file', None, 'path of entity_type_list_file') 133 | 134 | def file_based_input_fn_builder(input_file, max_seq_length, 135 | is_training, drop_remainder, entity_type_num): 136 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 137 | 138 | name_to_features = { 139 | "input_ids": tf.FixedLenFeature([max_seq_length], tf.int64), 140 | "input_mask": tf.FixedLenFeature([max_seq_length], tf.int64), 141 | "segment_ids": tf.FixedLenFeature([max_seq_length], tf.int64), 142 | "labels": tf.FixedLenFeature([], tf.int64), 143 | "slot_labels": tf.FixedLenFeature([max_seq_length], tf.int64), 144 | "labels_mask": tf.FixedLenFeature([max_seq_length], tf.int64), 145 | "entity_type_ids": tf.FixedLenFeature([max_seq_length*entity_type_num], tf.int64), 146 | "sequence_lengths": tf.FixedLenFeature([], tf.int64), 147 | } 148 | 149 | def _decode_record(record, name_to_features): 150 | """Decodes a record to a TensorFlow example.""" 151 | example = tf.parse_single_example(record, name_to_features) 152 | 153 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 154 | # So cast all int64 to int32. 155 | for name in list(example.keys()): 156 | t = example[name] 157 | if t.dtype == tf.int64: 158 | t = tf.to_int32(t) 159 | example[name] = t 160 | 161 | return example 162 | 163 | def input_fn(params): 164 | """The actual input function.""" 165 | d = tf.data.TFRecordDataset(input_file) 166 | # For training, we want a lot of parallel reading and shuffling. 167 | # For eval, we want no shuffling and parallel reading doesn't matter. 168 | if is_training: 169 | d = d.repeat() # 每epoch重复使用 170 | d = d.shuffle(buffer_size=100) 171 | d = d.apply( 172 | tf.contrib.data.map_and_batch( 173 | lambda record: _decode_record(record, name_to_features), 174 | batch_size=params["batch_size"], 175 | drop_remainder=drop_remainder)) 176 | return d 177 | 178 | return input_fn 179 | 180 | 181 | def _calculate_steps(num_examples, batch_size, num_epochs, warmup_proportion=0): 182 | """Calculates the number of steps. 183 | 184 | Args: 185 | num_examples: Number of examples in the dataset. 186 | batch_size: Batch size. 187 | num_epochs: How many times we should go through the dataset. 188 | warmup_proportion: Proportion of warmup steps. 189 | 190 | Returns: 191 | Tuple (number of steps, number of warmup steps). 192 | """ 193 | steps = int(num_examples / batch_size * num_epochs) 194 | warmup_steps = int(warmup_proportion * steps) 195 | return steps, warmup_steps 196 | 197 | 198 | def main(_): 199 | tf.logging.set_verbosity(tf.logging.INFO) 200 | if not (FLAGS.do_train or FLAGS.do_eval or FLAGS.do_export): 201 | raise ValueError("At least one of `do_train`, `do_eval` or `do_export` must be True.") 202 | 203 | model_config = run_lasertagger_utils.LaserTaggerConfig.from_json_file( 204 | FLAGS.model_config_file) 205 | 206 | if FLAGS.max_seq_length > model_config.max_position_embeddings: 207 | raise ValueError( 208 | "Cannot use sequence length %d because the BERT model " 209 | "was only trained up to sequence length %d" % 210 | (FLAGS.max_seq_length, model_config.max_position_embeddings)) 211 | 212 | if not FLAGS.do_export and not os.path.exists(FLAGS.output_dir): 213 | os.makedirs(FLAGS.output_dir) 214 | 215 | num_tags = len(utils.read_label_map(FLAGS.label_map_file)) 216 | num_slot_tags = len(utils.read_label_map(FLAGS.slot_label_map_file)) 217 | entity_type_list = utils.read_label_map(FLAGS.entity_type_list_file)[FLAGS.domain_name] 218 | 219 | tpu_cluster_resolver = None 220 | if FLAGS.use_tpu and FLAGS.tpu_name: 221 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 222 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 223 | 224 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 225 | run_config = tf.contrib.tpu.RunConfig( 226 | cluster=tpu_cluster_resolver, 227 | master=FLAGS.master, 228 | model_dir=FLAGS.output_dir, 229 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 230 | keep_checkpoint_max=FLAGS.keep_checkpoint_max) # , 231 | # tpu_config=tf.contrib.tpu.TPUConfig( 232 | # iterations_per_loop=FLAGS.iterations_per_loop, 233 | # per_host_input_for_training=is_per_host, 234 | # eval_training_input_configuration=tf.contrib.tpu.InputPipelineConfig.SLICED)) 235 | 236 | if FLAGS.do_train: 237 | num_train_steps, num_warmup_steps = _calculate_steps( 238 | FLAGS.num_train_examples, FLAGS.train_batch_size, 239 | FLAGS.num_train_epochs, FLAGS.warmup_proportion) 240 | else: 241 | num_train_steps, num_warmup_steps = None, None 242 | entity_type_num = len(entity_type_list) 243 | print(curLine(), "num_train_steps=",num_train_steps, "learning_rate=", FLAGS.learning_rate, 244 | ",num_warmup_steps=",num_warmup_steps, ",max_seq_length=", FLAGS.max_seq_length) 245 | print(colored("%s restore from init_checkpoint:%s" % (curLine(), FLAGS.init_checkpoint), "red")) 246 | model_fn = run_lasertagger_utils.ModelFnBuilder( 247 | config=model_config, 248 | num_tags=num_tags, 249 | num_slot_tags=num_slot_tags, 250 | init_checkpoint=FLAGS.init_checkpoint, 251 | learning_rate=FLAGS.learning_rate, 252 | num_train_steps=num_train_steps, 253 | num_warmup_steps=num_warmup_steps, 254 | use_tpu=FLAGS.use_tpu, 255 | use_one_hot_embeddings=FLAGS.use_tpu, 256 | max_seq_length=FLAGS.max_seq_length, 257 | drop_keep_prob=FLAGS.drop_keep_prob, 258 | entity_type_num=entity_type_num, 259 | slot_ratio=FLAGS.slot_ratio).build() 260 | 261 | # If TPU is not available, this will fall back to normal Estimator on CPU or GPU. 262 | estimator = tf.contrib.tpu.TPUEstimator( 263 | use_tpu=FLAGS.use_tpu, 264 | model_fn=model_fn, 265 | config=run_config, 266 | train_batch_size=FLAGS.train_batch_size, 267 | eval_batch_size=FLAGS.eval_batch_size, 268 | predict_batch_size=FLAGS.predict_batch_size 269 | ) 270 | 271 | if FLAGS.do_train: 272 | train_input_fn = file_based_input_fn_builder( 273 | input_file=FLAGS.training_file, 274 | max_seq_length=FLAGS.max_seq_length, 275 | is_training=True, 276 | drop_remainder=True, 277 | entity_type_num=entity_type_num) 278 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 279 | 280 | if FLAGS.do_export: 281 | tf.logging.info("Exporting the model to %s"% FLAGS.export_path) 282 | def serving_input_fn(): 283 | def _input_fn(): 284 | features = { 285 | "input_ids": tf.placeholder(tf.int64, [None, None]), 286 | "input_mask": tf.placeholder(tf.int64, [None, None]), 287 | "segment_ids": tf.placeholder(tf.int64, [None, None]), 288 | "entity_type_ids": tf.placeholder(tf.int64, [None, None]), 289 | "sequence_lengths": tf.placeholder(tf.int64, [None]), 290 | } 291 | return tf.estimator.export.ServingInputReceiver( 292 | features=features, receiver_tensors=features) 293 | 294 | return _input_fn 295 | if not os.path.exists(FLAGS.export_path): 296 | print(curLine(), "will make dir:%s" % FLAGS.export_path) 297 | os.makedirs(FLAGS.export_path) 298 | 299 | estimator.export_saved_model( 300 | FLAGS.export_path, 301 | serving_input_fn(), 302 | checkpoint_path=FLAGS.init_checkpoint) 303 | 304 | 305 | if __name__ == "__main__": 306 | flags.mark_flag_as_required("model_config_file") 307 | flags.mark_flag_as_required("label_map_file") 308 | flags.mark_flag_as_required("entity_type_list_file") 309 | tf.app.run() 310 | -------------------------------------------------------------------------------- /run_lasertagger_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Utilities for building a LaserTagger TF model.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | 22 | from __future__ import print_function 23 | from bert import modeling 24 | from bert import optimization 25 | import tensorflow as tf 26 | from tensorflow.nn.rnn_cell import LSTMCell 27 | from curLine_file import curLine 28 | 29 | class LaserTaggerConfig(modeling.BertConfig): 30 | """Model configuration for LaserTagger.""" 31 | 32 | def __init__(self, 33 | use_t2t_decoder=True, 34 | decoder_num_hidden_layers=1, 35 | decoder_hidden_size=768, 36 | decoder_num_attention_heads=4, 37 | decoder_filter_size=3072, 38 | use_full_attention=False, 39 | **kwargs): 40 | """Initializes an instance of LaserTagger configuration. 41 | 42 | This initializer expects both the BERT specific arguments and the 43 | Transformer decoder arguments listed below. 44 | 45 | Args: 46 | use_t2t_decoder: Whether to use the Transformer decoder (i.e. 47 | LaserTagger_AR). If False, the remaining args do not affect anything and 48 | can be set to default values. 49 | decoder_num_hidden_layers: Number of hidden decoder layers. 50 | decoder_hidden_size: Decoder hidden size. 51 | decoder_num_attention_heads: Number of decoder attention heads. 52 | decoder_filter_size: Decoder filter size. 53 | use_full_attention: Whether to use full encoder-decoder attention. 54 | **kwargs: The arguments that the modeling.BertConfig initializer expects. 55 | """ 56 | super(LaserTaggerConfig, self).__init__(**kwargs) 57 | self.use_t2t_decoder = use_t2t_decoder 58 | self.decoder_num_hidden_layers = decoder_num_hidden_layers 59 | self.decoder_hidden_size = decoder_hidden_size 60 | self.decoder_num_attention_heads = decoder_num_attention_heads 61 | self.decoder_filter_size = decoder_filter_size 62 | self.use_full_attention = use_full_attention 63 | 64 | 65 | class ModelFnBuilder(object): 66 | """Class for building `model_fn` closure for TPUEstimator.""" 67 | 68 | def __init__(self, config, num_tags, num_slot_tags, 69 | init_checkpoint, 70 | learning_rate, num_train_steps, 71 | num_warmup_steps, use_tpu, 72 | use_one_hot_embeddings, max_seq_length, drop_keep_prob, entity_type_num, slot_ratio): 73 | """Initializes an instance of a LaserTagger model. 74 | 75 | Args: 76 | config: LaserTagger model configuration. 77 | num_tags: Number of different tags to be predicted. 78 | init_checkpoint: Path to a pretrained BERT checkpoint (optional). 79 | learning_rate: Learning rate. 80 | num_train_steps: Number of training steps. 81 | num_warmup_steps: Number of warmup steps. 82 | use_tpu: Whether to use TPU. 83 | use_one_hot_embeddings: Whether to use one-hot embeddings for word 84 | embeddings. 85 | max_seq_length: Maximum sequence length. 86 | """ 87 | self._config = config 88 | self._num_tags = num_tags 89 | self.num_slot_tags = num_slot_tags 90 | self._init_checkpoint = init_checkpoint 91 | self._learning_rate = learning_rate 92 | self._num_train_steps = num_train_steps 93 | self._num_warmup_steps = num_warmup_steps 94 | self._use_tpu = use_tpu 95 | self._use_one_hot_embeddings = use_one_hot_embeddings 96 | self._max_seq_length = max_seq_length 97 | self.drop_keep_prob = drop_keep_prob 98 | self.slot_ratio = slot_ratio 99 | self.intent_ratio = 1.0-self.slot_ratio 100 | self.entity_type_num = entity_type_num 101 | self.lstm_hidden_size = 128 102 | 103 | def _create_model(self, mode, input_ids, input_mask, segment_ids, labels, 104 | slot_labels, labels_mask, drop_keep_prob, entity_type_ids, sequence_lengths): 105 | """Creates a LaserTagger model.""" 106 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 107 | model = modeling.BertModel( 108 | config=self._config, 109 | is_training=is_training, 110 | input_ids=input_ids, 111 | input_mask=input_mask, 112 | token_type_ids=segment_ids, 113 | use_one_hot_embeddings=self._use_one_hot_embeddings) 114 | 115 | final_layer = model.get_sequence_output() 116 | # final_hidden = model.get_pooled_output() 117 | 118 | if is_training: 119 | # I.e., 0.1 dropout 120 | # final_hidden = tf.nn.dropout(final_hidden, keep_prob=drop_keep_prob) 121 | final_layer = tf.nn.dropout(final_layer, keep_prob=drop_keep_prob) 122 | 123 | # 结合实体信息 124 | batch_size, seq_length = modeling.get_shape_list(input_ids) 125 | 126 | self.entity_type_embedding = tf.get_variable(name="entity_type_embedding", 127 | shape=(self.entity_type_num, self._config.hidden_size), 128 | dtype=tf.float32, trainable=True, 129 | initializer=tf.random_uniform_initializer( 130 | -self._config.initializer_range*100, 131 | self._config.initializer_range*100, seed=20)) 132 | 133 | with tf.init_scope(): 134 | impact_weight_init = tf.constant(1.0 / self.entity_type_num, dtype=tf.float32, shape=(1, self.entity_type_num)) 135 | self.impact_weight = tf.Variable( 136 | impact_weight_init, dtype=tf.float32, name="impact_weight") # 不同类型的影响权重 137 | impact_weight_matrix = tf.tile(self.impact_weight, multiples=[batch_size * seq_length, 1]) 138 | 139 | entity_type_ids_matrix1 = tf.cast(tf.reshape(entity_type_ids, [batch_size * seq_length, self.entity_type_num]), dtype=tf.float32) 140 | entity_type_ids_matrix = tf.multiply(entity_type_ids_matrix1, impact_weight_matrix) 141 | entity_type_emb = tf.matmul(entity_type_ids_matrix, self.entity_type_embedding) 142 | final_layer = final_layer + tf.reshape(entity_type_emb, [batch_size, seq_length, self._config.hidden_size]) # TODO TODO # 0.7071067811865476是二分之根号二 143 | # final_layer = tf.concat([final_layer, tf.reshape(entity_type_emb, [batch_size, seq_length,self._config.hidden_size])], axis=-1) 144 | 145 | if is_training: 146 | final_layer = tf.nn.dropout(final_layer, keep_prob=drop_keep_prob) 147 | 148 | (output_fw_seq, output_bw_seq), ((c_fw,h_fw),(c_bw,h_bw)) = tf.nn.bidirectional_dynamic_rnn( 149 | cell_fw=LSTMCell(self.lstm_hidden_size), 150 | cell_bw=LSTMCell(self.lstm_hidden_size), 151 | inputs=final_layer, 152 | sequence_length=sequence_lengths, 153 | dtype=tf.float32) 154 | layer_matrix = tf.concat([output_fw_seq, output_bw_seq], axis=-1) 155 | final_hidden = tf.concat([c_fw, c_bw], axis=-1) 156 | 157 | layer_matrix= tf.contrib.layers.layer_norm( 158 | inputs=layer_matrix, begin_norm_axis=-1, begin_params_axis=-1) 159 | 160 | 161 | intent_logits = tf.layers.dense( 162 | final_hidden, 163 | self._num_tags, 164 | kernel_initializer=tf.truncated_normal_initializer(stddev=0.02), 165 | name="output_projection") 166 | slot_logits = tf.layers.dense(layer_matrix, self.num_slot_tags, 167 | kernel_initializer=tf.truncated_normal_initializer(stddev=0.02), name="slot_projection") 168 | 169 | 170 | with tf.variable_scope("loss"): 171 | loss = None 172 | per_example_intent_loss = None 173 | per_example_slot_loss = None 174 | if mode != tf.estimator.ModeKeys.PREDICT: 175 | per_example_intent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( 176 | labels=labels, logits=intent_logits) 177 | slot_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( 178 | labels=slot_labels, logits=slot_logits) 179 | per_example_slot_loss = tf.truediv( 180 | tf.reduce_sum(slot_loss, axis=1), 181 | tf.cast(tf.reduce_sum(labels_mask, axis=1), tf.float32)) 182 | 183 | # from tensorflow.contrib.crf import crf_log_likelihood 184 | # from tensorflow.contrib.crf import viterbi_decode 185 | # batch_size = tf.shape(slot_logits)[0] 186 | # print(curLine(), batch_size, tf.constant([self._max_seq_length])) 187 | # length_batch = tf.tile(tf.constant([self._max_seq_length]), [batch_size]) 188 | # print(curLine(), batch_size, "length_batch:", length_batch) 189 | # per_example_slot_loss, self.transition_params = crf_log_likelihood(inputs=slot_logits, 190 | # tag_indices=slot_labels,sequence_lengths=length_batch) 191 | # print(curLine(), "per_example_slot_loss:", per_example_slot_loss) # shape=(batch_size,) 192 | # print(curLine(), "self.transition_params:", self.transition_params) # shape=(9, 9) 193 | 194 | loss = tf.reduce_mean(self.intent_ratio*per_example_intent_loss + self.slot_ratio*per_example_slot_loss) 195 | pred_intent = tf.cast(tf.argmax(intent_logits, axis=-1), tf.int32) 196 | pred_slot = tf.cast(tf.argmax(slot_logits, axis=-1), tf.int32) 197 | return (loss, per_example_slot_loss, pred_intent, pred_slot, 198 | batch_size, entity_type_emb, impact_weight_matrix, entity_type_ids_matrix, final_layer, slot_logits) 199 | 200 | def build(self): 201 | """Returns `model_fn` closure for TPUEstimator.""" 202 | 203 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 204 | """The `model_fn` for TPUEstimator.""" 205 | 206 | tf.logging.info("*** Features ***") 207 | for name in sorted(features.keys()): 208 | tf.logging.info(" name = %s, shape = %s", name, features[name].shape) 209 | input_ids = features["input_ids"] 210 | input_mask = features["input_mask"] 211 | segment_ids = features["segment_ids"] 212 | labels = None 213 | slot_labels = None 214 | labels_mask = None 215 | if mode != tf.estimator.ModeKeys.PREDICT: 216 | labels = features["labels"] 217 | slot_labels = features["slot_labels"] 218 | labels_mask = features["labels_mask"] 219 | (total_loss, per_example_loss, pred_intent, pred_slot, 220 | batch_size, entity_type_emb, impact_weight_matrix, entity_type_ids_matrix, final_layer, slot_logits) = self._create_model( 221 | mode, input_ids, input_mask, segment_ids, labels, slot_labels, labels_mask, self.drop_keep_prob, 222 | features["entity_type_ids"], sequence_lengths=features["sequence_lengths"]) 223 | 224 | tvars = tf.trainable_variables() 225 | initialized_variable_names = {} 226 | scaffold_fn = None 227 | if self._init_checkpoint: 228 | (assignment_map, initialized_variable_names 229 | ) = modeling.get_assignment_map_from_checkpoint(tvars, 230 | self._init_checkpoint) 231 | if self._use_tpu: 232 | def tpu_scaffold(): 233 | tf.train.init_from_checkpoint(self._init_checkpoint, assignment_map) 234 | return tf.train.Scaffold() 235 | 236 | scaffold_fn = tpu_scaffold 237 | else: 238 | tf.train.init_from_checkpoint(self._init_checkpoint, assignment_map) 239 | 240 | tf.logging.info("**** Trainable Variables ****") 241 | # for var in tvars: 242 | # tf.logging.info("Initializing the model from: %s", 243 | # self._init_checkpoint) 244 | # init_string = "" 245 | # if var.name in initialized_variable_names: 246 | # init_string = ", *INIT_FROM_CKPT*" 247 | # tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 248 | # init_string) 249 | 250 | output_spec = None 251 | if mode == tf.estimator.ModeKeys.TRAIN: 252 | train_op = optimization.create_optimizer( 253 | total_loss, self._learning_rate, self._num_train_steps, 254 | self._num_warmup_steps, self._use_tpu) 255 | 256 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 257 | mode=mode, 258 | loss=total_loss, 259 | train_op=train_op, 260 | scaffold_fn=scaffold_fn) 261 | 262 | elif mode == tf.estimator.ModeKeys.EVAL: 263 | def metric_fn(per_example_loss, labels, labels_mask, predictions): 264 | """Compute eval metrics.""" 265 | accuracy = tf.cast( 266 | tf.reduce_all( # tf.reduce_all 相当于"逻辑AND"操作,找到输出完全正确的才算正确 267 | tf.logical_or( 268 | tf.equal(labels, predictions), 269 | ~tf.cast(labels_mask, tf.bool)), 270 | axis=1), tf.float32) 271 | return { 272 | # This is equal to the Exact score if the final realization step 273 | # doesn't introduce errors. 274 | "sentence_level_acc": tf.metrics.mean(accuracy), 275 | "eval_loss": tf.metrics.mean(per_example_loss), 276 | } 277 | 278 | eval_metrics = (metric_fn, 279 | [per_example_loss, labels, labels_mask, pred_intent]) 280 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 281 | mode=mode, 282 | loss=total_loss, 283 | eval_metrics=eval_metrics, 284 | scaffold_fn=scaffold_fn) 285 | else: 286 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 287 | mode=mode, predictions={'pred_intent': pred_intent, 'pred_slot': pred_slot, 288 | 'batch_size':batch_size, 'entity_type_emb':entity_type_emb, 'impact_weight_matrix':impact_weight_matrix, 'entity_type_ids_matrix':entity_type_ids_matrix, 289 | 'final_layer':final_layer, 'slot_logits':slot_logits}, 290 | # 'intent_logits':intent_logits, 'entity_type_ids_matrix':entity_type_ids_matrix}, 291 | scaffold_fn=scaffold_fn) 292 | return output_spec 293 | 294 | return model_fn 295 | -------------------------------------------------------------------------------- /predict_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Compute realized predictions for a dataset.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | from absl import app 24 | from absl import flags 25 | from absl import logging 26 | import math, time 27 | from termcolor import colored 28 | import tensorflow as tf 29 | 30 | import bert_example 31 | import predict_utils 32 | import csv 33 | import utils 34 | from find_entity import exacter_acmation 35 | from curLine_file import curLine, normal_transformer, other_tag 36 | 37 | FLAGS = flags.FLAGS 38 | flags.DEFINE_string( 39 | 'input_file', None, 40 | 'Path to the input file containing examples for which to compute ' 41 | 'predictions.') 42 | flags.DEFINE_enum( 43 | 'input_format', None, ['wikisplit', 'discofuse','qa', 'nlu'], 44 | 'Format which indicates how to parse the input_file.') 45 | flags.DEFINE_string( 46 | 'output_file', None, 47 | 'Path to the CSV file where the predictions are written to.') 48 | flags.DEFINE_string( 49 | 'submit_file', None, 50 | 'Path to the CSV file where the predictions are written to.') 51 | flags.DEFINE_string( 52 | 'label_map_file', None, 53 | 'Path to the label map file. Either a JSON file ending with ".json", that ' 54 | 'maps each possible tag to an ID, or a text file that has one tag per line.') 55 | flags.DEFINE_string( 56 | 'slot_label_map_file', None, 57 | 'Path to the label map file. Either a JSON file ending with ".json", that ' 58 | 'maps each possible tag to an ID, or a text file that has one tag per line.') 59 | flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.') 60 | flags.DEFINE_integer('max_seq_length', 128, 'Maximum sequence length.') 61 | flags.DEFINE_bool( 62 | 'do_lower_case', False, 63 | 'Whether to lower case the input text. Should be True for uncased ' 64 | 'models and False for cased models.') 65 | flags.DEFINE_string('saved_model', None, 'Path to an exported TF model.') 66 | flags.DEFINE_string( 67 | 'domain_name', None, 68 | 'Whether to lower case the input text. Should be True for uncased ' 69 | 'models and False for cased models.') 70 | flags.DEFINE_string( 71 | 'entity_type_list_file', None, 'path of entity_type_list_file') 72 | 73 | def main(argv): 74 | if len(argv) > 1: 75 | raise app.UsageError('Too many command-line arguments.') 76 | flags.mark_flag_as_required('input_file') 77 | flags.mark_flag_as_required('input_format') 78 | flags.mark_flag_as_required('output_file') 79 | flags.mark_flag_as_required('label_map_file') 80 | flags.mark_flag_as_required('vocab_file') 81 | flags.mark_flag_as_required('saved_model') 82 | label_map = utils.read_label_map(FLAGS.label_map_file) 83 | slot_label_map = utils.read_label_map(FLAGS.slot_label_map_file) 84 | target_domain_name = FLAGS.domain_name 85 | print(curLine(), "target_domain_name:", target_domain_name) 86 | assert target_domain_name in ["navigation", "phone_call", "music"] 87 | entity_type_list = utils.read_label_map(FLAGS.entity_type_list_file)[FLAGS.domain_name] 88 | 89 | builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file, 90 | FLAGS.max_seq_length, 91 | FLAGS.do_lower_case, slot_label_map=slot_label_map, 92 | entity_type_list=entity_type_list, get_entity_func=exacter_acmation.get_all_entity) 93 | predictor = predict_utils.LaserTaggerPredictor( 94 | tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder, 95 | label_map, slot_label_map, target_domain_name=target_domain_name) 96 | print(colored("%s saved_model:%s" % (curLine(), FLAGS.saved_model), "red")) 97 | 98 | ##### test 99 | print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red")) 100 | 101 | 102 | domain_list = [] 103 | slot_info_list = [] 104 | intent_list = [] 105 | 106 | predict_domain_list = [] 107 | previous_pred_slot_list = [] 108 | previous_pred_intent_list = [] 109 | sources_list = [] 110 | predict_batch_size = 64 111 | limit = predict_batch_size * 1500 # 5184 # 10001 # 112 | with tf.gfile.GFile(FLAGS.input_file) as f: 113 | reader = csv.reader(f) 114 | session_list = [] 115 | for row_id, line in enumerate(reader): 116 | if len(line) == 1: 117 | line = line[0].strip().split("\t") 118 | if len(line) > 4: # 有标注 119 | (sessionId, raw_query, predDomain, predIntent, predSlot, domain, intent, slot) = line 120 | domain_list.append(domain) 121 | intent_list.append(intent) 122 | slot_info_list.append(slot) 123 | else: 124 | (sessionId, raw_query, predDomainIntent, predSlot) = line 125 | if "." in predDomainIntent: 126 | predDomain,predIntent = predDomainIntent.split(".") 127 | else: 128 | predDomain,predIntent = predDomainIntent, predDomainIntent 129 | if "忘记电话" in raw_query: 130 | predDomain = "phone_call" # rule 131 | if "专用道" in raw_query: 132 | predDomain = "navigation" # rule 133 | predict_domain_list.append(predDomain) 134 | previous_pred_slot_list.append(predSlot) 135 | previous_pred_intent_list.append(predIntent) 136 | query = normal_transformer(raw_query) 137 | if query != raw_query: 138 | print(curLine(), len(query), "query: ", query) 139 | print(curLine(), len(raw_query), "raw_query:", raw_query) 140 | 141 | sources = [] 142 | if row_id > 0 and sessionId == session_list[row_id - 1][0]: 143 | sources.append(session_list[row_id - 1][1]) # last query 144 | sources.append(query) 145 | session_list.append((sessionId, raw_query)) 146 | sources_list.append(sources) 147 | 148 | if len(sources_list) >= limit: 149 | print(colored("%s stop reading at %d to save time" %(curLine(), limit), "red")) 150 | break 151 | 152 | number = len(sources_list) # 总样本数 153 | 154 | predict_intent_list = [] 155 | predict_slot_list = [] 156 | predict_batch_size = min(predict_batch_size, number) 157 | batch_num = math.ceil(float(number) / predict_batch_size) 158 | start_time = time.time() 159 | num_predicted = 0 160 | modemode = 'a' 161 | if len(domain_list) > 0: # 有标注 162 | modemode = 'w' 163 | with tf.gfile.Open(FLAGS.output_file, modemode) as writer: 164 | # if len(domain_list) > 0: # 有标注 165 | # writer.write("\t".join(["sessionId", "query", "predDomain", "predIntent", "predSlot", "domain", "intent", "Slot"]) + "\n") 166 | for batch_id in range(batch_num): 167 | # if batch_id <= 48: 168 | # continue 169 | sources_batch = sources_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size] 170 | predict_domain_batch = predict_domain_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size] 171 | predict_intent_batch, predict_slot_batch = predictor.predict_batch(sources_batch=sources_batch, target_domain_name=target_domain_name, predict_domain_batch=predict_domain_batch) 172 | assert len(predict_intent_batch) == len(sources_batch) 173 | num_predicted += len(predict_intent_batch) 174 | for id, [predict_intent, predict_slot_info, sources] in enumerate(zip(predict_intent_batch, predict_slot_batch, sources_batch)): 175 | sessionId, raw_query = session_list[batch_id * predict_batch_size + id] 176 | predict_domain = predict_domain_list[batch_id * predict_batch_size + id] 177 | # if predict_domain == "music": 178 | # predict_slot_info = raw_query 179 | # if predict_intent == "play": # 模型分类到播放意图,但没有找到槽位,这时用ac自动机提高召回 180 | # predict_intent_rule, predict_slot_info = rules(raw_query, predict_domain, target_domain_name) 181 | # # if predict_intent_rule in {"pause", "next"}: 182 | # # predict_intent = predict_intent_rule 183 | # if "<" in predict_slot_info_rule : # and "<" not in predict_slot_info: 184 | # predict_slot_info = predict_slot_info_rule 185 | # print(curLine(), "predict_slot_info_rule:", predict_slot_info_rule) 186 | # print(curLine()) 187 | 188 | if predict_domain != target_domain_name: # 不是当前模型的domain,用规则识别 189 | predict_intent = previous_pred_intent_list[batch_id * predict_batch_size + id] 190 | predict_slot_info = previous_pred_slot_list[batch_id * predict_batch_size + id] 191 | # else: 192 | # print(curLine(), predict_intent, "predict_slot_info:", predict_slot_info) 193 | predict_intent_list.append(predict_intent) 194 | predict_slot_list.append(predict_slot_info) 195 | if len(domain_list) > 0: # 有标注 196 | domain = domain_list[batch_id * predict_batch_size + id] 197 | intent = intent_list[batch_id * predict_batch_size + id] 198 | slot = slot_info_list[batch_id * predict_batch_size + id] 199 | domain_flag = "right" 200 | if domain != predict_domain: 201 | domain_flag = "wrong" 202 | writer.write("\t".join([sessionId, raw_query, predict_domain, predict_intent, predict_slot_info, domain, intent, slot]) + "\n") # , domain_flag 203 | if batch_id % 5 == 0: 204 | cost_time = (time.time() - start_time) / 60.0 205 | print("%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin." % 206 | (curLine(), batch_id + 1, batch_num, num_predicted, number, cost_time)) 207 | cost_time = (time.time() - start_time) / 60.0 208 | print( 209 | f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time/num_predicted*60} s.') 210 | 211 | 212 | if FLAGS.submit_file is not None: 213 | import collections, os 214 | domain_counter = collections.Counter() 215 | if os.path.exists(path=FLAGS.submit_file): 216 | os.remove(FLAGS.submit_file) 217 | with open(FLAGS.submit_file, 'w',encoding='UTF-8') as f: 218 | writer = csv.writer(f, dialect='excel') 219 | # writer.writerow(["session_id", "query", "intent", "slot_annotation"]) # TODO 220 | for example_id, sources in enumerate(sources_list): 221 | sessionId, raw_query = session_list[example_id] 222 | predict_domain = predict_domain_list[example_id] 223 | predict_intent = predict_intent_list[example_id] 224 | predict_domain_intent = other_tag 225 | domain_counter.update([predict_domain]) 226 | slot = raw_query 227 | if predict_domain != other_tag: 228 | predict_domain_intent = "%s.%s" % (predict_domain, predict_intent) 229 | slot = predict_slot_list[example_id] 230 | # if predict_domain == "navigation": # TODO TODO 231 | # predict_domain_intent = other_tag 232 | # slot = raw_query 233 | line = [sessionId, raw_query, predict_domain_intent, slot] 234 | writer.writerow(line) 235 | print(curLine(), "example_id=", example_id) 236 | print(curLine(), "domain_counter:", domain_counter) 237 | cost_time = (time.time() - start_time) / 60.0 238 | num_predicted = example_id+1 239 | print(curLine(), "%s cost %f s" % (target_domain_name, cost_time)) 240 | print( 241 | f'{curLine()} {num_predicted} predictions saved to:{FLAGS.submit_file}, cost {cost_time} min, ave {cost_time/num_predicted*60} s.') 242 | 243 | 244 | 245 | def rules(raw_query, predict_domain, target_domain_name): 246 | predict_intent = predict_domain # OTHERS 247 | slot_info = raw_query 248 | if predict_domain == "navigation": 249 | predict_intent = 'navigation' 250 | if "打开" in raw_query: 251 | predict_intent = "open" 252 | elif "开始" in raw_query: 253 | predict_intent = "start_navigation" 254 | for word in predict_utils.cancel_keywords: 255 | if word in raw_query: 256 | predict_intent = 'cancel_navigation' 257 | break 258 | # slot_info = raw_query 259 | # if predict_intent == 'navigation': TODO 260 | slot_info = exacter_acmation.get_slot_info(raw_query, domain=predict_domain) 261 | # if predict_intent != 'navigation': # TODO 262 | # print(curLine(), "slot_info:", slot_info) 263 | elif predict_domain == 'music': 264 | predict_intent = 'play' 265 | for word in predict_utils.cancel_keywords: 266 | if word in raw_query: 267 | predict_intent = 'pause' 268 | break 269 | for word in ["下一", "换一首", "换一曲", "切歌", "其他歌"]: 270 | if word in raw_query: 271 | predict_intent = 'next' 272 | break 273 | slot_info = exacter_acmation.get_slot_info(raw_query, domain=predict_domain) 274 | if predict_intent not in ['play','pause'] and slot_info != raw_query: # 根据槽位修改意图  换一首高安红尘情歌 275 | print(curLine(), predict_intent, slot_info) 276 | predict_intent = 'play' 277 | # if predict_intent != 'play': # 换一首高安红尘情歌 278 | # print(curLine(), predict_intent, slot_info) 279 | elif predict_domain == 'phone_call': 280 | predict_intent = 'make_a_phone_call' 281 | for word in predict_utils.cancel_keywords: 282 | if word in raw_query: 283 | predict_intent = 'cancel' 284 | break 285 | slot_info = exacter_acmation.get_slot_info(raw_query, domain=predict_domain) 286 | return predict_intent, slot_info 287 | 288 | if __name__ == '__main__': 289 | app.run(main) 290 | -------------------------------------------------------------------------------- /confusion_words_danyin.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 混淆词提取算法 不考虑多音字,只考虑单音,故只处理同音字的情况 3 | from pypinyin import lazy_pinyin, pinyin, Style 4 | from Levenshtein import distance # python-Levenshtein 5 | # 编辑距离 (Levenshtein Distance算法) 6 | import copy 7 | import sys 8 | import os, json 9 | from find_entity.acmation import entity_files_folder, entity_folder 10 | from curLine_file import curLine 11 | 12 | number_map = {"0":"十", "1":"一", "2":"二", "3":"三", "4":"四", "5":"五", "6":"六", "7":"七", "8":"八", "9":"九"} 13 | cons_table = {'n': 'l', 'l': 'n', 'f': 'h', 'h': 'f', 'zh': 'z', 'z': 'zh', 'c': 'ch', 'ch': 'c', 's': 'sh', 'sh': 's'} 14 | vowe_table = {'ai': 'ei', 'an': 'ang', 'en': 'eng', 'in': 'ing', 15 | 'ei': 'ai', 'ang': 'an', 'eng': 'en', 'ing': 'in'} 16 | stop_chars = ["的", "了", "啦", "着","们", "呀", "啊", "地", "呢", ".", ",", ",", "。", "-", "_,", "(", ")", "(", "在", "儿"] # TODO 儿 17 | 18 | # 非标准拼音 http://www.pinyin.info/rules/initials_finals.html 19 | w_pinyin_map = {"wu":"u", "wa":"ua", "wo":"uo", "wai":"uai", "wei":"ui", "wan":"uan", "wang":"uang", "wen":"un", "weng":"ueng"} 20 | y_pinyin_map = {"yi":"i", "ya":"ia", "ye":"ie", "yao":"iao", "you":"iu", "yan":"ian", "yang":"iang", "yin":"in","ying":"ing", "yong":"iong"} 21 | def my_lazy_pinyin(word, style=Style.NORMAL, errors='default', strict=True): 22 | rtn_list = lazy_pinyin(word, style=style, errors=errors, strict=strict) 23 | # 非标准拼音 24 | if style == Style.NORMAL: # 返回整体拼音的列表 25 | for id, total_pinyin in enumerate(rtn_list): 26 | if total_pinyin[0]=="w" and total_pinyin in w_pinyin_map: 27 | rtn_list[id] = w_pinyin_map[total_pinyin] 28 | elif total_pinyin[0]=="y" and total_pinyin in y_pinyin_map: 29 | rtn_list[id] = y_pinyin_map[total_pinyin] 30 | elif style == Style.INITIALS: # 只返回声母的列表 31 | total_pinyin_list = lazy_pinyin(word, style=Style.NORMAL, errors=errors, strict=strict) 32 | for id, (shengmu, total_pinyin) in enumerate(zip(rtn_list, total_pinyin_list)): 33 | if total_pinyin[0]=="w" and total_pinyin in w_pinyin_map: 34 | rtn_list[id] = "" 35 | elif total_pinyin[0]=="y" and total_pinyin in y_pinyin_map: 36 | rtn_list[id] = "" 37 | else: 38 | raise Exception("wrong style") 39 | return rtn_list 40 | 41 | 42 | def add_func(entity_info_dict, entity_type, raw_entity, entity_variation, char_distance): 43 | shengmu_list, yunmu_list, pinyin_list = my_pinyin(word=entity_variation) 44 | assert len(shengmu_list) == len(yunmu_list) == len(pinyin_list) 45 | combination = "".join(pinyin_list) 46 | # if len(combination) < 3: 47 | # print(curLine(), "warning:", raw_entity, "combination:", combination) 48 | entity_info_dict[raw_entity]["combination_list"].append([combination, 0, char_distance, entity_variation]) 49 | base_distance = 1.0 # 修改一个声母或韵母导致的距离,小于1 50 | for shengmu_index, shengmu in enumerate(shengmu_list): 51 | if shengmu not in cons_table: # 没有混淆对 52 | continue 53 | shengmu_list_variation = copy.copy(shengmu_list) 54 | shengmu_list_variation[shengmu_index] = cons_table[shengmu] # 修改一个声母 55 | combination_variation = "".join( 56 | ["%s%s" % (shengmu, yunmu) for shengmu, yunmu in zip(shengmu_list_variation, yunmu_list)]) 57 | entity_info_dict[raw_entity]["combination_list"].append([combination_variation, base_distance, char_distance, entity_variation]) 58 | if entity_type in ["song"]: 59 | for yunmu_index, yunmu in enumerate(yunmu_list): 60 | if yunmu not in vowe_table: # 没有混淆对 61 | continue 62 | yunmu_list_variation = copy.copy(yunmu_list) 63 | yunmu_list_variation[yunmu_index] = vowe_table[yunmu] # 修改一个韵母 64 | combination_variation = "".join( 65 | ["%s%s" % (shengmu, yunmu) for shengmu, yunmu in zip(shengmu_list, yunmu_list_variation)]) 66 | entity_info_dict[raw_entity]["combination_list"].append([combination_variation, base_distance, char_distance, entity_variation]) 67 | return 68 | 69 | def add_pinyin(raw_entity, entity_info_dict, pri, entity_type): 70 | entity = raw_entity 71 | for k, v in number_map.items(): 72 | entity = entity.replace(k, v) 73 | if raw_entity not in entity_info_dict: # 新的实体 74 | entity_info_dict[raw_entity] = {"priority": pri, "combination_list": []} 75 | add_func(entity_info_dict, entity_type, raw_entity, entity_variation=entity, char_distance=0) 76 | for char in stop_chars: 77 | if char in entity: 78 | entity_variation = entity.replace(char, "") 79 | char_distance = len(entity) - len(entity_variation) 80 | add_func(entity_info_dict, entity_type, raw_entity, entity_variation, char_distance=char_distance) 81 | else: 82 | old_pri = entity_info_dict[raw_entity]["priority"] 83 | if pri > old_pri: 84 | entity_info_dict[raw_entity]["priority"] = pri 85 | return 86 | 87 | def get_entityType_pinyin(entity_type): 88 | entity_info_dict = {} 89 | entity_file = os.path.join(entity_folder, "%s.txt" % entity_type) 90 | with open(entity_file, "r") as fr: 91 | lines = fr.readlines() 92 | priority = 3 93 | if entity_type in ["song"]: 94 | priority -= 0.5 95 | print(curLine(), "get %d %s from %s, priority=%f" % (len(lines), entity_type, entity_file, priority)) 96 | for line in lines: 97 | raw_entity = line.strip() 98 | add_pinyin(raw_entity, entity_info_dict, priority, entity_type) 99 | 100 | ### TODO 从标注语料中挖掘得到 101 | entity_file = os.path.join(entity_files_folder, "%s.json" % entity_type) 102 | with open(entity_file, "r") as fr: 103 | current_entity_dict = json.load(fr) 104 | print(curLine(), "get %d %s from %s, priority=%f" % (len(current_entity_dict), entity_type, entity_file, priority)) 105 | for entity_before, entity_after_times in current_entity_dict.items(): 106 | entity_after = entity_after_times[0] 107 | priority = 4 108 | if entity_type in ["song"]: 109 | priority -= 0.5 110 | add_pinyin(entity_after, entity_info_dict, priority, entity_type) 111 | return entity_info_dict 112 | 113 | # 用编辑距离度量拼音字符串之间的相似度 不考虑相似实体的多音情况 114 | def pinyin_similar_word_danyin(entity_info_dict, word, jichu_distance=0.05, char_ratio=0.55, char_distance=1.0): 115 | if word in entity_info_dict: # 存在实体,无需纠错 116 | return 1.0, word 117 | best_similar_word = None 118 | top_similar_score = 0 119 | if "0" in word: # 这里先尝试替换成零,后面会尝试替换成十的情况 120 | word_ling = word.replace("0", "零") 121 | top_similar_score_ling, best_similar_word_ling = pinyin_similar_word_danyin( 122 | entity_info_dict, word_ling, jichu_distance, char_ratio, char_distance) 123 | if top_similar_score_ling >= 0.99999: 124 | return top_similar_score_ling, best_similar_word_ling 125 | elif top_similar_score_ling >= top_similar_score: 126 | top_similar_score = top_similar_score_ling 127 | best_similar_word = best_similar_word_ling 128 | for k,v in number_map.items(): 129 | word = word.replace(k, v) 130 | if word in entity_info_dict: # 存在实体,无需纠错 131 | return 1.0, word 132 | all_combination = [["".join(my_lazy_pinyin(word)), 0, 0, word]] 133 | for char in stop_chars: 134 | if char in word: 135 | entity_variation = word.replace(char, "") 136 | c = (len(word) - len(entity_variation)) *char_distance # TODO 137 | all_combination.append( ["".join(my_lazy_pinyin(entity_variation)), 0, c, entity_variation] ) 138 | 139 | for current_combination, basebase, c1, entity_variation in all_combination: # 当前的各种发音 140 | if len(current_combination) == 0: 141 | continue 142 | similar_word = None 143 | current_distance = sys.maxsize 144 | for entity, priority_comList in entity_info_dict.items(): 145 | priority = priority_comList["priority"] 146 | com_list = priority_comList["combination_list"] 147 | for com, a, c, entity_variation_ku in com_list: 148 | d = basebase + jichu_distance*a + char_distance*c+c1 - priority*0.01 + \ 149 | distance(com, current_combination)*(1.0-char_ratio) + distance(entity_variation_ku,entity_variation) * char_ratio 150 | if d < current_distance: 151 | current_distance = d 152 | similar_word = entity 153 | current_similar_score = 1.0 - float(current_distance) / len(current_combination) 154 | if current_similar_score > top_similar_score: 155 | best_similar_word = similar_word 156 | top_similar_score = current_similar_score 157 | return top_similar_score, best_similar_word 158 | 159 | # 自己包装的函数,返回字符的声母(可能为空,如啊呀),韵母,整体拼音 160 | def my_pinyin(word): 161 | shengmu_list = my_lazy_pinyin(word, style=Style.INITIALS, strict=False) 162 | total_pinyin_list = my_lazy_pinyin(word, errors='default') 163 | yunmu_list = [] 164 | for shengmu, total_pinyin in zip(shengmu_list, total_pinyin_list): 165 | yunmu = total_pinyin[total_pinyin.index(shengmu) + len(shengmu):] 166 | yunmu_list.append(yunmu) 167 | return shengmu_list, yunmu_list, total_pinyin_list 168 | 169 | 170 | singer_pinyin = get_entityType_pinyin(entity_type="singer") 171 | print(curLine(), len(singer_pinyin), "singer_pinyin") 172 | 173 | song_pinyin = get_entityType_pinyin(entity_type="song") 174 | print(curLine(), len(song_pinyin), "song_pinyin") 175 | 176 | def correct_song(entity_before, jichu_distance=0.001, char_ratio=0.53, char_distance=0.1): # char_ratio=0.48): 177 | top_similar_score, best_similar_word = pinyin_similar_word_danyin( 178 | song_pinyin, entity_before, jichu_distance, char_ratio, char_distance) 179 | return top_similar_score, best_similar_word 180 | 181 | def correct_singer(entity_before, jichu_distance=0.001, char_ratio=0.5, char_distance=0.1): # char_ratio=0.1, char_distance=1.0): 182 | top_similar_score, best_similar_word = pinyin_similar_word_danyin( 183 | singer_pinyin, entity_before, jichu_distance, char_ratio, char_distance) 184 | return top_similar_score, best_similar_word 185 | 186 | 187 | 188 | def test(entity_type, pinyin_ku, score_threshold, jichu_distance, char_ratio=0.0, char_distance=0.1): # 测试纠错的准确率 189 | # 把给的实体库加到集合中 190 | entity_set = set() 191 | entity_file = os.path.join(entity_folder, "%s.txt" % entity_type) 192 | with open(entity_file, "r") as fr: 193 | lines = fr.readlines() 194 | for line in lines: 195 | raw_entity = line.strip() 196 | entity_set.add(raw_entity) 197 | # 抽取的实体作为测试集合 198 | entity_file = os.path.join(entity_files_folder, "%s.json" % entity_type) 199 | with open(entity_file, "r") as fr: 200 | current_entity_dict = json.load(fr) 201 | test_num = 0.0 202 | wrong_num = 0.0 203 | right_num = 0.0 204 | for entity_before, entity_after_times in current_entity_dict.items(): 205 | entity_after = entity_after_times[0] 206 | test_num += 1 207 | top_similar_score, best_similar_word = pinyin_similar_word_danyin(pinyin_ku, entity_before, jichu_distance, char_ratio, char_distance) 208 | predict_after = entity_before 209 | if top_similar_score > score_threshold: 210 | predict_after = best_similar_word 211 | if predict_after != entity_after: 212 | wrong_num += 1 213 | elif entity_before != predict_after: # 纠正对了 214 | right_num += 1 215 | return wrong_num, test_num, right_num 216 | 217 | 218 | if __name__ == "__main__": 219 | for word in ["霍元甲", "梁梁"]: 220 | a = correct_song(entity_before=word, jichu_distance=1.0, char_ratio=0.48) 221 | print(curLine(), word, a) 222 | for word in ["霍元甲", "m c天佑的歌曲", "v.a"]: 223 | a = correct_singer(entity_before=word, jichu_distance=1.0, char_ratio=0.1) 224 | print(curLine(), word, a) 225 | 226 | s = "啊饿恩昂你为什*.\"123c单身" 227 | s = "啊饿恩昂你为什单身的万瓦呢?" 228 | s = "打扰一样呆大衣虽四有挖" 229 | strict = False 230 | # yunmu_list = my_lazy_pinyin(s, style=Style.FINALS, strict=strict) # 根据声母(拼音的前半部分)得到韵母 231 | shengmu_list = my_lazy_pinyin(s, style=Style.INITIALS, strict=strict) # 返回声母,但是严格按照标准(strict=True)有些字没有声母会返回空字符串,strict=False时会返回 232 | pinyin_list = my_lazy_pinyin(s, errors='default') # ""ignore") 233 | print(curLine(), len(shengmu_list), "shengmu_list:", shengmu_list) 234 | # print(curLine(), len(yunmu_list), "yunmu_list:", yunmu_list) 235 | print(curLine(), len(pinyin_list), "pinyin_list:", pinyin_list) 236 | print(curLine(), my_pinyin(s)) 237 | for id, c in enumerate(s): 238 | print(curLine(), id, c, my_pinyin(c)) 239 | 240 | 241 | 242 | best_score = -sys.maxsize 243 | pinyin_ku = None 244 | 245 | 246 | 247 | # # # # song 248 | entity_type = "song" 249 | if entity_type == "song": 250 | pinyin_ku = song_pinyin 251 | elif entity_type == "singer": 252 | pinyin_ku = singer_pinyin 253 | for char_distance in [0.1]: 254 | for char_ratio in [0.4, 0.53, 0.55, 0.68]: # 255 | for score_threshold in [0.45, 0.65]: # 256 | for jichu_distance in [0.001]: # [0.005, 0.01, 0.03, 0.05]:# 257 | wrong_num, test_num, right_num = test(entity_type, pinyin_ku, score_threshold, jichu_distance, char_ratio, char_distance=char_distance) 258 | score = 0.01*right_num - wrong_num 259 | print("char_distance=%f, threshold=%f, jichu_distance=%f, char_ratio=%f, wrong_num=%d, right_num=%d, score=%f" % 260 | (char_distance, score_threshold, jichu_distance, char_ratio, wrong_num, right_num, score)) 261 | if score > best_score: 262 | print(curLine(), "score=%f, best_score=%f" % (score, best_score)) 263 | best_score = score 264 | 265 | # # singer 266 | # entity_type = "singer" #  267 | # if entity_type == "song": 268 | # pinyin_ku = song_pinyin 269 | # elif entity_type == "singer": 270 | # pinyin_ku = singer_pinyin 271 | # for char_distance in [0.05, 0.1]: 272 | # for char_ratio in [0.5]: 273 | # for score_threshold in [0.45, 0.65]: 274 | # for jichu_distance in [0.001]: # [0.005, 0.01, 0.03, 0.05]:# 275 | # wrong_num, test_num, right_num = test(entity_type, pinyin_ku, score_threshold, jichu_distance, char_ratio, char_distance) 276 | # score = 0.01*right_num - wrong_num 277 | # print("char_distance=%f, threshold=%f, jichu_distance=%f, char_ratio=%f, wrong_num=%d, right_num=%d, score=%f" % 278 | # (char_distance, score_threshold, jichu_distance, char_ratio, wrong_num, right_num, score)) 279 | # if score > best_score: 280 | # print(curLine(), "score=%f, best_score=%f" % (score, best_score)) 281 | # best_score = score 282 | 283 | --------------------------------------------------------------------------------