├── find_entity
├── __init__.py
├── probable_acmation.py
├── exacter_acmation.py
└── acmation.py
├── __init__.py
├── official_transformer
├── __init__.py
├── README
├── model_params.py
├── ffn_layer.py
├── embedding_layer.py
├── model_utils.py
├── tpu.py
└── attention_layer.py
├── entity_files
├── slot-dictionaries
│ ├── custom_destination.txt
│ ├── instrument.txt
│ ├── language.txt
│ ├── toplist.txt
│ ├── style.txt
│ ├── emotion.txt
│ ├── scene.txt
│ ├── theme.txt
│ └── age.txt
├── custom_destination.json
├── scene.json
├── age.json
├── instrument.json
├── language.json
├── toplist.json
├── origin.json
├── emotion.json
├── theme.json
├── style.json
└── frequentSinger.json
├── configs
├── lasertagger_config.json
└── lasertagger_config-tiny.json
├── requirements.txt
├── README.md
├── curLine_file.py
├── .gitignore
├── score_main.py
├── utils.py
├── confusion_words_duoyin.py
├── split_corpus.py
├── start.sh
├── score_lib.py
├── preprocess_main.py
├── confusion_words.py
├── predict_param.py
├── predict_utils.py
├── sari_hook.py
├── LICENSE
├── run_lasertagger.py
├── run_lasertagger_utils.py
├── predict_main.py
└── confusion_words_danyin.py
/find_entity/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | import compute_lcs
--------------------------------------------------------------------------------
/official_transformer/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/entity_files/slot-dictionaries/custom_destination.txt:
--------------------------------------------------------------------------------
1 | 你家
2 | 公司
3 | 家
4 |
--------------------------------------------------------------------------------
/entity_files/custom_destination.json:
--------------------------------------------------------------------------------
1 | {
2 | "公司": [
3 | "公司",
4 | 28
5 | ],
6 | "家": [
7 | "家",
8 | 81
9 | ]
10 | }
--------------------------------------------------------------------------------
/entity_files/scene.json:
--------------------------------------------------------------------------------
1 | {
2 | "酒吧": [
3 | "酒吧",
4 | 3
5 | ],
6 | "咖啡厅": [
7 | "咖啡厅",
8 | 3
9 | ],
10 | "k歌": [
11 | "k歌",
12 | 1
13 | ]
14 | }
--------------------------------------------------------------------------------
/entity_files/age.json:
--------------------------------------------------------------------------------
1 | {
2 | "儿歌": [
3 | "儿歌",
4 | 36
5 | ],
6 | "二零一一": [
7 | "二零一一",
8 | 1
9 | ],
10 | "零零后": [
11 | "零零后",
12 | 1
13 | ]
14 | }
--------------------------------------------------------------------------------
/entity_files/instrument.json:
--------------------------------------------------------------------------------
1 | {
2 | "古筝": [
3 | "古筝",
4 | 5
5 | ],
6 | "锁那": [
7 | "唢呐",
8 | 2
9 | ],
10 | "唢呐": [
11 | "唢呐",
12 | 3
13 | ],
14 | "萨克斯": [
15 | "萨克斯",
16 | 4
17 | ]
18 | }
--------------------------------------------------------------------------------
/official_transformer/README:
--------------------------------------------------------------------------------
1 | This directory contains Transformer related code files. These are copied from:
2 | https://github.com/tensorflow/models/tree/master/official/transformer
3 | to make it possible to install all LaserTagger dependencies with pip (the
4 | official Transformer implementation doesn't support pip installation).
--------------------------------------------------------------------------------
/entity_files/slot-dictionaries/instrument.txt:
--------------------------------------------------------------------------------
1 | 钢琴曲
2 | 琵琶曲
3 | 吉他
4 | 管弦
5 | 吉他曲
6 | 管弦乐
7 | 缸琴
8 | 交响
9 | harmonica
10 | 葫芦丝
11 | 唢呐
12 | 竖琴
13 | piano
14 | 小提琴
15 | 笛子
16 | 二胡
17 | 钢琴
18 | 大提琴
19 | 吉它
20 | 风琴
21 | 手风琴
22 | 口琴
23 | 笛音
24 | guitar
25 | 琵琶
26 | 交响乐
27 | 古筝
28 | 萨克斯
29 | 刚琴
30 | 港琴
31 |
--------------------------------------------------------------------------------
/entity_files/slot-dictionaries/language.txt:
--------------------------------------------------------------------------------
1 | 日语
2 | 粤语
3 | 韩国
4 | 华语
5 | 那美克星语
6 | 德语
7 | 印度语
8 | 韩语
9 | 英文
10 | 巫族语
11 | 泰国
12 | 外语
13 | 普通话
14 | 葡萄牙
15 | 阿拉伯语
16 | 汉语
17 | 香港话
18 | 中国
19 | 欧美
20 | 法文
21 | 外国
22 | 泰语
23 | 法国
24 | 日本
25 | 中文
26 | 闽南语
27 | 因语
28 | 法语
29 | 小语种
30 | 国语
31 | 藏族
32 | 闽南
33 | 蒙语
34 | 欧美英语
35 | 日文
36 | 葡萄牙语
37 | 阴语
38 | 印度
39 | 英语
40 | 广东话
41 | 韩文
42 |
--------------------------------------------------------------------------------
/configs/lasertagger_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "hidden_dropout_prob": 0.1,
5 | "hidden_size": 768,
6 | "initializer_range": 0.02,
7 | "intermediate_size": 3072,
8 | "max_position_embeddings": 512,
9 | "num_attention_heads": 12,
10 | "num_hidden_layers": 11,
11 | "type_vocab_size": 2,
12 | "vocab_size": 21128,
13 | "decoder_num_hidden_layers": 1,
14 | "decoder_hidden_size": 768,
15 | "decoder_num_attention_heads": 4,
16 | "decoder_filter_size": 3072,
17 | "use_full_attention": false
18 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow==1.15.2 # CPU Version of TensorFlow.
2 | # tensorflow-gpu==1.15.0 # GPU version of TensorFlow.
3 | absl-py==0.8.1
4 | astor==0.8.0
5 | bert-tensorflow==1.0.1
6 | gast==0.2.2
7 | google-pasta==0.1.7
8 | grpcio==1.24.3
9 | h5py==2.10.0
10 | Keras-Applications==1.0.8
11 | Keras-Preprocessing==1.1.0
12 | Markdown==3.1.1
13 | numpy==1.17.3
14 | pkg-resources==0.0.0
15 | protobuf==3.10.0
16 | scipy==1.3.1
17 | six==1.12.0
18 | tensorboard==1.15.0
19 | tensorflow-estimator==1.15.1
20 | termcolor==1.1.0
21 | Werkzeug==0.16.0
22 | wrapt==1.11.2
23 | flask
24 |
--------------------------------------------------------------------------------
/entity_files/language.json:
--------------------------------------------------------------------------------
1 | {
2 | "藏族": [
3 | "藏族",
4 | 1
5 | ],
6 | "华语": [
7 | "华语",
8 | 1
9 | ],
10 | "中文": [
11 | "中文",
12 | 7
13 | ],
14 | "国语": [
15 | "国语",
16 | 3
17 | ],
18 | "闽南语": [
19 | "闽南语",
20 | 4
21 | ],
22 | "粤语": [
23 | "粤语",
24 | 7
25 | ],
26 | "欧美": [
27 | "欧美",
28 | 1
29 | ],
30 | "闽南": [
31 | "闽南",
32 | 2
33 | ],
34 | "韩国": [
35 | "韩国",
36 | 1
37 | ]
38 | }
--------------------------------------------------------------------------------
/configs/lasertagger_config-tiny.json:
--------------------------------------------------------------------------------
1 | {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "hidden_dropout_prob": 0.1,
5 | "hidden_size": 312,
6 | "initializer_range": 0.02,
7 | "intermediate_size": 1248,
8 | "max_position_embeddings": 512,
9 | "num_attention_heads": 12,
10 | "num_hidden_layers": 4,
11 | "type_vocab_size": 2,
12 | "vocab_size": 8021,
13 | "use_t2t_decoder": true,
14 | "decoder_num_hidden_layers": 1,
15 | "decoder_hidden_size": 768,
16 | "decoder_num_attention_heads": 4,
17 | "decoder_filter_size": 3072,
18 | "use_full_attention": false
19 | }
20 |
--------------------------------------------------------------------------------
/entity_files/toplist.json:
--------------------------------------------------------------------------------
1 | {
2 | "流行": [
3 | "流行",
4 | 8
5 | ],
6 | "最新": [
7 | "最新",
8 | 7
9 | ],
10 | "好歌": [
11 | "好歌",
12 | 2
13 | ],
14 | "好听": [
15 | "好听",
16 | 5
17 | ],
18 | "网络": [
19 | "网络",
20 | 8
21 | ],
22 | "最近": [
23 | "最近",
24 | 1
25 | ],
26 | "最流行": [
27 | "最流行",
28 | 1
29 | ],
30 | "热门": [
31 | "热门",
32 | 2
33 | ],
34 | "热热热": [
35 | "热热热",
36 | 1
37 | ],
38 | "新歌": [
39 | "新歌",
40 | 2
41 | ],
42 | "港台歌曲": [
43 | "港台歌曲",
44 | 1
45 | ]
46 | }
--------------------------------------------------------------------------------
/entity_files/slot-dictionaries/toplist.txt:
--------------------------------------------------------------------------------
1 | 摩登
2 | 榜单
3 | 日韩
4 | 时髦
5 | 港台歌曲
6 | 风靡
7 | 第一首
8 | 出道时
9 | pop
10 | 热门
11 | 高品质音乐榜
12 | 最好听
13 | 最老
14 | 时兴
15 | 美国itunes榜
16 | vivo手机高品质音乐榜
17 | billboard
18 | 好听
19 | 怀旧华语
20 | 抢手
21 | 盛行
22 | qq音乐榜单
23 | 巅峰榜
24 | 风行
25 | 第一个
26 | 最老一个
27 | 最老一张
28 | 网络流行
29 | 网易音乐
30 | 第一张
31 | 新
32 | 最新
33 | 香港商台榜
34 | 大陆
35 | 受欢迎
36 | 韩国mnet榜
37 | 最受欢迎
38 | 榜首
39 | 日本公信榜
40 | 美国公告牌榜
41 | 热歌
42 | 最流行
43 | 流行
44 | 热热热
45 | 台湾幽浮榜
46 | 最近
47 | 动听
48 | 怀旧粤语
49 | 英国uk榜
50 | 新歌
51 | 播放最多
52 | 首张
53 | 时新
54 | 公告牌
55 | 香港电台榜
56 | 最火
57 | qq音乐
58 | 最热门
59 | 网络
60 | 火热
61 | 最热
62 | 网络热歌
63 | 云音乐
64 | 网易云音乐
65 | 怀旧英语
66 | 好歌
67 | 港台
68 | 好
69 | 内地
70 |
--------------------------------------------------------------------------------
/entity_files/origin.json:
--------------------------------------------------------------------------------
1 | {
2 | "江苏": [
3 | "江苏",
4 | 1
5 | ],
6 | "袁浦中学": [
7 | "袁浦中学",
8 | 1
9 | ],
10 | "购物公园": [
11 | "购物公园",
12 | 1
13 | ],
14 | "红安": [
15 | "红安",
16 | 1
17 | ],
18 | "海思厂": [
19 | "海思厂",
20 | 1
21 | ],
22 | "凯仕厂": [
23 | "凯仕厂",
24 | 1
25 | ],
26 | "万达花园": [
27 | "万达花园",
28 | 1
29 | ],
30 | "遵义": [
31 | "遵义",
32 | 1
33 | ],
34 | "石马巷": [
35 | "石马巷",
36 | 1
37 | ],
38 | "宾馆路": [
39 | "宾馆路",
40 | 1
41 | ],
42 | "海椒市": [
43 | "海椒市",
44 | 1
45 | ],
46 | "205国道": [
47 | "205国道",
48 | 1
49 | ]
50 | }
--------------------------------------------------------------------------------
/entity_files/emotion.json:
--------------------------------------------------------------------------------
1 | {
2 | "励志": [
3 | "励志",
4 | 1
5 | ],
6 | "伤感": [
7 | "伤感",
8 | 2
9 | ],
10 | "嗨曲": [
11 | "嗨曲",
12 | 1
13 | ],
14 | "怀旧": [
15 | "怀旧",
16 | 1
17 | ],
18 | "快歌": [
19 | "快歌",
20 | 3
21 | ],
22 | "寂寞": [
23 | "寂寞",
24 | 1
25 | ],
26 | "劲爆": [
27 | "劲爆",
28 | 6
29 | ],
30 | "柔情": [
31 | "柔情",
32 | 1
33 | ],
34 | "火爆": [
35 | "火爆",
36 | 1
37 | ],
38 | "温柔": [
39 | "温柔",
40 | 1
41 | ],
42 | "欢快": [
43 | "欢快",
44 | 4
45 | ],
46 | "抒情": [
47 | "抒情",
48 | 3
49 | ],
50 | "慢歌": [
51 | "慢歌",
52 | 1
53 | ],
54 | "猛烈": [
55 | "猛烈",
56 | 1
57 | ]
58 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 音乐领域的意图和填槽,以及槽值纠错方案
2 | AI研习社比赛 对话系统中的口语理解 第一名
3 |
4 | 1.意图和填槽方案
5 | 模型是RoBERT+BiLSTM模型
6 | 通过AC自动机发现输入句子中的疑似实体,然后在输入特征中融入疑似实体的信息,从而提高模型的泛化能力;
7 | 2.槽值纠错方案
8 | 有些请求中会出现错别字,如果错别字出现在槽值中,则前面的填槽模型根据句式语义识别到的结果,后台服务也无法找到对应的歌曲。
9 | 因此,需要结合实体库和拼音特征对识别到的槽值进行纠正。
10 |
11 | 经过分析,语料中错误的槽值分为以下几种情况: ||的前后分别是纠错前后的槽值
12 | 1).同音字 例如:妈妈巴巴巴||妈妈爸爸;江湖大盗||江湖大道;摩托谣||摩托摇
13 | 2).错字 苍天一声笑||沧海一声笑;火锅调料||火锅底料
14 | 3).声母错了 体验||体面;道别的爱给特别的你||特别的爱给特别的你
15 | 4).韵母错 唱之歌||蝉之歌;说三就剩||说散就散;没有你陪伴真的很孤单||没有你陪伴真的好孤单
16 | 5).缺字 爱你千百回||一生爱你千百回;九百九十朵玫瑰||九百九十九朵玫瑰
17 | 6).表达不一致 1234||一二三四
18 | 7).停用字或标点 兄弟想你啦||兄弟想你了
19 | 8).差异很大 我想带你去土耳其 带你去旅行
20 | 采用拼音的编辑距离即可纠正1)情况;在得到拼音后替换容易混淆的声母对即可纠正3)情况;
21 | 在得到拼音后替换容易混淆的韵母对即可纠正4)情况;对实体或槽值中的数字转化为对应汉字即可纠正6)情况;
22 | 对实体或槽值做去停用字的操作即可纠正7)情况;其它2),5),8)的情况不容易纠正,具体能否纠正就看拼音的编辑距离和设定的阈值。
23 |
24 | 运行实验的方法:
25 | bash start.sh host_name
26 | 注:因为我是Linux系统下实验的,host_name是用户名。
27 | 请大家根据自身情况,自行修改脚本start.sh中的文件目录。
28 |
29 | ## License
30 |
31 | Apache 2.0; see [LICENSE](LICENSE) for details.
32 |
--------------------------------------------------------------------------------
/entity_files/slot-dictionaries/style.txt:
--------------------------------------------------------------------------------
1 | b box
2 | 动感炫舞
3 | 含蓄
4 | 狂野
5 | 电子
6 | 文雅
7 | 乡间
8 | 拉丁
9 | 典雅
10 | 蓝草
11 | hip-hop
12 | 迪斯可
13 | 嘻哈风
14 | 嘻哈说唱
15 | 舞曲
16 | 风雅
17 | 斯卡
18 | 山歌
19 | 节奏布鲁斯
20 | 嘻哈
21 | 俚歌
22 | 大合唱
23 | 跳舞
24 | 广场舞
25 | 高雅
26 | 街舞
27 | 古风
28 | 男女合唱
29 | 大雅
30 | hip hop
31 | 民族风
32 | 高贵
33 | 精雅
34 | 摇滚乐
35 | 轻音乐
36 | 农村
37 | 舞蹈
38 | 英伦风
39 | 对唱
40 | 黑暗
41 | 电音
42 | 古老
43 | 乡村
44 | r&b
45 | 爵士
46 | 草原风
47 | 男女对唱
48 | 歌谣
49 | rapping
50 | 风谣
51 | 拉锯电音
52 | 摇滚
53 | 民谣
54 | 拉丁舞
55 | 劲歌
56 | 对歌
57 | 饶舌
58 | 后摇
59 | 朋克
60 | 串烧
61 | 雅致
62 | 狂放
63 | 民间
64 | 草原
65 | beatboxing
66 | 嘻嘻哈哈
67 | 巴萨诺瓦
68 | bossa nova
69 | 民俗
70 | 狂热
71 | 合唱
72 | 比波普
73 | 民族
74 | 爵士乐
75 | 蓝调
76 | 炫舞
77 | 新世纪
78 | 大草原
79 | 暗潮
80 | 金属
81 | 豪放
82 | 优雅
83 | rap
84 | 热舞
85 | 打碟
86 | 草地
87 | breaking
88 | 鬼步舞
89 | 陕北民歌
90 | 英伦
91 | 民歌
92 | 电子乐
93 | bebop
94 | 慢摇舞曲
95 | 现代蓝调
96 | 古典
97 | 淡雅
98 | 慢摇
99 | 乡下
100 | 说唱
101 | 布鲁斯
102 | 嘻嘻嘻哈哈哈
103 |
--------------------------------------------------------------------------------
/entity_files/theme.json:
--------------------------------------------------------------------------------
1 | {
2 | "dj": [
3 | "dj",
4 | 83
5 | ],
6 | "经典老歌": [
7 | "经典老歌",
8 | 6
9 | ],
10 | "车载": [
11 | "车载",
12 | 2
13 | ],
14 | "老歌": [
15 | "老歌",
16 | 3
17 | ],
18 | "情歌": [
19 | "情歌",
20 | 10
21 | ],
22 | "我是歌手": [
23 | "我是歌手",
24 | 3
25 | ],
26 | "经典": [
27 | "经典",
28 | 7
29 | ],
30 | "红歌": [
31 | "红歌",
32 | 4
33 | ],
34 | "网络热歌": [
35 | "网络热歌",
36 | 1
37 | ],
38 | "中国好声音": [
39 | "中国好声音",
40 | 2
41 | ],
42 | "纯音乐": [
43 | "纯音乐",
44 | 5
45 | ],
46 | "神曲": [
47 | "神曲",
48 | 2
49 | ],
50 | "基督教": [
51 | "基督教",
52 | 1
53 | ],
54 | "成名曲": [
55 | "成名曲",
56 | 2
57 | ],
58 | "喊麦": [
59 | "喊麦",
60 | 1
61 | ],
62 | "新年": [
63 | "新年",
64 | 2
65 | ],
66 | "过年": [
67 | "过年",
68 | 1
69 | ],
70 | "串烧": [
71 | "串烧",
72 | 3
73 | ],
74 | "伤感情歌": [
75 | "伤感情歌",
76 | 3
77 | ]
78 | }
--------------------------------------------------------------------------------
/entity_files/style.json:
--------------------------------------------------------------------------------
1 | {
2 | "串烧": [
3 | "串烧",
4 | 2
5 | ],
6 | "舞曲": [
7 | "舞曲",
8 | 11
9 | ],
10 | "轻音乐": [
11 | "轻音乐",
12 | 6
13 | ],
14 | "蓝调": [
15 | "蓝调",
16 | 2
17 | ],
18 | "乡村": [
19 | "乡村",
20 | 1
21 | ],
22 | "民歌": [
23 | "民歌",
24 | 3
25 | ],
26 | "摇滚": [
27 | "摇滚",
28 | 12
29 | ],
30 | "合唱": [
31 | "合唱",
32 | 2
33 | ],
34 | "慢摇": [
35 | "慢摇",
36 | 2
37 | ],
38 | "草原": [
39 | "草原",
40 | 3
41 | ],
42 | "鬼步舞": [
43 | "鬼步舞",
44 | 1
45 | ],
46 | "嘻哈": [
47 | "嘻哈",
48 | 2
49 | ],
50 | "电音": [
51 | "电音",
52 | 2
53 | ],
54 | "拉锯电音": [
55 | "拉锯电音",
56 | 1
57 | ],
58 | "古典": [
59 | "古典",
60 | 2
61 | ],
62 | "b box": [
63 | "b box",
64 | 1
65 | ],
66 | "嘻哈说唱": [
67 | "嘻哈说唱",
68 | 1
69 | ],
70 | "陕北民歌": [
71 | "陕北民歌",
72 | 3
73 | ],
74 | "民族": [
75 | "民族",
76 | 1
77 | ],
78 | "民谣": [
79 | "民谣",
80 | 1
81 | ]
82 | }
--------------------------------------------------------------------------------
/curLine_file.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import sys
3 | def curLine():
4 | file_path = sys._getframe().f_back.f_code.co_filename # 获取调用函数的路径
5 | file_name=file_path[file_path.rfind("/") + 1:] # 获取调用函数所在的文件名
6 | lineno=sys._getframe().f_back.f_lineno#当前行号
7 | str="[%s:%s] "%(file_name,lineno)
8 | return str
9 |
10 |
11 | stop_words = set(('\n', '\r', '\t', ' '))
12 | # """全角转半角"""
13 | def quanjiao2banjiao(query_lower):
14 | rstring = list()
15 | for uchar in query_lower:
16 | # if uchar in stop_words: # 因为word piece会忽略空格引发问题,故当前版本都是忽略空格的,会引发错误
17 | # continue
18 | inside_code = ord(uchar)
19 | if inside_code == 12288: # 全角空格直接转换
20 | inside_code = 32
21 | elif (inside_code >= 65281 and inside_code <= 65374): # 全角字符(除空格)根据关系转化
22 | inside_code -= 65248
23 | uchar = chr(inside_code)
24 | # if uchar not in stop_mark:
25 | rstring.append(uchar)
26 | rerturn_str = "".join(rstring)
27 | return rerturn_str
28 | # 大写变小写, 全角变半角, 去标点 注意训练语料和模板也要改
29 | def normal_transformer(query):
30 | query_lower = query # .strip("’,?!,。?!�") #.lower()
31 | rstring = quanjiao2banjiao(query_lower).strip() #.replace("您","你")
32 | return rstring
33 |
34 | other_tag = "OTHERS"
--------------------------------------------------------------------------------
/entity_files/slot-dictionaries/emotion.txt:
--------------------------------------------------------------------------------
1 | 酷一点
2 | 休闲
3 | 励志
4 | 嚣张
5 | 幸福快乐
6 | 兴奋
7 | 清幽
8 | 疗伤
9 | 减压
10 | 治愈系
11 | 愉快
12 | 凄恻
13 | 忆旧
14 | 快活
15 | 欢乐
16 | 激动人心
17 | 欢悦
18 | 凄怆
19 | 打动人
20 | 快歌
21 | 心潮难平
22 | 哀戚
23 | 孤单
24 | 清新
25 | 安谧
26 | 激动不已
27 | 气盛
28 | 最high
29 | 清静
30 | 悲愁
31 | 浪漫
32 | 热血
33 | 喜悦
34 | 舒缓
35 | 性感
36 | 眷恋
37 | 悲伤
38 | 罗曼蒂克
39 | 孤孤单单
40 | 凄惶
41 | 舒服
42 | 柔情
43 | 开心
44 | 恋旧
45 | 慢一点
46 | 昂奋
47 | 凄然
48 | 忧伤
49 | 欢娱
50 | 分手
51 | 伤悲
52 | 猛歌
53 | 令人感动
54 | 幸福
55 | 慢歌
56 | 最嗨
57 | 温柔
58 | 静谧
59 | 宣泄
60 | 下雨天
61 | 喜欢
62 | high歌
63 | 嗨歌
64 | 相思
65 | 难过
66 | 孤单单
67 | 催人泪下
68 | 安静
69 | 静寂
70 | 欣悦
71 | 欢快
72 | 宁静
73 | 寂寞
74 | 抑郁自杀
75 | 抚平
76 | 动人
77 | 哀愁
78 | 激情澎湃
79 | 放松
80 | 表白
81 | 嗨曲
82 | 治愈心灵
83 | 甜蜜
84 | 欣喜
85 | 心痛
86 | 快乐
87 | 悲哀
88 | 称快
89 | 轻快
90 | 猛烈
91 | 动感
92 | 哀伤
93 | 悲怆
94 | 难受
95 | 正能量
96 | 伤感
97 | 治愈
98 | 欢喜
99 | 念旧
100 | 欢欢喜喜
101 | 愉悦
102 | 火爆
103 | 激情
104 | 心潮起伏
105 | 悲戚
106 | 劲爆
107 | 快快乐乐
108 | 感动
109 | 伤心
110 | 抑郁
111 | 奋斗
112 | 孤独
113 | 熬心
114 | 欢欣
115 | 美满
116 | 怀念
117 | 催人奋进
118 | 轻松
119 | 甜美
120 | 妖媚
121 | 思念
122 | 喜庆
123 | 肉麻
124 | 怀旧
125 | 怀古
126 | 嗨
127 | 激动
128 | 欢愉
129 | 清爽
130 | 抒情
131 | 激励
132 | 发泄
133 | 心潮澎湃
134 | 令人鼓舞
135 | 高兴
136 |
--------------------------------------------------------------------------------
/entity_files/slot-dictionaries/scene.txt:
--------------------------------------------------------------------------------
1 | 宅
2 | 做事
3 | 出车
4 | 夜晚
5 | 咖啡吧
6 | 午觉
7 | 早饭
8 | 睡懒觉
9 | 求婚
10 | 公车
11 | 喝咖啡
12 | 晚饭
13 | 清早
14 | 催眠
15 | 写作业
16 | 转转
17 | 小资
18 | 幽会
19 | 旅游
20 | 赖床
21 | 馆子
22 | 旅行
23 | 健身
24 | 校车
25 | 聚集
26 | 午饭
27 | 喜宴
28 | 跑步
29 | 发车
30 | 码代码
31 | 开车
32 | 婚礼
33 | 早餐
34 | 来一杯咖啡
35 | 夜总会
36 | 商场
37 | 团聚
38 | 饭庄
39 | 喜酒
40 | 晚上
41 | 深夜
42 | 下午茶
43 | 饭店
44 | 坐车
45 | 网吧
46 | 驾车
47 | 喜筵
48 | 好好工作
49 | 地铁
50 | 商店
51 | 睡前
52 | 回家
53 | 相聚
54 | 早上
55 | 餐馆
56 | 阴天
57 | 迪厅
58 | 午睡
59 | 咖啡店
60 | 夜店
61 | 酒家
62 | 公交车
63 | 小憩
64 | 喝酒
65 | 午餐
66 | 午休
67 | 上网
68 | k歌
69 | 学习
70 | 写书法
71 | 食堂
72 | 吃饭
73 | 下雨天
74 | 酒馆
75 | 午眠
76 | 店铺
77 | 饭铺
78 | 家务
79 | 游玩
80 | 看书
81 | 绕弯儿
82 | 散步
83 | 菜馆
84 | 早晨
85 | 做生日
86 | 迪吧
87 | 走走
88 | 远足
89 | 洗澡
90 | 睡觉
91 | 家事
92 | 游历
93 | 起床
94 | 工作
95 | 婚宴
96 | 瑜伽
97 | 写代码
98 | 宅在家里
99 | 大巴
100 | 寿诞
101 | 超市
102 | 酒吧
103 | 读书
104 | 牵手
105 | 活动
106 | 自习
107 | 冲凉
108 | 做清洁
109 | 溜达
110 | 约会
111 | 做寿
112 | 驱车
113 | 生日
114 | 饭馆
115 | 家务事
116 | 在路上
117 | 观光
118 | 雨天
119 | 聚会
120 | 家务时间
121 | 做作业
122 | 晚餐
123 | 作事
124 | 休息
125 | 夜间
126 | 清晨
127 | 家务活
128 | 咖啡馆
129 | 咖啡厅
130 | 睡觉前
131 | 胎教
132 | 劳动
133 | 泡澡
134 | 睡眠
135 | 忙碌
136 | 遛弯儿
137 | 集会
138 | 运动
139 | 结婚
140 | 公交
141 | 班车
142 | 转悠
143 | 酒店
144 | 做卫生
145 | 过生日
146 |
--------------------------------------------------------------------------------
/entity_files/slot-dictionaries/theme.txt:
--------------------------------------------------------------------------------
1 | 格莱美奖
2 | 翻唱
3 | 最经典
4 | 成名歌曲
5 | 基督
6 | 翻唱歌曲
7 | 招牌歌曲
8 | 红歌
9 | 超强音浪
10 | 电影
11 | 沙发歌曲
12 | 天籁之战
13 | 基督教
14 | 情歌
15 | dj
16 | 广场舞
17 | 车载音乐
18 | 幼儿园
19 | 纯人声
20 | 沙发
21 | 男生歌曲
22 | 中国风
23 | 车载低音炮
24 | 歌厅必点歌
25 | 神曲
26 | 单纯
27 | 电视剧
28 | 奥斯卡奖
29 | 圣诞
30 | 铃音
31 | 清新
32 | 音乐剧
33 | 革命
34 | 网游
35 | 冷门
36 | 学校
37 | 影视原声
38 | 梦想的声音
39 | 电视
40 | 佛经
41 | 浪漫韩剧
42 | 游戏
43 | 经典情歌
44 | 成名作
45 | 情侣
46 | 春节
47 | 现场版
48 | 中国新歌声
49 | 喊麦
50 | 奥奖
51 | 动画
52 | 格莱美
53 | 世界名曲
54 | 我是歌手
55 | 男人听的歌
56 | dj碟
57 | 动漫
58 | 影视
59 | 劲歌
60 | 老情歌
61 | 个性化
62 | 格奖
63 | 沙发音乐
64 | 新歌声
65 | 网络热歌
66 | 纯音乐
67 | 网络伤感情歌
68 | 经典老歌
69 | 串烧
70 | 老一点的歌
71 | mv喊麦
72 | 漫画
73 | 男人情歌
74 | 演唱会版
75 | 成名曲
76 | 佛教
77 | 武侠
78 | dj舞曲
79 | 电视连续剧
80 | 韩剧
81 | dj热碟
82 | 蒙面唱将
83 | 二次元
84 | 平安夜
85 | live
86 | 新春
87 | 武侠风
88 | 歌手2018
89 | 年会
90 | 动画片
91 | 男生听的歌
92 | 经典现场
93 | 次元
94 | 新年
95 | 中国好歌曲
96 | ktv
97 | 中国好声音
98 | live版
99 | 一人一首招牌歌
100 | 代表歌曲
101 | 铃声
102 | 演唱会
103 | ktv必点歌
104 | 小清新
105 | 老歌
106 | 个性电台
107 | ktv歌曲
108 | 爱情
109 | 车载歌曲
110 | 洗脑
111 | 纯旋律
112 | 代表作
113 | 武侠情节
114 | 奥斯卡
115 | 大清新
116 | 圣诞节
117 | 冷僻
118 | 爱国
119 | 伤情中国风
120 | 传统
121 | 伤感的情歌
122 | 好歌曲
123 | 校园
124 | 车载
125 | mc喊麦
126 | 伴奏
127 | 个性歌曲
128 | 军旅
129 | 经典
130 | 个性
131 | 伤感情歌
132 | 没有歌词
133 | 过年
134 | 好声音
135 | 男性歌曲
136 | 元旦
137 | 流金岁月
138 | 经典翻唱
139 | 蒙面唱将猜猜猜
140 | 超赞纯人声
141 |
--------------------------------------------------------------------------------
/entity_files/slot-dictionaries/age.txt:
--------------------------------------------------------------------------------
1 | 小孩
2 | 小学
3 | 女神
4 | 2016年
5 | 1994年
6 | 1978年
7 | 儿童
8 | 1986年
9 | 女人
10 | 中老年
11 | 2004年
12 | 六十后老年人
13 | 1989年
14 | 1965年
15 | 七零年代
16 | 1977年
17 | 两千后
18 | 1949年
19 | 娃娃
20 | 2017年
21 | 1954年
22 | 1971年
23 | 1995年
24 | 1979年
25 | 1981年
26 | 美女
27 | 男人
28 | 2007年
29 | 10年代
30 | 1964年
31 | 八十年代
32 | 1951年
33 | 女孩子
34 | 2222年
35 | 一岁半
36 | 1968年
37 | 1996年
38 | 2010年
39 | 80后
40 | 九零年代
41 | 2002年
42 | 九零后
43 | 90年代
44 | 70后
45 | 1956年
46 | 老年
47 | 1967年
48 | 1984年
49 | 九十年代
50 | 2020年
51 | 80年代
52 | 八零年代
53 | 1997年
54 | 2011年
55 | 90后
56 | 1987年
57 | 2019年
58 | 2005年
59 | 1970年
60 | 帅哥
61 | 1957年
62 | 20年代
63 | 男孩子
64 | 1966年
65 | 2012年
66 | 父母辈
67 | 1953年
68 | 2008年
69 | 1962年
70 | 1982年
71 | 2000年
72 | 一岁
73 | 七十后
74 | 女生
75 | 1990年
76 | 一零年代
77 | 八零后
78 | 1999年
79 | 1960年
80 | 1958年
81 | 1975年
82 | 1950年
83 | 1985年
84 | 2010年之前
85 | 幼儿
86 | 1961年
87 | 八十后
88 | 童谣
89 | 亲子儿歌
90 | 1969年
91 | 00年代
92 | 1991年
93 | 2013年
94 | 91世纪
95 | 2003年
96 | 六零后
97 | 1976年
98 | 暖男
99 | 1959年
100 | 小朋友
101 | 22世纪
102 | 1992年
103 | 2014年
104 | 21世纪
105 | 1972年
106 | 1955年
107 | 男孩
108 | 文艺青年
109 | 1980年
110 | 2006年
111 | 七十年代
112 | 2010之前
113 | 1963年
114 | 零零后
115 | 儿歌
116 | 男神
117 | 60后
118 | 零零年代
119 | 1993年
120 | 2015年
121 | 70年代
122 | 1952年
123 | 1973年
124 | 2009年
125 | 2018年
126 | 少儿
127 | 1983年
128 | 2001年
129 | 上班族
130 | 九十后
131 | 二零一一
132 | 00后
133 | 两岁
134 | 七零后
135 | 女孩
136 | 1988年
137 | 1998年
138 | 男生
139 | 1974年
140 |
--------------------------------------------------------------------------------
/entity_files/frequentSinger.json:
--------------------------------------------------------------------------------
1 | {
2 | "1": 0.1622604899345137,
3 | "0": 0.06766917293233082,
4 | "10": 0.04640633842671194,
5 | "9": 0.041959738054814455,
6 | "N": 0.03128789716226049,
7 | "M": 0.025305198480071147,
8 | "未来": 0.024173336567224514,
9 | "P": 0.023930794728757376,
10 | "左右": 0.01447166302853909,
11 | "安": 0.01228878648233487,
12 | "M2": 0.011399466407955373,
13 | "风": 0.009782520818174469,
14 | "15": 0.008650658905327836,
15 | "19": 0.007437949712992158,
16 | "21": 0.007114560595035977,
17 | "张航": 0.00606354596167839,
18 | "凌晨": 0.004608294930875576,
19 | "31": 0.004284905812919395,
20 | "信": 0.003718974856496079,
21 | "气象预报": 0.0031530439000727628,
22 | "美": 0.0031530439000727628,
23 | "城市": 0.002991349341094672,
24 | "C": 0.002991349341094672,
25 | "-": 0.0028296547821165816,
26 | "吉": 0.0028296547821165816,
27 | "泉": 0.0026679602231384912,
28 | "流行歌曲": 0.0025871129436494463,
29 | "in": 0.0025871129436494463,
30 | "夏": 0.002425418384671356,
31 | "S": 0.0023445711051823105,
32 | "B": 0.00218287654620422,
33 | "A": 0.00218287654620422,
34 | "乌兰": 0.0020211819872261298,
35 | "R": 0.0019403347077370846,
36 | "D": 0.0017786401487589942,
37 | "H": 0.0017786401487589942,
38 | "W": 0.001697792869269949,
39 | "红": 0.001697792869269949,
40 | "晴天": 0.0015360983102918587,
41 | "T": 0.0015360983102918587,
42 | "星": 0.0015360983102918587,
43 | "十一": 0.0015360983102918587,
44 | "ch": 0.0012935564718247231,
45 | "李白": 0.001212709192335678,
46 | "G": 0.001212709192335678,
47 | "杨": 0.001212709192335678,
48 | "零": 0.0011318619128466328,
49 | "I": 0.0011318619128466328,
50 | "J": 0.0011318619128466328,
51 | "小雨": 0.0010510146333575876,
52 | "L": 0.0010510146333575876,
53 | "O": 0.0009701673538685423,
54 |
55 | "m2":0.8,
56 | "什么": 0.8,
57 | "倩女幽魂": 0.8,
58 | "两只老虎": 0.8,
59 | "中华": 0.8,
60 | "文歌": 0.8,
61 | "岳云鹏": 0.8,
62 | "郭德纲": 0.8,
63 | "大美": 0.8,
64 | "大王叫我来巡山": 0.8,
65 | "大海": 0.8,
66 | "大小姐": 0.8,
67 | "白天": 0.8,
68 | "江南": 0.8,
69 | "凌晨": 0.8,
70 | "金刚": 0.8,
71 | "苍蝇": 0.8,
72 | "海洋": 0.8,
73 | "相思": 0.8,
74 | "李白": 0.8,
75 | "dj": 0.8,
76 | "you": 0.8,
77 | "曹操": 0.8,
78 | "曲歌": 0.5,
79 | "云朵": 0.5,
80 | "小女子": 0.8
81 | }
82 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | *-notice.txt
3 | *log.txt
4 | *score.txt
5 | *.jsonl
6 | *.pickle
7 | *results.txt
8 | hunxiao_words.txt
9 | *.csv
10 |
11 | # Distribution / packaging
12 | .Python
13 | models/
14 | .idea/
15 | __pycache__/
16 | *.pyc
17 | *.py[cod]
18 | *$py.class
19 | *.so
20 |
21 |
22 | develop-eggs/
23 | dist/
24 | downloads/
25 | eggs/
26 | .eggs/
27 | lib/
28 | lib64/
29 | parts/
30 | sdist/
31 | var/
32 | wheels/
33 | pip-wheel-metadata/
34 | share/python-wheels/
35 | *.egg-info/
36 | .installed.cfg
37 | *.egg
38 | MANIFEST
39 |
40 | # PyInstaller
41 | # Usually these files are written by a python script from a template
42 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
43 | *.manifest
44 | *.spec
45 |
46 | # Installer logs
47 | pip-log.txt
48 | pip-delete-this-directory.txt
49 |
50 | # Unit test / coverage reports
51 | htmlcov/
52 | .tox/
53 | .nox/
54 | .coverage
55 | .coverage.*
56 | .cache
57 | nosetests.xml
58 | coverage.xml
59 | *.cover
60 | *.py,cover
61 | .hypothesis/
62 | .pytest_cache/
63 |
64 | # Translations
65 | *.mo
66 | *.pot
67 |
68 | # Django stuff:
69 | *.log
70 | local_settings.py
71 | db.sqlite3
72 | db.sqlite3-journal
73 |
74 | # Flask stuff:
75 | instance/
76 | .webassets-cache
77 |
78 | # Scrapy stuff:
79 | .scrapy
80 |
81 | # Sphinx documentation
82 | docs/_build/
83 |
84 | # PyBuilder
85 | target/
86 |
87 | # Jupyter Notebook
88 | .ipynb_checkpoints
89 |
90 | # IPython
91 | profile_default/
92 | ipython_config.py
93 |
94 | # pyenv
95 | .python-version
96 |
97 | # pipenv
98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
101 | # install all needed dependencies.
102 | #Pipfile.lock
103 |
104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
105 | __pypackages__/
106 |
107 | # Celery stuff
108 | celerybeat-schedule
109 | celerybeat.pid
110 |
111 | # SageMath parsed files
112 | *.sage.py
113 |
114 | # Environments
115 | .env
116 | .venv
117 | env/
118 | venv/
119 | ENV/
120 | env.bak/
121 | venv.bak/
122 |
123 | # Spyder project settings
124 | .spyderproject
125 | .spyproject
126 |
127 | # Rope project settings
128 | .ropeproject
129 |
130 | # mkdocs documentation
131 | /site
132 |
133 | # mypy
134 | .mypy_cache/
135 | .dmypy.json
136 | dmypy.json
137 |
138 | # Pyre type checker
139 | .pyre/
140 |
--------------------------------------------------------------------------------
/score_main.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | """Calculates evaluation scores for a prediction TSV file.
18 |
19 | The prediction file is produced by predict_main.py and should contain 3 or more
20 | columns:
21 | 1: sources (concatenated)
22 | 2: prediction
23 | 3-n: targets (1 or more)
24 | """
25 |
26 | from __future__ import absolute_import
27 | from __future__ import division
28 |
29 | from __future__ import print_function
30 |
31 | from absl import app, flags, logging
32 | import score_lib
33 | from bert import tokenization
34 | from curLine_file import curLine
35 |
36 | FLAGS = flags.FLAGS
37 |
38 | flags.DEFINE_string(
39 | 'prediction_file', None,
40 | 'TSV file containing source, prediction, and target columns.')
41 | flags.DEFINE_bool(
42 | 'do_lower_case', True,
43 | 'Whether score computation should be case insensitive (in the LaserTagger '
44 | 'paper this was set to True).')
45 | flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.')
46 | flags.DEFINE_string(
47 | 'domain_name', None,
48 | 'Whether to lower case the input text. Should be True for uncased '
49 | 'models and False for cased models.')
50 | def main(argv):
51 | if len(argv) > 1:
52 | raise app.UsageError('Too many command-line arguments.')
53 | flags.mark_flag_as_required('prediction_file')
54 | target_domain_name = FLAGS.domain_name
55 | print(curLine(), "target_domain_name:", target_domain_name)
56 |
57 | predDomain_list, predIntent_list, domain_list, right_intent_num, right_slot_num, exact_num = score_lib.read_data(
58 | FLAGS.prediction_file, FLAGS.do_lower_case, target_domain_name=target_domain_name)
59 | logging.info(f'Read file: {FLAGS.prediction_file}')
60 | all_num = len(domain_list)
61 | domain_acc = score_lib.compute_exact_score(predDomain_list, domain_list)
62 | intent_acc = float(right_intent_num) / all_num
63 | slot_acc = float(right_slot_num) / all_num
64 | exact_score = float(exact_num) / all_num
65 | print('Num=%d, domain_acc=%.4f, intent_acc=%.4f, slot_acc=%.5f, exact_score=%.4f'
66 | % (all_num, domain_acc, intent_acc, slot_acc, exact_score))
67 |
68 |
69 | if __name__ == '__main__':
70 | app.run(main)
71 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | """Utility functions for LaserTagger."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 |
22 | from __future__ import print_function
23 |
24 | import json
25 | import csv
26 |
27 | import tensorflow as tf
28 | from curLine_file import curLine, normal_transformer, other_tag
29 |
30 |
31 |
32 | def yield_sources_and_targets(input_file, input_format, domain_name):
33 | """Reads and yields source lists and targets from the input file.
34 |
35 | Args:
36 | input_file: Path to the input file.
37 | input_format: Format of the input file.
38 |
39 | Yields:
40 | Tuple with (list of source texts, target text).
41 | """
42 | if input_format == 'nlu':
43 | yield_example_fn = _nlu_examples
44 | else:
45 | raise ValueError('Unsupported input_format: {}'.format(input_format))
46 | for sources, target in yield_example_fn(input_file, domain_name):
47 | yield sources, target
48 |
49 | def _nlu_examples(input_file, target_domain_name=None):
50 | with tf.gfile.GFile(input_file) as f:
51 | reader = csv.reader(f)
52 | session_list = []
53 | for row_id, (sessionId, raw_query, domain_intent, param) in enumerate(reader):
54 | query = normal_transformer(raw_query)
55 | param = normal_transformer(param)
56 | sources = []
57 | if row_id > 0 and sessionId == session_list[row_id - 1][0]:
58 | sources.append(session_list[row_id-1][1]) # last query
59 | sources.append(query)
60 | if domain_intent == other_tag:
61 | domain = other_tag
62 | else:
63 | domain, intent = domain_intent.split(".")
64 | session_list.append((sessionId, query))
65 | if target_domain_name is not None and target_domain_name != domain:
66 | continue
67 | yield sources, (intent,param)
68 |
69 |
70 | def read_label_map(path):
71 | """Returns label map read from the given path."""
72 | with tf.gfile.GFile(path) as f:
73 | if path.endswith('.json'):
74 | return json.load(f)
75 | else:
76 | label_map = {}
77 | empty_line_encountered = False
78 | for tag in f:
79 | tag = tag.strip()
80 | if tag:
81 | label_map[tag] = len(label_map)
82 | else:
83 | if empty_line_encountered:
84 | raise ValueError(
85 | 'There should be no empty lines in the middle of the label map '
86 | 'file.'
87 | )
88 | empty_line_encountered = True
89 | return label_map
90 |
--------------------------------------------------------------------------------
/official_transformer/model_params.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Defines Transformer model parameters."""
17 |
18 | from collections import defaultdict
19 |
20 |
21 | BASE_PARAMS = defaultdict(
22 | lambda: None, # Set default value to None.
23 |
24 | # Input params
25 | default_batch_size=2048, # Maximum number of tokens per batch of examples.
26 | default_batch_size_tpu=32768,
27 | max_length=256, # Maximum number of tokens per example.
28 |
29 | # Model params
30 | initializer_gain=1.0, # Used in trainable variable initialization.
31 | vocab_size=33708, # Number of tokens defined in the vocabulary file.
32 | hidden_size=512, # Model dimension in the hidden layers.
33 | num_hidden_layers=6, # Number of layers in the encoder and decoder stacks.
34 | num_heads=8, # Number of heads to use in multi-headed attention.
35 | filter_size=2048, # Inner layer dimension in the feedforward network.
36 |
37 | # Dropout values (only used when training)
38 | layer_postprocess_dropout=0.1,
39 | attention_dropout=0.1,
40 | relu_dropout=0.1,
41 |
42 | # Training params
43 | label_smoothing=0.1,
44 | learning_rate=2.0,
45 | learning_rate_decay_rate=1.0,
46 | learning_rate_warmup_steps=16000,
47 |
48 | # Optimizer params
49 | optimizer_adam_beta1=0.9,
50 | optimizer_adam_beta2=0.997,
51 | optimizer_adam_epsilon=1e-09,
52 |
53 | # Default prediction params
54 | extra_decode_length=50,
55 | beam_size=4,
56 | alpha=0.6, # used to calculate length normalization in beam search
57 |
58 | # TPU specific parameters
59 | use_tpu=False,
60 | static_batch=False,
61 | allow_ffn_pad=True,
62 | )
63 |
64 | BIG_PARAMS = BASE_PARAMS.copy()
65 | BIG_PARAMS.update(
66 | default_batch_size=4096,
67 |
68 | # default batch size is smaller than for BASE_PARAMS due to memory limits.
69 | default_batch_size_tpu=16384,
70 |
71 | hidden_size=1024,
72 | filter_size=4096,
73 | num_heads=16,
74 | )
75 |
76 | # Parameters for running the model in multi gpu. These should not change the
77 | # params that modify the model shape (such as the hidden_size or num_heads).
78 | BASE_MULTI_GPU_PARAMS = BASE_PARAMS.copy()
79 | BASE_MULTI_GPU_PARAMS.update(
80 | learning_rate_warmup_steps=8000
81 | )
82 |
83 | BIG_MULTI_GPU_PARAMS = BIG_PARAMS.copy()
84 | BIG_MULTI_GPU_PARAMS.update(
85 | layer_postprocess_dropout=0.3,
86 | learning_rate_warmup_steps=8000
87 | )
88 |
89 | # Parameters for testing the model
90 | TINY_PARAMS = BASE_PARAMS.copy()
91 | TINY_PARAMS.update(
92 | default_batch_size=1024,
93 | default_batch_size_tpu=1024,
94 | hidden_size=32,
95 | num_heads=4,
96 | filter_size=256,
97 | )
98 |
--------------------------------------------------------------------------------
/official_transformer/ffn_layer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Implementation of fully connected network."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import tensorflow as tf
23 |
24 |
25 | class FeedFowardNetwork(tf.layers.Layer):
26 | """Fully connected feedforward network."""
27 |
28 | def __init__(self, hidden_size, filter_size, relu_dropout, train, allow_pad):
29 | super(FeedFowardNetwork, self).__init__()
30 | self.hidden_size = hidden_size
31 | self.filter_size = filter_size
32 | self.relu_dropout = relu_dropout
33 | self.train = train
34 | self.allow_pad = allow_pad
35 |
36 | self.filter_dense_layer = tf.layers.Dense(
37 | filter_size, use_bias=True, activation=tf.nn.relu, name="filter_layer")
38 | self.output_dense_layer = tf.layers.Dense(
39 | hidden_size, use_bias=True, name="output_layer")
40 |
41 | def call(self, x, padding=None):
42 | """Return outputs of the feedforward network.
43 |
44 | Args:
45 | x: tensor with shape [batch_size, length, hidden_size]
46 | padding: (optional) If set, the padding values are temporarily removed
47 | from x (provided self.allow_pad is set). The padding values are placed
48 | back in the output tensor in the same locations.
49 | shape [batch_size, length]
50 |
51 | Returns:
52 | Output of the feedforward network.
53 | tensor with shape [batch_size, length, hidden_size]
54 | """
55 | padding = None if not self.allow_pad else padding
56 |
57 | # Retrieve dynamically known shapes
58 | batch_size = tf.shape(x)[0]
59 | length = tf.shape(x)[1]
60 |
61 | if padding is not None:
62 | with tf.name_scope("remove_padding"):
63 | # Flatten padding to [batch_size*length]
64 | pad_mask = tf.reshape(padding, [-1])
65 |
66 | nonpad_ids = tf.to_int32(tf.where(pad_mask < 1e-9))
67 |
68 | # Reshape x to [batch_size*length, hidden_size] to remove padding
69 | x = tf.reshape(x, [-1, self.hidden_size])
70 | x = tf.gather_nd(x, indices=nonpad_ids)
71 |
72 | # Reshape x from 2 dimensions to 3 dimensions.
73 | x.set_shape([None, self.hidden_size])
74 | x = tf.expand_dims(x, axis=0)
75 |
76 | output = self.filter_dense_layer(x)
77 | if self.train:
78 | output = tf.nn.dropout(output, 1.0 - self.relu_dropout)
79 | output = self.output_dense_layer(output)
80 |
81 | if padding is not None:
82 | with tf.name_scope("re_add_padding"):
83 | output = tf.squeeze(output, axis=0)
84 | output = tf.scatter_nd(
85 | indices=nonpad_ids,
86 | updates=output,
87 | shape=[batch_size * length, self.hidden_size]
88 | )
89 | output = tf.reshape(output, [batch_size, length, self.hidden_size])
90 | return output
91 |
--------------------------------------------------------------------------------
/find_entity/probable_acmation.py:
--------------------------------------------------------------------------------
1 | # 发现疑似实体,辅助训练
2 | # 用ac自动机构建发现疑似实体的工具
3 | import os
4 | from collections import defaultdict
5 | import json
6 | import re
7 | from .acmation import KeywordTree, add_to_ac, entity_files_folder, entity_folder
8 | from curLine_file import curLine, normal_transformer
9 |
10 | domain2entity_map = {}
11 | domain2entity_map["music"] = ["age", "singer", "song", "toplist", "theme", "style", "scene", "language", "emotion", "instrument"]
12 | domain2entity_map["navigation"] = ["custom_destination", "city"] # place city
13 | domain2entity_map["phone_call"] = ["phone_num", "contact_name"]
14 |
15 | re_phoneNum = re.compile("[0-9一二三四五六七八九十拾]+") # 编译
16 |
17 |
18 | # 也许直接读取下载的xls文件更方便,但那样需要安装xlrd模块
19 | self_entity_trie_tree = {} # 总的实体字典 自己建立的某些实体类型的实体树
20 | for domain, entity_type_list in domain2entity_map.items():
21 | print(curLine(), domain, entity_type_list)
22 | for entity_type in entity_type_list:
23 | if entity_type not in self_entity_trie_tree:
24 | ac = KeywordTree(case_insensitive=True)
25 | else:
26 | ac = self_entity_trie_tree[entity_type]
27 |
28 | # TODO
29 | if entity_type == "city":
30 | # for current_entity_type in ["city", "province"]:
31 | # entity_file = waibu_folder + "%s.json" % current_entity_type
32 | # with open(entity_file, "r") as f:
33 | # current_entity_dict = json.load(f)
34 | # print(curLine(), "get %d %s from %s" %
35 | # (len(current_entity_dict), current_entity_type, entity_file))
36 | # for entity_before, entity_times in current_entity_dict.items():
37 | # entity_after = entity_before
38 | # add_to_ac(ac, entity_type, entity_before, entity_after, pri=1)
39 |
40 | ## 从标注语料中挖掘得到的地名
41 | for current_entity_type in ["destination", "origin"]:
42 | entity_file = os.path.join(entity_files_folder, "%s.json" % current_entity_type)
43 | with open(entity_file, "r") as f:
44 | current_entity_dict = json.load(f)
45 | print(curLine(), "get %d %s from %s" %
46 | (len(current_entity_dict), current_entity_type, entity_file))
47 | for entity_before, entity_after_times in current_entity_dict.items():
48 | entity_after = entity_after_times[0]
49 | add_to_ac(ac, entity_type, entity_before, entity_after, pri=2)
50 | input(curLine())
51 |
52 | # 给的实体库,最高优先级
53 | entity_file = os.path.join(entity_folder, "%s.txt" % entity_type)
54 | if os.path.exists(entity_file):
55 | with open(entity_file, "r") as fr:
56 | lines = fr.readlines()
57 | print(curLine(), "get %d %s from %s" % (len(lines), entity_type, entity_file))
58 | for line in lines:
59 | entity_after = line.strip()
60 | entity_before = entity_after # TODO
61 | pri = 3
62 | if entity_type in ["song"]:
63 | pri -= 0.5
64 | add_to_ac(ac, entity_type, entity_before, entity_after, pri=pri)
65 | ac.finalize()
66 | self_entity_trie_tree[entity_type] = ac
67 |
68 |
69 | def get_all_entity(corpus, useEntityTypeList):
70 | self_entityTypeMap = defaultdict(list)
71 | for entity_type in useEntityTypeList:
72 | result = self_entity_trie_tree[entity_type].search(corpus)
73 | for res in result:
74 | after, priority = res.meta_data
75 | self_entityTypeMap[entity_type].append({'before': res.keyword, 'after': after, "priority":priority})
76 | if "phone_num" in useEntityTypeList:
77 | token_numbers = re_phoneNum.findall(corpus)
78 | for number in token_numbers:
79 | self_entityTypeMap["phone_num"].append({'before':number, 'after':number, 'priority': 2})
80 | return self_entityTypeMap
--------------------------------------------------------------------------------
/confusion_words_duoyin.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # 混淆词提取算法
3 | from pypinyin import lazy_pinyin, pinyin, Style
4 | # pinyin 的方法默认带声调,而 lazy_pinyin 方法不带声调,但是会得到不常见的多音字
5 | # pinyin(c, heteronym=True, style=0) 不考虑音调情况下的多音字
6 |
7 | from Levenshtein import distance # python-Levenshtein
8 | # 编辑距离 (Levenshtein Distance算法)
9 | import os
10 | import copy
11 | from curLine_file import curLine
12 | from find_entity.acmation import entity_folder
13 |
14 | number_map = {"0":"零", "1":"一", "2":"二", "3":"三", "4":"四", "5":"五", "6":"六", "7":"七", "8":"八", "9":"九"}
15 | def get_pinyin_combination(entity):
16 | all_combination = [""]
17 | for zi in entity:
18 | if zi in number_map:
19 | zi = number_map[zi]
20 | zi_pinyin_list = pinyin(zi, heteronym=True, style=0)[0]
21 | pre_combination = copy.deepcopy(all_combination)
22 | all_combination = []
23 | for index, zi_pinyin in enumerate(zi_pinyin_list): # 逐个添加每种发音
24 | for com in pre_combination: # com代表一种情况
25 | # com.append(zi_pinyin)
26 | # new = com + [zi_pinyin]
27 | # new = "".join(new)
28 | all_combination.append("%s%s" % (com, zi_pinyin)) # TODO 要不要加分隔符
29 | return all_combination
30 |
31 | def get_entityType_pinyin(entity_type):
32 | entity_info_dict = {}
33 | entity_file = os.path.join(entity_folder, "%s.txt" % entity_type)
34 | with open(entity_file, "r") as fr:
35 | lines = fr.readlines()
36 | pri = 3
37 | if entity_type in ["song"]:
38 | pri -= 0.5
39 | print(curLine(), "get %d %s from %s, pri=%f" % (len(lines), entity_type, entity_file, pri))
40 | for line in lines:
41 | entity_after = line.strip()
42 | all_combination = get_pinyin_combination(entity=entity_after)
43 | # for combination in all_combination:
44 | if entity_after not in entity_info_dict: # 新的发音
45 | entity_info_dict[entity_after] = (all_combination, pri)
46 | else: # 一般不会重复,最多重复一次
47 | _, old_pri = entity_info_dict[entity_after]
48 | if pri > old_pri:
49 | entity_info_dict[entity_after] = (all_combination, pri)
50 | return entity_info_dict
51 |
52 | # 用编辑距离度量拼音字符串之间的相似度
53 | def pinyin_similar_word(entity_info_dict, word):
54 | similar_word = None
55 | if word in entity_info_dict: # 存在实体,无需纠错
56 | return 0, word
57 | all_combination = get_pinyin_combination(entity=word)
58 | top_similar_score = 0
59 | for current_combination in all_combination: # 当前的各种发音
60 | current_distance = 10000
61 |
62 | for entity_after,(com_list, pri) in entity_info_dict.items():
63 | for com in com_list:
64 | d = distance(com, current_combination)
65 | if d < current_distance:
66 | current_distance = d
67 | similar_word = entity_after
68 | current_similar_score = 1.0 - float(current_distance) / len(current_combination)
69 | print(curLine(), "current_combination:%s, %f" % (current_combination, current_similar_score), similar_word, current_distance)
70 |
71 |
72 |
73 | singer_pinyin = get_entityType_pinyin(entity_type="singer")
74 | print(curLine(), len(singer_pinyin), "singer_pinyin")
75 |
76 | song_pinyin = get_entityType_pinyin(entity_type="song")
77 | print(curLine(), len(song_pinyin), "song_pinyin")
78 |
79 |
80 |
81 | if __name__ == "__main__":
82 | s = "你为什123c单身"
83 | for c in s:
84 | c_pinyin = lazy_pinyin(s, errors="ignore")
85 | print(curLine(), c,c_pinyin)
86 | print(curLine(), pinyin(s, heteronym=True, style=0))
87 | # # print(pinyin(c, heteronym=True, style=0))
88 |
89 | # all_combination = get_pinyin_combination(entity=s)
90 | #
91 | # for index, combination in enumerate(all_combination):
92 | # print(curLine(), index, combination)
93 | # for s in ["abc", "abs1", "fabc"]:
94 | # ed = distance("abs", s)
95 | # print(s, ed)
96 |
97 | pinyin_similar_word(singer_pinyin, "周杰")
98 | pinyin_similar_word(singer_pinyin, "前任")
99 |
--------------------------------------------------------------------------------
/split_corpus.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # 分割语料
3 | import os
4 | import csv, json
5 | from collections import defaultdict
6 | import re
7 | from curLine_file import curLine, other_tag
8 |
9 | all_entity_dict = defaultdict(dict)
10 | before2after = {"父亲":"父亲", "赵磊":"赵磊", "甜蜜":"甜蜜", "大王叫我来":"大王叫我来巡山"}
11 | def get_slot(param):
12 | slot = []
13 | if "<" not in param:
14 | return slot
15 | if ">" not in param:
16 | print(curLine(), "param:", param)
17 | return slot
18 | if "" not in param:
19 | return slot
20 | start_segment = re.findall("<[\w_]*>", param)
21 | end_segment = re.findall("[\w_]*>", param)
22 | if len(start_segment) != len(end_segment):
23 | print(curLine(), "start_segment:", start_segment)
24 | print(curLine(), "end_segment:", end_segment)
25 | search_location = 0
26 | for s,e in zip(start_segment, end_segment):
27 | entityType = s[1:-1]
28 | assert "%s>" % entityType == e
29 | start_index = param[search_location:].index(s) + len(s)
30 | end_index = param[search_location:].index(e)
31 | entity_info = param[search_location:][start_index:end_index]
32 | search_location += end_index + len(e)
33 | before,after = entity_info, entity_info
34 | if "||" in entity_info:
35 | before, after = entity_info.split("||")
36 | if before in before2after:
37 | after = before2after[before]
38 | if before not in all_entity_dict[entityType]:
39 | all_entity_dict[entityType][before] = [after, 1]
40 | else:
41 | if after != all_entity_dict[entityType][before][0]:
42 | print(curLine(), entityType, before, after, all_entity_dict[entityType][before])
43 | assert after == all_entity_dict[entityType][before][0]
44 | all_entity_dict[entityType][before][1] += 1
45 | if before != after:
46 | before = after
47 | if before not in all_entity_dict[entityType]:
48 | all_entity_dict[entityType][before] = [after, 1]
49 | else:
50 | assert after == all_entity_dict[entityType][before][0]
51 | all_entity_dict[entityType][before][1] += 1
52 |
53 |
54 | # 预处理
55 | def process(source_file, train_file, dev_file):
56 | dev_lines = []
57 | train_num = 0
58 | intent_distribution = defaultdict(dict)
59 | with open(source_file, "r") as f, open(train_file, "w") as f_train:
60 | reader = csv.reader(f)
61 | train_write = csv.writer(f_train, dialect='excel')
62 | for row_id, line in enumerate(reader):
63 | if row_id==0:
64 | continue
65 | (sessionId, raw_query, domain_intent, param) = line
66 | get_slot(param)
67 |
68 | if domain_intent == other_tag:
69 | domain = other_tag
70 | intent = other_tag
71 | else:
72 | domain, intent = domain_intent.split(".")
73 | if intent in intent_distribution[domain]:
74 | intent_distribution[domain][intent] += 1
75 | else:
76 | intent_distribution[domain][intent] = 0
77 | if row_id == 0:
78 | continue
79 | sessionId = int(sessionId)
80 | if sessionId % 10>0:
81 | train_write.writerow(line)
82 | train_num += 1
83 | else:
84 | dev_lines.append(line)
85 | with open(dev_file, "w") as f_dev:
86 | write = csv.writer(f_dev, dialect='excel')
87 | for line in dev_lines:
88 | write.writerow(line)
89 | print(curLine(), "dev=%d, train=%d" % (len(dev_lines), train_num))
90 | for domain, intent_num in intent_distribution.items():
91 | print(curLine(), domain, intent_num)
92 |
93 |
94 | if __name__ == "__main__":
95 | corpus_folder = "/home/wzk/Mywork/corpus/compe/69"
96 | source_file = os.path.join(corpus_folder, "train.csv")
97 | train_file = os.path.join(corpus_folder, "train.txt")
98 | dev_file = os.path.join(corpus_folder, "dev.txt")
99 | process(source_file, train_file, dev_file)
100 |
101 | for entityType, entityDict in all_entity_dict.items():
102 | json_file = os.path.join(corpus_folder, "%s.json" % entityType)
103 | with open(json_file, "w") as f:
104 | json.dump(entityDict, f, ensure_ascii=False, indent=4)
105 | print(curLine(), "save %d %s to %s" % (len(entityDict), entityType, json_file))
106 |
107 |
--------------------------------------------------------------------------------
/official_transformer/embedding_layer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Implementation of embedding layer with shared weights."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import tensorflow as tf # pylint: disable=g-bad-import-order
23 |
24 | from official_transformer import tpu as tpu_utils
25 |
26 |
27 | class EmbeddingSharedWeights(tf.layers.Layer):
28 | """Calculates input embeddings and pre-softmax linear with shared weights."""
29 |
30 | def __init__(self, vocab_size, hidden_size, method="gather"):
31 | """Specify characteristic parameters of embedding layer.
32 |
33 | Args:
34 | vocab_size: Number of tokens in the embedding. (Typically ~32,000)
35 | hidden_size: Dimensionality of the embedding. (Typically 512 or 1024)
36 | method: Strategy for performing embedding lookup. "gather" uses tf.gather
37 | which performs well on CPUs and GPUs, but very poorly on TPUs. "matmul"
38 | one-hot encodes the indicies and formulates the embedding as a sparse
39 | matrix multiplication. The matmul formulation is wasteful as it does
40 | extra work, however matrix multiplication is very fast on TPUs which
41 | makes "matmul" considerably faster than "gather" on TPUs.
42 | """
43 | super(EmbeddingSharedWeights, self).__init__()
44 | self.vocab_size = vocab_size
45 | self.hidden_size = hidden_size
46 | if method not in ("gather", "matmul"):
47 | raise ValueError("method {} must be 'gather' or 'matmul'".format(method))
48 | self.method = method
49 |
50 | def build(self, _):
51 | with tf.variable_scope("embedding_and_softmax", reuse=tf.AUTO_REUSE):
52 | # Create and initialize weights. The random normal initializer was chosen
53 | # randomly, and works well.
54 | self.shared_weights = tf.get_variable(
55 | "weights", [self.vocab_size, self.hidden_size],
56 | initializer=tf.random_normal_initializer(
57 | 0., self.hidden_size ** -0.5))
58 |
59 | self.built = True
60 |
61 | def call(self, x):
62 | """Get token embeddings of x.
63 |
64 | Args:
65 | x: An int64 tensor with shape [batch_size, length]
66 | Returns:
67 | embeddings: float32 tensor with shape [batch_size, length, embedding_size]
68 | padding: float32 tensor with shape [batch_size, length] indicating the
69 | locations of the padding tokens in x.
70 | """
71 | with tf.name_scope("embedding"):
72 | # Create binary mask of size [batch_size, length]
73 | mask = tf.to_float(tf.not_equal(x, 0))
74 |
75 | if self.method == "gather":
76 | embeddings = tf.gather(self.shared_weights, x)
77 | embeddings *= tf.expand_dims(mask, -1)
78 | else: # matmul
79 | embeddings = tpu_utils.embedding_matmul(
80 | embedding_table=self.shared_weights,
81 | values=tf.cast(x, dtype=tf.int32),
82 | mask=mask
83 | )
84 | # embedding_matmul already zeros out masked positions, so
85 | # `embeddings *= tf.expand_dims(mask, -1)` is unnecessary.
86 |
87 |
88 | # Scale embedding by the sqrt of the hidden size
89 | embeddings *= self.hidden_size ** 0.5
90 |
91 | return embeddings
92 |
93 |
94 | def linear(self, x):
95 | """Computes logits by running x through a linear layer.
96 |
97 | Args:
98 | x: A float32 tensor with shape [batch_size, length, hidden_size]
99 | Returns:
100 | float32 tensor with shape [batch_size, length, vocab_size].
101 | """
102 | with tf.name_scope("presoftmax_linear"):
103 | batch_size = tf.shape(x)[0]
104 | length = tf.shape(x)[1]
105 |
106 | x = tf.reshape(x, [-1, self.hidden_size])
107 | logits = tf.matmul(x, self.shared_weights, transpose_b=True)
108 |
109 | return tf.reshape(logits, [batch_size, length, self.vocab_size])
110 |
--------------------------------------------------------------------------------
/official_transformer/model_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Transformer model helper methods."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import math
23 |
24 | import numpy as np
25 | import tensorflow as tf
26 |
27 | # Very low numbers to represent -infinity. We do not actually use -Inf, since we
28 | # want to be able to multiply these values by zero to get zero. (-Inf * 0 = NaN)
29 | _NEG_INF_FP32 = -1e9
30 | _NEG_INF_FP16 = np.finfo(np.float16).min
31 |
32 |
33 | def get_position_encoding(
34 | length, hidden_size, min_timescale=1.0, max_timescale=1.0e2): # TODO previous max_timescale=1.0e4
35 | """Return positional encoding.
36 |
37 | Calculates the position encoding as a mix of sine and cosine functions with
38 | geometrically increasing wavelengths.
39 | Defined and formulized in Attention is All You Need, section 3.5.
40 |
41 | Args:
42 | length: Sequence length.
43 | hidden_size: Size of the
44 | min_timescale: Minimum scale that will be applied at each position
45 | max_timescale: Maximum scale that will be applied at each position
46 |
47 | Returns:
48 | Tensor with shape [length, hidden_size]
49 | """
50 | # We compute the positional encoding in float32 even if the model uses
51 | # float16, as many of the ops used, like log and exp, are numerically unstable
52 | # in float16.
53 | position = tf.cast(tf.range(length), tf.float32)
54 | num_timescales = hidden_size // 2
55 | log_timescale_increment = (
56 | math.log(float(max_timescale) / float(min_timescale)) /
57 | (tf.cast(num_timescales, tf.float32) - 1))
58 | inv_timescales = min_timescale * tf.exp(
59 | tf.cast(tf.range(num_timescales), tf.float32) * -log_timescale_increment)
60 | scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0)
61 | signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
62 | return signal
63 |
64 |
65 | def get_decoder_self_attention_bias(length, dtype=tf.float32):
66 | """Calculate bias for decoder that maintains model's autoregressive property.
67 |
68 | Creates a tensor that masks out locations that correspond to illegal
69 | connections, so prediction at position i cannot draw information from future
70 | positions.
71 |
72 | Args:
73 | length: int length of sequences in batch.
74 | dtype: The dtype of the return value.
75 |
76 | Returns:
77 | float tensor of shape [1, 1, length, length]
78 | """
79 | neg_inf = _NEG_INF_FP16 if dtype == tf.float16 else _NEG_INF_FP32
80 | with tf.name_scope("decoder_self_attention_bias"):
81 | valid_locs = tf.linalg.band_part(input=tf.ones([length, length], dtype=dtype), # 2个length分别是输入和输出的最大长度
82 | num_lower=-1, num_upper=0) # Lower triangular part is 1, other is 0
83 | valid_locs = tf.reshape(valid_locs, [1, 1, length, length])
84 | decoder_bias = neg_inf * (1.0 - valid_locs)
85 | return decoder_bias
86 |
87 |
88 | def get_padding(x, padding_value=0, dtype=tf.float32):
89 | """Return float tensor representing the padding values in x.
90 |
91 | Args:
92 | x: int tensor with any shape
93 | padding_value: int value that
94 | dtype: The dtype of the return value.
95 |
96 | Returns:
97 | float tensor with same shape as x containing values 0 or 1.
98 | 0 -> non-padding, 1 -> padding
99 | """
100 | with tf.name_scope("padding"):
101 | return tf.cast(tf.equal(x, padding_value), dtype)
102 |
103 |
104 | def get_padding_bias(x):
105 | """Calculate bias tensor from padding values in tensor.
106 |
107 | Bias tensor that is added to the pre-softmax multi-headed attention logits,
108 | which has shape [batch_size, num_heads, length, length]. The tensor is zero at
109 | non-padding locations, and -1e9 (negative infinity) at padding locations.
110 |
111 | Args:
112 | x: int tensor with shape [batch_size, length]
113 |
114 | Returns:
115 | Attention bias tensor of shape [batch_size, 1, 1, length].
116 | """
117 | with tf.name_scope("attention_bias"):
118 | padding = get_padding(x)
119 | attention_bias = padding * _NEG_INF_FP32
120 | attention_bias = tf.expand_dims(
121 | tf.expand_dims(attention_bias, axis=1), axis=1)
122 | return attention_bias
123 |
--------------------------------------------------------------------------------
/official_transformer/tpu.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Functions specific to running TensorFlow on TPUs."""
17 |
18 | import tensorflow as tf
19 |
20 |
21 | # "local" is a magic word in the TPU cluster resolver; it informs the resolver
22 | # to use the local CPU as the compute device. This is useful for testing and
23 | # debugging; the code flow is ostensibly identical, but without the need to
24 | # actually have a TPU on the other end.
25 | LOCAL = "local"
26 |
27 |
28 | def construct_scalar_host_call(metric_dict, model_dir, prefix=""):
29 | """Construct a host call to log scalars when training on TPU.
30 |
31 | Args:
32 | metric_dict: A dict of the tensors to be logged.
33 | model_dir: The location to write the summary.
34 | prefix: The prefix (if any) to prepend to the metric names.
35 |
36 | Returns:
37 | A tuple of (function, args_to_be_passed_to_said_function)
38 | """
39 | # type: (dict, str) -> (function, list)
40 | metric_names = list(metric_dict.keys())
41 |
42 | def host_call_fn(global_step, *args):
43 | """Training host call. Creates scalar summaries for training metrics.
44 |
45 | This function is executed on the CPU and should not directly reference
46 | any Tensors in the rest of the `model_fn`. To pass Tensors from the
47 | model to the `metric_fn`, provide as part of the `host_call`. See
48 | https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
49 | for more information.
50 |
51 | Arguments should match the list of `Tensor` objects passed as the second
52 | element in the tuple passed to `host_call`.
53 |
54 | Args:
55 | global_step: `Tensor with shape `[batch]` for the global_step
56 | *args: Remaining tensors to log.
57 |
58 | Returns:
59 | List of summary ops to run on the CPU host.
60 | """
61 | step = global_step[0]
62 | with tf.contrib.summary.create_file_writer(
63 | logdir=model_dir, filename_suffix=".host_call").as_default():
64 | with tf.contrib.summary.always_record_summaries():
65 | for i, name in enumerate(metric_names):
66 | tf.contrib.summary.scalar(prefix + name, args[i][0], step=step)
67 |
68 | return tf.contrib.summary.all_summary_ops()
69 |
70 | # To log the current learning rate, and gradient norm for Tensorboard, the
71 | # summary op needs to be run on the host CPU via host_call. host_call
72 | # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
73 | # dimension. These Tensors are implicitly concatenated to
74 | # [params['batch_size']].
75 | global_step_tensor = tf.reshape(
76 | tf.compat.v1.train.get_or_create_global_step(), [1])
77 | other_tensors = [tf.reshape(metric_dict[key], [1]) for key in metric_names]
78 |
79 | return host_call_fn, [global_step_tensor] + other_tensors
80 |
81 |
82 | def embedding_matmul(embedding_table, values, mask, name="embedding_matmul"):
83 | """Performs embedding lookup via a matmul.
84 |
85 | The matrix to be multiplied by the embedding table Tensor is constructed
86 | via an implementation of scatter based on broadcasting embedding indices
87 | and performing an equality comparison against a broadcasted
88 | range(num_embedding_table_rows). All masked positions will produce an
89 | embedding vector of zeros.
90 |
91 | Args:
92 | embedding_table: Tensor of embedding table.
93 | Rank 2 (table_size x embedding dim)
94 | values: Tensor of embedding indices. Rank 2 (batch x n_indices)
95 | mask: Tensor of mask / weights. Rank 2 (batch x n_indices)
96 | name: Optional name scope for created ops
97 |
98 | Returns:
99 | Rank 3 tensor of embedding vectors.
100 | """
101 |
102 | with tf.name_scope(name):
103 | n_embeddings = embedding_table.get_shape().as_list()[0]
104 | batch_size, padded_size = values.shape.as_list()
105 |
106 | emb_idcs = tf.tile(
107 | tf.reshape(values, (batch_size, padded_size, 1)), (1, 1, n_embeddings))
108 | emb_weights = tf.tile(
109 | tf.reshape(mask, (batch_size, padded_size, 1)), (1, 1, n_embeddings))
110 | col_idcs = tf.tile(
111 | tf.reshape(tf.range(n_embeddings), (1, 1, n_embeddings)),
112 | (batch_size, padded_size, 1))
113 | one_hot = tf.where(
114 | tf.equal(emb_idcs, col_idcs), emb_weights,
115 | tf.zeros((batch_size, padded_size, n_embeddings)))
116 |
117 | return tf.tensordot(one_hot, embedding_table, 1)
118 |
--------------------------------------------------------------------------------
/start.sh:
--------------------------------------------------------------------------------
1 | # 意图识别和填槽模型
2 | start_tm=`date +%s%N`;
3 |
4 | export HOST_NAME=$1
5 | if [[ "wzk" == "$HOST_NAME" ]]
6 | then
7 | # set gpu id to use
8 | export CUDA_VISIBLE_DEVICES=0
9 | else
10 | # not use gpu
11 | export CUDA_VISIBLE_DEVICES=""
12 | fi
13 |
14 | ### Optional parameters ###
15 | # To quickly test that model training works, set the number of epochs to a
16 | # smaller value (e.g. 0.01).
17 | export DOMAIN_NAME="music"
18 | export input_format="nlu"
19 | export num_train_epochs=5
20 | export TRAIN_BATCH_SIZE=100
21 | export learning_rate=1e-4
22 | export warmup_proportion=0.1
23 | export max_seq_length=45
24 | export drop_keep_prob=0.9
25 | export MAX_INPUT_EXAMPLES=1000000
26 | export SAVE_CHECKPOINT_STEPS=1000
27 | export CORPUS_DIR="/home/${HOST_NAME}/Mywork/corpus/compe/69"
28 | export BERT_BASE_DIR="/home/${HOST_NAME}/Mywork/model/chinese_L-12_H-768_A-12"
29 | export CONFIG_FILE=configs/lasertagger_config.json
30 | export OUTPUT_DIR="${CORPUS_DIR}/${DOMAIN_NAME}_output"
31 | export MODEL_DIR="${OUTPUT_DIR}/${DOMAIN_NAME}_models"
32 | export do_lower_case=true
33 | export kernel_size=3
34 | export label_map_file=${OUTPUT_DIR}/label_map.json
35 | export slot_label_map_file=${OUTPUT_DIR}/slot_label_map.json
36 | export SUBMIT_FILE=${MODEL_DIR}/submit.csv
37 | export entity_type_list_file=${OUTPUT_DIR}/entity_type_list.json
38 |
39 | # Check these numbers from the "*.num_examples" files created in step 2.
40 | export NUM_TRAIN_EXAMPLES=300000
41 | export NUM_EVAL_EXAMPLES=5000
42 |
43 | python preprocess_main.py \
44 | --input_file=${CORPUS_DIR}/train.txt \
45 | --input_format=${input_format} \
46 | --output_tfrecord_train=${OUTPUT_DIR}/train.tf_record \
47 | --output_tfrecord_dev=${OUTPUT_DIR}/dev.tf_record \
48 | --label_map_file=${label_map_file} \
49 | --slot_label_map_file=${slot_label_map_file} \
50 | --vocab_file=${BERT_BASE_DIR}/vocab.txt \
51 | --max_seq_length=${max_seq_length} \
52 | --do_lower_case=${do_lower_case} \
53 | --domain_name=${DOMAIN_NAME} \
54 | --entity_type_list_file=${entity_type_list_file}
55 |
56 |
57 |
58 | echo "Train the model."
59 | python run_lasertagger.py \
60 | --training_file=${OUTPUT_DIR}/train.tf_record \
61 | --eval_file=${OUTPUT_DIR}/dev.tf_record \
62 | --label_map_file=${label_map_file} \
63 | --slot_label_map_file=${slot_label_map_file} \
64 | --model_config_file=${CONFIG_FILE} \
65 | --output_dir=${MODEL_DIR} \
66 | --init_checkpoint=${BERT_BASE_DIR}/bert_model.ckpt \
67 | --do_train=true \
68 | --do_eval=true \
69 | --num_train_epochs=${num_train_epochs} \
70 | --train_batch_size=${TRAIN_BATCH_SIZE} \
71 | --learning_rate=${learning_rate} \
72 | --warmup_proportion=${warmup_proportion} \
73 | --drop_keep_prob=${drop_keep_prob} \
74 | --kernel_size=${kernel_size} \
75 | --save_checkpoints_steps=${SAVE_CHECKPOINT_STEPS} \
76 | --max_seq_length=${max_seq_length} \
77 | --num_train_examples=${NUM_TRAIN_EXAMPLES} \
78 | --num_eval_examples=${NUM_EVAL_EXAMPLES} \
79 | --domain_name=${DOMAIN_NAME} \
80 | --entity_type_list_file=${entity_type_list_file}
81 |
82 |
83 | ## 4. Prediction
84 |
85 | ### Export the model.
86 | echo "Export the model."
87 | python run_lasertagger.py \
88 | --label_map_file=${label_map_file} \
89 | --slot_label_map_file=${slot_label_map_file} \
90 | --model_config_file=${CONFIG_FILE} \
91 | --max_seq_length=${max_seq_length} \
92 | --kernel_size=${kernel_size} \
93 | --output_dir=${MODEL_DIR} \
94 | --do_export=true \
95 | --export_path="${MODEL_DIR}/export" \
96 | --domain_name=${DOMAIN_NAME} \
97 | --entity_type_list_file=${entity_type_list_file}
98 |
99 | ######### Get the most recently exported model directory.
100 | TIMESTAMP=$(ls "${MODEL_DIR}/export/" | \
101 | grep -v "temp-" | sort -r | head -1)
102 | SAVED_MODEL_DIR=${MODEL_DIR}/export/${TIMESTAMP}
103 | PREDICTION_FILE=${MODEL_DIR}/pred.tsv
104 |
105 |
106 | echo "predict_main.py for eval"
107 | python predict_main.py \
108 | --input_file=${OUTPUT_DIR}/pred.tsv \
109 | --input_format=${input_format} \
110 | --output_file=${PREDICTION_FILE} \
111 | --label_map_file=${label_map_file} \
112 | --slot_label_map_file=${slot_label_map_file} \
113 | --vocab_file=${BERT_BASE_DIR}/vocab.txt \
114 | --max_seq_length=${max_seq_length} \
115 | --do_lower_case=${do_lower_case} \
116 | --saved_model=${SAVED_MODEL_DIR} \
117 | --domain_name=${DOMAIN_NAME} \
118 | --entity_type_list_file=${entity_type_list_file}
119 |
120 | ####### 5. Evaluation
121 | echo "python score_main.py --prediction_file=" ${PREDICTION_FILE}
122 | python score_main.py --prediction_file=${PREDICTION_FILE} --vocab_file=${BERT_BASE_DIR}/vocab.txt --do_lower_case=true --domain_name=${DOMAIN_NAME}
123 |
124 |
125 | #echo "predict_main.py for test"
126 | #python predict_main.py \
127 | # --input_file=${OUTPUT_DIR}/submit.csv \
128 | # --input_format=${input_format} \
129 | # --output_file=${PREDICTION_FILE} \
130 | # --submit_file=${SUBMIT_FILE} \
131 | # --label_map_file=${label_map_file} \
132 | # --slot_label_map_file=${slot_label_map_file} \
133 | # --vocab_file=${BERT_BASE_DIR}/vocab.txt \
134 | # --max_seq_length=${max_seq_length} \
135 | # --do_lower_case=${do_lower_case} \
136 | # --saved_model=${SAVED_MODEL_DIR} \
137 | # --domain_name=${DOMAIN_NAME} \
138 | # --entity_type_list_file=${entity_type_list_file}
139 |
140 | end_tm=`date +%s%N`;
141 | use_tm=`echo $end_tm $start_tm | awk '{ print ($1 - $2) / 1000000000 /3600}'`
142 | echo "cost time" $use_tm "h"
143 |
--------------------------------------------------------------------------------
/score_lib.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | """Utility functions for computing evaluation metrics."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 |
22 | from __future__ import print_function
23 |
24 | import re
25 | from typing import List, Text, Tuple
26 |
27 | import sari_hook
28 | import utils
29 |
30 | import tensorflow as tf
31 | from curLine_file import normal_transformer, curLine
32 |
33 | def read_data(
34 | path,
35 | lowercase,
36 | target_domain_name):
37 | """Reads data from prediction TSV file.
38 |
39 | The prediction file should contain 3 or more columns:
40 | 1: sources (concatenated)
41 | 2: prediction
42 | 3-n: targets (1 or more)
43 |
44 | Args:
45 | path: Path to the prediction file.
46 | lowercase: Whether to lowercase the data (to compute case insensitive
47 | scores).
48 |
49 | Returns:
50 | Tuple (list of sources, list of predictions, list of target lists)
51 | """
52 | sources = []
53 | predDomain_list = []
54 | predIntent_list = []
55 | domain_list = []
56 | right_intent_num = 0
57 | right_slot_num = 0
58 | exact_num = 0
59 | with tf.gfile.GFile(path) as f:
60 | for lineId, line in enumerate(f):
61 | if "sessionId" in line and "pred" in line:
62 | continue
63 | sessionId, query, predDomain,predIntent,predSlot, domain,intent,Slot = line.rstrip('\n').split('\t')
64 | # if target_domain_name == predDomain and target_domain_name != domain: # domain 错了
65 | # print(curLine(), lineId, predSlot, "Slot:", Slot, "predDomain:%s, domain:%s" % (predDomain, domain), "predIntent:",
66 | # predIntent)
67 | if predIntent == intent:
68 | right_intent_num += 1
69 | if predSlot == Slot:
70 | exact_num += 1
71 |
72 | if predSlot == Slot:
73 | right_slot_num += 1
74 | else:
75 | if target_domain_name == predDomain and target_domain_name == domain: #
76 | print(curLine(), predSlot, "Slot:", Slot, "predDomain:%s, domain:%s" % (predDomain, domain), "predIntent:", predIntent)
77 | predDomain_list.append(predDomain)
78 | predIntent_list.append(predIntent)
79 | domain_list.append(domain)
80 | return predDomain_list, predIntent_list, domain_list, right_intent_num, right_slot_num, exact_num
81 |
82 |
83 | def compute_exact_score(predictions, domain_list):
84 | """Computes the Exact score (accuracy) of the predictions.
85 |
86 | Exact score is defined as the percentage of predictions that match at least
87 | one of the targets.
88 |
89 | Args:
90 | predictions: List of predictions.
91 | target_lists: List of targets (1 or more per prediction).
92 |
93 | Returns:
94 | Exact score between [0, 1].
95 | """
96 |
97 | correct_domain = 0
98 | for p, d in zip(predictions, domain_list):
99 | if p==d:
100 | correct_domain += 1
101 | return correct_domain / max(len(predictions), 0.1) # Avoids 0/0.
102 |
103 |
104 | def compute_sari_scores(
105 | sources,
106 | predictions,
107 | target_lists,
108 | ignore_wikisplit_separators = True,
109 | tokenizer=None):
110 | """Computes SARI scores.
111 |
112 | Wraps the t2t implementation of SARI computation.
113 |
114 | Args:
115 | sources: List of sources.
116 | predictions: List of predictions.
117 | target_lists: List of targets (1 or more per prediction).
118 | ignore_wikisplit_separators: Whether to ignore "<::::>" tokens, used as
119 | sentence separators in Wikisplit, when evaluating. For the numbers
120 | reported in the paper, we accidentally ignored those tokens. Ignoring them
121 | does not affect the Exact score (since there's usually always a period
122 | before the separator to indicate sentence break), but it decreases the
123 | SARI score (since the Addition score goes down as the model doesn't get
124 | points for correctly adding <::::> anymore).
125 |
126 | Returns:
127 | Tuple (SARI score, keep score, addition score, deletion score).
128 | """
129 | sari_sum = 0
130 | keep_sum = 0
131 | add_sum = 0
132 | del_sum = 0
133 | length_sum = 0
134 | length_max = 0
135 | for source, pred, targets in zip(sources, predictions, target_lists):
136 | if ignore_wikisplit_separators:
137 | source = re.sub(' <::::> ', ' ', source)
138 | pred = re.sub(' <::::> ', ' ', pred)
139 | targets = [re.sub(' <::::> ', ' ', t) for t in targets]
140 | source_ids = tokenizer.tokenize(source) # utils.get_token_list(source)
141 | pred_ids = tokenizer.tokenize(pred) # utils.get_token_list(pred)
142 | list_of_targets = [tokenizer.tokenize(t) for t in targets]
143 | length_sum += len(pred_ids)
144 | length_max = max(length_max, len(pred_ids))
145 | sari, keep, addition, deletion = sari_hook.get_sari_score(
146 | source_ids, pred_ids, list_of_targets, beta_for_deletion=1)
147 | sari_sum += sari
148 | keep_sum += keep
149 | add_sum += addition
150 | del_sum += deletion
151 | n = max(len(sources), 0.1) # Avoids 0/0.
152 | return (sari_sum / n, keep_sum / n, add_sum / n, del_sum / n, length_sum/n, length_max)
--------------------------------------------------------------------------------
/find_entity/exacter_acmation.py:
--------------------------------------------------------------------------------
1 | # 发现疑似实体,辅助预测
2 | import os
3 | from collections import defaultdict
4 | import re
5 | import json
6 | from .acmation import KeywordTree, add_to_ac, entity_files_folder, entity_folder
7 | from curLine_file import curLine, normal_transformer
8 |
9 | re_phoneNum = re.compile("[0-9一二三四五六七八九十拾]+") # 编译
10 |
11 | # AC自动机, similar to trie tree
12 | # 也许直接读取下载的xls文件更方便,但那样需要安装xlrd模块
13 |
14 | domain2entity_map = {}
15 | domain2entity_map["music"] = ["age", "singer", "song", "toplist", "theme", "style", "scene", "language", "emotion", "instrument"]
16 | # domain2entity_map["navigation"] = ["custom_destination", "city"] # "destination", "origin"]
17 | # domain2entity_map["phone_call"] = ["phone_num", "contact_name"]
18 | self_entity_trie_tree = {} # 总的实体字典 自己建立的某些实体类型的实体树
19 | for domain, entity_type_list in domain2entity_map.items():
20 | print(curLine(), domain, entity_type_list)
21 | for entity_type in entity_type_list:
22 | if entity_type not in self_entity_trie_tree:
23 | ac = KeywordTree(case_insensitive=True)
24 | else:
25 | ac = self_entity_trie_tree[entity_type]
26 |
27 | ### 从标注语料中挖掘得到
28 | entity_file = os.path.join(entity_files_folder, "%s.json" % entity_type)
29 | with open(entity_file, "r") as fr:
30 | current_entity_dict = json.load(fr)
31 | print(curLine(), "get %d %s from %s" % (len(current_entity_dict), entity_type, entity_file))
32 | for entity_before, entity_after_times in current_entity_dict.items():
33 | entity_after = entity_after_times[0]
34 | pri = 2
35 | if entity_type in ["song"]:
36 | pri -= 0.5
37 | add_to_ac(ac, entity_type, entity_before, entity_after, pri=pri)
38 | if entity_type == "song":
39 | add_to_ac(ac, entity_type, "花的画", "花的话", 1.5)
40 |
41 | # 给的实体库,最高优先级
42 | entity_file = os.path.join(entity_folder, "%s.txt" % entity_type)
43 | if os.path.exists(entity_file):
44 | with open(entity_file, "r") as fr:
45 | lines = fr.readlines()
46 | print(curLine(), "get %d %s from %s" % (len(lines), entity_type, entity_file))
47 | for line in lines:
48 | entity_after = line.strip()
49 | entity_before = entity_after # TODO
50 | pri = 3
51 | if entity_type in ["song"]:
52 | pri -= 0.5
53 | add_to_ac(ac, entity_type, entity_before, entity_after, pri=pri)
54 | ac.finalize()
55 | self_entity_trie_tree[entity_type] = ac
56 |
57 |
58 | def get_all_entity(corpus, useEntityTypeList):
59 | self_entityTypeMap = defaultdict(list)
60 | for entity_type in useEntityTypeList:
61 | result = self_entity_trie_tree[entity_type].search(corpus)
62 | for res in result:
63 | after, priority = res.meta_data
64 | self_entityTypeMap[entity_type].append({'before': res.keyword, 'after': after, "priority":priority})
65 | if "phone_num" in useEntityTypeList:
66 | token_numbers = re_phoneNum.findall(corpus)
67 | for number in token_numbers:
68 | self_entityTypeMap["phone_num"].append({'before':number, 'after':number, 'priority': 2})
69 | return self_entityTypeMap
70 |
71 |
72 | def get_slot_info(query, domain):
73 | useEntityTypeList = domain2entity_map[domain]
74 | entityTypeMap = get_all_entity(query, useEntityTypeList=useEntityTypeList)
75 | entity_list_all = [] # 汇总所有实体
76 | for entity_type, entity_list in entityTypeMap.items():
77 | for entity in entity_list:
78 | entity_before = entity['before']
79 | ignore_flag = False
80 | if entity_type != "song" and len(entity_before) < 2 and entity_before not in ["家","妈"]:
81 | ignore_flag = True
82 | if entity_type == "song" and len(entity_before) < 2 and \
83 | entity_before not in {"鱼", "云", "逃", "退", "陶", "美", "图", "默"}:
84 | ignore_flag = True
85 | if entity_before in {"什么歌", "一首", "小花","叮当","傻逼", "给你听", "现在", "当我"}:
86 | ignore_flag = True
87 | if ignore_flag:
88 | if entity_before not in "好点没走伤":
89 | print(curLine(), "ignore entity_type:%s, entity:%s, query:%s"
90 | % (entity_type, entity_before, query))
91 | else:
92 | entity_list_all.append((entity_type, entity_before, entity['after'], entity['priority']))
93 | entity_list_all = sorted(entity_list_all, key=lambda item: len(item[1])*100+item[3],
94 | reverse=True) # new_entity_map 中key是实体,value是实体类型
95 | slot_info = query
96 | exist_entityType_set = set()
97 | replace_mask = [0] * len(query)
98 | for entity_type, entity_before, entity_after, priority in entity_list_all:
99 | if entity_before not in query:
100 | continue
101 | if entity_type in exist_entityType_set:
102 | continue # 已经有这个类型了,忽略 # TODO
103 | start_location = slot_info.find(entity_before)
104 | if start_location > -1:
105 | exist_entityType_set.add(entity_type)
106 | if entity_after == entity_before:
107 | entity_info_str = "<%s>%s%s>" % (entity_type, entity_after, entity_type)
108 | else:
109 | entity_info_str = "<%s>%s||%s%s>" % (entity_type, entity_before, entity_after, entity_type)
110 | slot_info = slot_info.replace(entity_before, entity_info_str)
111 | query = query.replace(entity_before, "")
112 | else:
113 | print(curLine(), replace_mask, slot_info, "entity_type:", entity_type, entity_before)
114 | return slot_info
115 |
116 | if __name__ == '__main__':
117 | for query in ["拨打10086", "打电话给100十五", "打电话给一二三拾"]:
118 | res = get_slot_info(query, domain="phone_call")
119 | print(curLine(), query, res)
120 |
121 | for query in ["节奏来一首一朵鲜花送给心上人", "陈瑞的歌曲", "放一首销愁"]:
122 | res = get_slot_info(query, domain="music")
123 | print(curLine(), query, res)
--------------------------------------------------------------------------------
/official_transformer/attention_layer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Implementation of multiheaded attention and self-attention layers."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import tensorflow as tf
23 |
24 |
25 | class Attention(tf.layers.Layer):
26 | """Multi-headed attention layer."""
27 |
28 | def __init__(self, hidden_size, num_heads, attention_dropout, train):
29 | if hidden_size % num_heads != 0:
30 | raise ValueError("Hidden size must be evenly divisible by the number of "
31 | "heads.")
32 |
33 | super(Attention, self).__init__()
34 | self.hidden_size = hidden_size
35 | self.num_heads = num_heads
36 | self.attention_dropout = attention_dropout
37 | self.train = train
38 |
39 | # Layers for linearly projecting the queries, keys, and values.
40 | self.q_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="q")
41 | self.k_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="k")
42 | self.v_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="v")
43 |
44 | self.output_dense_layer = tf.layers.Dense(hidden_size, use_bias=False,
45 | name="output_transform")
46 |
47 | def split_heads(self, x):
48 | """Split x into different heads, and transpose the resulting value.
49 |
50 | The tensor is transposed to insure the inner dimensions hold the correct
51 | values during the matrix multiplication.
52 |
53 | Args:
54 | x: A tensor with shape [batch_size, length, hidden_size]
55 |
56 | Returns:
57 | A tensor with shape [batch_size, num_heads, length, hidden_size/num_heads]
58 | """
59 | with tf.name_scope("split_heads"):
60 | batch_size = tf.shape(x)[0]
61 | length = tf.shape(x)[1]
62 |
63 | # Calculate depth of last dimension after it has been split.
64 | depth = (self.hidden_size // self.num_heads)
65 |
66 | # Split the last dimension
67 | x = tf.reshape(x, [batch_size, length, self.num_heads, depth])
68 |
69 | # Transpose the result
70 | return tf.transpose(x, [0, 2, 1, 3])
71 |
72 | def combine_heads(self, x):
73 | """Combine tensor that has been split.
74 |
75 | Args:
76 | x: A tensor [batch_size, num_heads, length, hidden_size/num_heads]
77 |
78 | Returns:
79 | A tensor with shape [batch_size, length, hidden_size]
80 | """
81 | with tf.name_scope("combine_heads"):
82 | batch_size = tf.shape(x)[0]
83 | length = tf.shape(x)[2]
84 | x = tf.transpose(x, [0, 2, 1, 3]) # --> [batch, length, num_heads, depth]
85 | return tf.reshape(x, [batch_size, length, self.hidden_size])
86 |
87 | def call(self, x, y, bias, cache=None):
88 | """Apply attention mechanism to x and y.
89 |
90 | Args:
91 | x: a tensor with shape [batch_size, length_x, hidden_size]
92 | y: a tensor with shape [batch_size, length_y, hidden_size]
93 | bias: attention bias that will be added to the result of the dot product.
94 | cache: (Used during prediction) dictionary with tensors containing results
95 | of previous attentions. The dictionary must have the items:
96 | {"k": tensor with shape [batch_size, i, key_channels],
97 | "v": tensor with shape [batch_size, i, value_channels]}
98 | where i is the current decoded length.
99 |
100 | Returns:
101 | Attention layer output with shape [batch_size, length_x, hidden_size]
102 | """
103 | # Linearly project the query (q), key (k) and value (v) using different
104 | # learned projections. This is in preparation of splitting them into
105 | # multiple heads. Multi-head attention uses multiple queries, keys, and
106 | # values rather than regular attention (which uses a single q, k, v).
107 | q = self.q_dense_layer(x)
108 | k = self.k_dense_layer(y)
109 | v = self.v_dense_layer(y)
110 |
111 | if cache is not None:
112 | # Combine cached keys and values with new keys and values.
113 | k = tf.concat([cache["k"], k], axis=1)
114 | v = tf.concat([cache["v"], v], axis=1)
115 |
116 | # Update cache
117 | cache["k"] = k
118 | cache["v"] = v
119 |
120 | # Split q, k, v into heads.
121 | q = self.split_heads(q)
122 | k = self.split_heads(k)
123 | v = self.split_heads(v)
124 |
125 | # Scale q to prevent the dot product between q and k from growing too large.
126 | depth = (self.hidden_size // self.num_heads)
127 | q *= depth ** -0.5
128 |
129 | # Calculate dot product attention
130 | logits = tf.matmul(q, k, transpose_b=True)
131 | logits += bias
132 | weights = tf.nn.softmax(logits, name="attention_weights")
133 | if self.train:
134 | weights = tf.nn.dropout(weights, 1.0 - self.attention_dropout)
135 | attention_output = tf.matmul(weights, v)
136 |
137 | # Recombine heads --> [batch_size, length, hidden_size]
138 | attention_output = self.combine_heads(attention_output)
139 |
140 | # Run the combined outputs through another linear projection layer.
141 | attention_output = self.output_dense_layer(attention_output)
142 | return attention_output
143 |
144 |
145 | class SelfAttention(Attention):
146 | """Multiheaded self-attention layer."""
147 |
148 | def call(self, x, bias, cache=None):
149 | return super(SelfAttention, self).call(x, x, bias, cache)
150 |
--------------------------------------------------------------------------------
/preprocess_main.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | """Convert a dataset into the TFRecord format.
18 |
19 | The resulting TFRecord file will be used when training a LaserTagger model.
20 | """
21 |
22 | from __future__ import absolute_import
23 | from __future__ import division
24 |
25 | from __future__ import print_function
26 |
27 | from typing import Text
28 |
29 | from absl import app
30 | from absl import flags
31 | from absl import logging
32 | from bert import tokenization
33 | import bert_example
34 | import json
35 | import utils
36 |
37 | import tensorflow as tf
38 | from find_entity.probable_acmation import domain2entity_map, get_all_entity
39 | from curLine_file import curLine
40 |
41 |
42 | FLAGS = flags.FLAGS
43 |
44 | flags.DEFINE_string(
45 | 'input_file', None,
46 | 'Path to the input file containing examples to be converted to '
47 | 'tf.Examples.')
48 | flags.DEFINE_enum(
49 | 'input_format', None, ['nlu'],
50 | 'Format which indicates how to parse the input_file.')
51 | flags.DEFINE_string('output_tfrecord_train', None,
52 | 'Path to the resulting TFRecord file.')
53 | flags.DEFINE_string('output_tfrecord_dev', None,
54 | 'Path to the resulting TFRecord file.')
55 | flags.DEFINE_string(
56 | 'label_map_file', None,
57 | 'Path to the label map file. Either a JSON file ending with ".json", that '
58 | 'maps each possible tag to an ID, or a text file that has one tag per line.')
59 | flags.DEFINE_string(
60 | 'slot_label_map_file', None,
61 | 'Path to the label map file. Either a JSON file ending with ".json", that '
62 | 'maps each possible tag to an ID, or a text file that has one tag per line.')
63 | flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.')
64 | flags.DEFINE_integer('max_seq_length', 128, 'Maximum sequence length.')
65 | flags.DEFINE_bool(
66 | 'do_lower_case', False,
67 | 'Whether to lower case the input text. Should be True for uncased '
68 | 'models and False for cased models.')
69 | flags.DEFINE_string(
70 | 'domain_name', None, 'domain_name')
71 | flags.DEFINE_string(
72 | 'entity_type_list_file', None, 'path of entity_type_list_file')
73 |
74 | def _write_example_count(count: int, output_file: str) -> Text:
75 | """Saves the number of converted examples to a file.
76 |
77 | This count is used when determining the number of training steps.
78 |
79 | Args:
80 | count: The number of converted examples.
81 |
82 | Returns:
83 | The filename to which the count is saved.
84 | """
85 | count_fname = output_file + '.num_examples.txt'
86 | with tf.gfile.GFile(count_fname, 'w') as count_writer:
87 | count_writer.write(str(count))
88 | return count_fname
89 |
90 |
91 |
92 | def main(argv):
93 | if len(argv) > 1:
94 | raise app.UsageError('Too many command-line arguments.')
95 | flags.mark_flag_as_required('input_file')
96 | flags.mark_flag_as_required('input_format')
97 | flags.mark_flag_as_required('output_tfrecord_train')
98 | flags.mark_flag_as_required('output_tfrecord_dev')
99 | flags.mark_flag_as_required('vocab_file')
100 | target_domain_name = FLAGS.domain_name
101 | entity_type_list = domain2entity_map[target_domain_name]
102 |
103 | print(curLine(), "target_domain_name:", target_domain_name, len(entity_type_list), "entity_type_list:", entity_type_list)
104 | builder = bert_example.BertExampleBuilder({}, FLAGS.vocab_file,
105 | FLAGS.max_seq_length,
106 | FLAGS.do_lower_case, slot_label_map={},
107 | entity_type_list=entity_type_list, get_entity_func=get_all_entity)
108 |
109 | num_converted = 0
110 | num_ignored = 0
111 | # ff = open("new_%s.txt" % target_domain_name, "w") # TODO
112 | with tf.python_io.TFRecordWriter(FLAGS.output_tfrecord_train) as writer_train:
113 | for i, (sources, target) in enumerate(utils.yield_sources_and_targets(
114 | FLAGS.input_file, FLAGS.input_format, target_domain_name)):
115 | logging.log_every_n(
116 | logging.INFO,
117 | f'{i} examples processed, {num_converted} converted to tf.Example.',
118 | 10000)
119 | if len(sources[0]) > 35: # 忽略问题太长的样本
120 | num_ignored += 1
121 | print(curLine(), "ignore num_ignored=%d, question length=%d" % (num_ignored, len(sources[0])))
122 | continue
123 | example1, _, info_str = builder.build_bert_example(sources, target)
124 | example =example1.to_tf_example().SerializeToString()
125 | writer_train.write(example)
126 | num_converted += 1
127 | # ff.write("%d %s\n" % (i, info_str))
128 | logging.info(f'Done. {num_converted} examples converted to tf.Example, num_ignored {num_ignored} examples.')
129 | for output_file in [FLAGS.output_tfrecord_train]:
130 | count_fname = _write_example_count(num_converted, output_file=output_file)
131 | logging.info(f'Wrote:\n{output_file}\n{count_fname}')
132 | with open(FLAGS.label_map_file, "w") as f:
133 | json.dump(builder._label_map, f, ensure_ascii=False, indent=4)
134 | print(curLine(), "save %d to %s" % (len(builder._label_map), FLAGS.label_map_file))
135 | with open(FLAGS.slot_label_map_file, "w") as f:
136 | json.dump(builder.slot_label_map, f, ensure_ascii=False, indent=4)
137 | print(curLine(), "save %d to %s" % (len(builder.slot_label_map), FLAGS.slot_label_map_file))
138 |
139 | with open(FLAGS.entity_type_list_file, "w") as f:
140 | json.dump(domain2entity_map, f, ensure_ascii=False, indent=4)
141 | print(curLine(), "save %d to %s" % (len(domain2entity_map), FLAGS.entity_type_list_file))
142 |
143 |
144 | if __name__ == '__main__':
145 | app.run(main)
--------------------------------------------------------------------------------
/confusion_words.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # 混淆词提取算法 不考虑多音字,只处理同音字的情况
3 | from pypinyin import lazy_pinyin, pinyin, Style
4 | # pinyin 的方法默认带声调,而 lazy_pinyin 方法不带声调,但是会得到不常见的多音字
5 | # pinyin(c, heteronym=True, style=0) 不考虑音调情况下的多音字
6 | from Levenshtein import distance # python-Levenshtein
7 | # 编辑距离 (Levenshtein Distance算法)
8 | import os
9 | from curLine_file import curLine
10 | from find_entity.acmation import entity_folder
11 |
12 | number_map = {"0":"零", "1":"一", "2":"二", "3":"三", "4":"四", "5":"五", "6":"六", "7":"七", "8":"八", "9":"九"}
13 | cons_table = {'n': 'l', 'l': 'n', 'f': 'h', 'h': 'f', 'zh': 'z', 'z': 'zh', 'c': 'ch', 'ch': 'c', 's': 'sh', 'sh': 's'}
14 | # def get_similar_pinyin(entity_type): # 存的时候直接把拼音的变体也保存下来
15 | # entity_info_dict = {}
16 | # entity_file = os.path.join(entity_folder, "%s.txt" % entity_type)
17 | # with open(entity_file, "r") as fr:
18 | # lines = fr.readlines()
19 | # pri = 3
20 | # if entity_type in ["song"]:
21 | # pri -= 0.5
22 | # print(curLine(), "get %d %s from %s, pri=%f" % (len(lines), entity_type, entity_file, pri))
23 | # for line in lines:
24 | # entity = line.strip()
25 | # for k,v in number_map.items():
26 | # entity.replace(k, v)
27 | # # for combination in all_combination:
28 | # if entity not in entity_info_dict: # 新的实体
29 | # combination = "".join(lazy_pinyin(entity)) # default:默认行为,不处理,原木原样返回 , errors="ignore"
30 | # if len(combination) < 2:
31 | # print(curLine(), "warning:", entity, "combination:", combination)
32 | # entity_info_dict[entity] = (combination, pri)
33 |
34 | def get_entityType_pinyin(entity_type):
35 | entity_info_dict = {}
36 | entity_file = os.path.join(entity_folder, "%s.txt" % entity_type)
37 | with open(entity_file, "r") as fr:
38 | lines = fr.readlines()
39 | pri = 3
40 | if entity_type in ["song"]:
41 | pri -= 0.5
42 | print(curLine(), "get %d %s from %s, pri=%f" % (len(lines), entity_type, entity_file, pri))
43 | for line in lines:
44 | entity = line.strip()
45 | for k,v in number_map.items():
46 | entity.replace(k, v)
47 | # for combination in all_combination:
48 | if entity not in entity_info_dict: # 新的实体
49 | combination = "".join(lazy_pinyin(entity)) # default:默认行为,不处理,原木原样返回 , errors="ignore"
50 | if len(combination) < 2:
51 | print(curLine(), "warning:", entity, "combination:", combination)
52 | entity_info_dict[entity] = (combination, pri)
53 | else:
54 | combination, old_pri = entity_info_dict[entity]
55 | if pri > old_pri:
56 | entity_info_dict[entity] = (combination, pri)
57 | return entity_info_dict
58 |
59 | # 用编辑距离度量拼音字符串之间的相似度 不考虑相似实体的多音情况
60 | def pinyin_similar_word_noduoyin(entity_info_dict, word):
61 | if word in entity_info_dict: # 存在实体,无需纠错
62 | return 1.0, word
63 | best_similar_word = None
64 | top_similar_score = 0
65 | try:
66 | all_combination = ["".join(lazy_pinyin(word))] # get_pinyin_combination(entity=word) #
67 | for current_combination in all_combination: # 当前的各种发音
68 | if len(current_combination) == 0:
69 | print(curLine(), "word:", word)
70 | continue
71 | similar_word = None
72 | current_distance = 10000
73 | for entity,(com, pri) in entity_info_dict.items():
74 | char_ratio = 0.0
75 | d = distance(com, current_combination)*(1.0-char_ratio) + distance(entity, word) * char_ratio
76 | if d < current_distance:
77 | current_distance = d
78 | similar_word = entity
79 | # if d<=2.5:
80 | # print(curLine(),com, current_combination, distance(com, current_combination), distance(entity, word) )
81 | # print(curLine(), word, entity, similar_word, "current_distance=", current_distance)
82 |
83 |
84 | current_similar_score = 1.0 - float(current_distance) / len(current_combination)
85 | # print(curLine(), "current_combination:%s, %f" % (current_combination, current_similar_score), similar_word, current_distance)
86 | if current_similar_score > top_similar_score:
87 | # print(curLine(), current_similar_score, top_similar_score, best_similar_word, similar_word)
88 | best_similar_word = similar_word
89 | top_similar_score = current_similar_score
90 | except Exception as error:
91 | print(curLine(), "error:", error)
92 | return top_similar_score, best_similar_word
93 |
94 | # 自己包装的函数,返回字符的声母(可能为空,如啊呀),韵母,整体拼音
95 | def my_pinyin(char):
96 | shengmu = pinyin(char, style=Style.INITIALS, strict=True)[0][0]
97 | yunmu = pinyin(char, style=Style.FINALS, strict=True)[0][0]
98 | total_pinyin = lazy_pinyin(char, errors='default')[0]
99 | if shengmu + yunmu != total_pinyin:
100 | print(curLine(), "char:", char, ",shengmu:%s, yunmu:%s" % (shengmu, yunmu), total_pinyin)
101 | return shengmu, yunmu, total_pinyin
102 |
103 | singer_pinyin = get_entityType_pinyin(entity_type="singer")
104 | print(curLine(), len(singer_pinyin), "singer_pinyin")
105 |
106 | song_pinyin = get_entityType_pinyin(entity_type="song")
107 | print(curLine(), len(song_pinyin), "song_pinyin")
108 |
109 |
110 | if __name__ == "__main__":
111 | # pinyin('中心', heteronym=True) # 启用多音字模式
112 |
113 |
114 | import confusion_words_duoyin
115 | res = pinyin_similar_word_noduoyin(song_pinyin, "体验") # confusion_words_duoyin.song_pinyin, "体验")
116 | print(curLine(), res)
117 |
118 | s = "啊饿恩昂你为什*.\"123c单身"
119 | s = "啊饿恩昂你为什单身的呢?"
120 | yunmu_list = pinyin(s, style=Style.FINALS)
121 | shengmu_list = pinyin(s, style=Style.INITIALS, strict=False) # 返回声母,但是严格按照标准(strict=True)有些字没有声母会返回空字符串,strict=False时会返回
122 | pinyin_list = lazy_pinyin(s, errors='default') # ""ignore")
123 | print(curLine(), len(shengmu_list), "shengmu_list:", shengmu_list)
124 | print(curLine(), len(yunmu_list), "yunmu_list:", yunmu_list)
125 | print(curLine(), len(pinyin_list), "pinyin_list:", pinyin_list)
126 | for id, c in enumerate(s):
127 | print(curLine(), id, c, my_pinyin(c))
128 | # print(curLine(), c, pinyin(c, style=Style.INITIALS), s)
129 | # print(curLine(), c,c_pinyin)
130 | # print(curLine(), pinyin(s, heteronym=True, style=0))
131 | # # print(pinyin(c, heteronym=True, style=0))
132 |
133 | # all_combination = get_pinyin_combination(entity=s)
134 | #
135 | # for index, combination in enumerate(all_combination):
136 | # print(curLine(), index, combination)
137 | # for s in ["abc", "abs1", "fabc"]:
138 | # ed = distance("abs", s)
139 | # print(s, ed)
140 | for predict_singer in ["周杰伦", "前任", "朱姐", "奇隆"]:
141 | top_similar_score, best_similar_word = pinyin_similar_word_noduoyin(singer_pinyin, predict_singer)
142 | print(curLine(), predict_singer, top_similar_score, best_similar_word)
143 |
--------------------------------------------------------------------------------
/find_entity/acmation.py:
--------------------------------------------------------------------------------
1 | import json, os
2 | import re
3 | from curLine_file import curLine, normal_transformer
4 |
5 | re_phoneNum = re.compile("[0-9一二三四五六七八九十拾]+") # 编译
6 |
7 | entity_files_folder = "./entity_files"
8 | entity_folder = os.path.join(entity_files_folder, "slot-dictionaries")
9 | frequentSong = {}
10 | frequentSinger = {}
11 | with open(os.path.join(entity_files_folder,"frequentSong.json"), "r") as f:
12 | frequentSong = json.load(f)
13 |
14 | with open(os.path.join(entity_files_folder,"frequentSinger.json"), "r") as f:
15 | frequentSinger = json.load(f)
16 |
17 | # AC自动机, similar to trie tree
18 | class State(object):
19 | __slots__ = ['identifier', 'symbol', 'success', 'transitions', 'parent',
20 | 'matched_keyword', 'longest_strict_suffix', 'meta_data']
21 |
22 | def __init__(self, identifier, symbol=None, parent=None, success=False):
23 | self.symbol = symbol
24 | self.identifier = identifier
25 | self.transitions = {}
26 | self.parent = parent
27 | self.success = success
28 | self.matched_keyword = None
29 | self.longest_strict_suffix = None
30 |
31 |
32 | class Result(object):
33 | __slots__ = ['keyword', 'location', 'meta_data']
34 |
35 | def __init__(self, **kwargs):
36 | for k, v in kwargs.items():
37 | setattr(self, k, v)
38 |
39 | def __str__(self):
40 | return_str = ''
41 | for k in self.__slots__:
42 | return_str += '{}:{:<20}\t'.format(k, json.dumps(getattr(self, k)))
43 | return return_str
44 |
45 | class KeywordTree(object):
46 | def __init__(self, case_insensitive=True):
47 | '''
48 | @param case_insensitive: If true, case will be ignored when searching.
49 | Setting this to true will have a positive
50 | impact on performance.
51 | Defaults to false.
52 | '''
53 | self._zero_state = State(0)
54 | self._counter = 1
55 | self._finalized = False
56 | self._case_insensitive = case_insensitive
57 |
58 | def add(self, keywords, meta_data=None):
59 | '''
60 | Add a keyword to the tree.
61 | Can only be used before finalize() has been called.
62 | Keyword should be str or unicode.
63 | '''
64 | if self._finalized:
65 | raise ValueError('KeywordTree has been finalized.' +
66 | ' No more keyword additions allowed')
67 | original_keyword = keywords
68 | if self._case_insensitive:
69 | if isinstance(keywords, list):
70 | keywords = map(str.lower, keywords)
71 | elif isinstance(keywords, str):
72 | keywords = keywords.lower()
73 | # if keywords != original_keyword:
74 | # print(curLine(), keywords, original_keyword)
75 | # input(curLine())
76 | else:
77 | raise Exception('keywords error')
78 | if len(keywords) <= 0:
79 | return
80 | current_state = self._zero_state
81 | for word in keywords:
82 | try:
83 | current_state = current_state.transitions[word]
84 | except KeyError:
85 | next_state = State(self._counter, parent=current_state,
86 | symbol=word)
87 | self._counter += 1
88 | current_state.transitions[word] = next_state
89 | current_state = next_state
90 | current_state.success = True
91 | current_state.matched_keyword = original_keyword
92 | current_state.meta_data = meta_data
93 |
94 | def search(self, text, greedy=False, cut_word=False, cut_separator=' '):
95 | '''
96 |
97 | :param text:
98 | :param greedy:
99 | :param cut_word:
100 | :param cut_separator:
101 | :return:
102 | '''
103 | gen = self._search(text, cut_word, cut_separator)
104 | pre = None
105 | for result in gen:
106 | assert isinstance(result, Result)
107 | if not greedy:
108 | yield result
109 | continue
110 | if pre is None:
111 | pre = result
112 |
113 | if result.location > pre.location:
114 | yield pre
115 | pre = result
116 | continue
117 |
118 | if len(result.keyword) > len(pre.keyword):
119 | pre = result
120 | continue
121 | if pre is not None:
122 | yield pre
123 |
124 | def _search(self, text, cut_word=False, cut_separator=' '):
125 | '''
126 | Search a text for all occurences of the added keywords.
127 | Can only be called after finalized() has been called.
128 | O(n) with n = len(text)
129 | @return: Generator used to iterate over the results.
130 | Or None if no keyword was found in the text.
131 | '''
132 | # if not self._finalized:
133 | # raise ValueError('KeywordTree has not been finalized.' +
134 | # ' No search allowed. Call finalize() first.')
135 | if self._case_insensitive:
136 | if isinstance(text, list):
137 | text = map(str.lower, text)
138 | elif isinstance(text, str):
139 | text = text.lower()
140 | else:
141 | raise Exception('context type error')
142 |
143 | if cut_word:
144 | if isinstance(text, str):
145 | text = text.split(cut_separator)
146 |
147 | current_state = self._zero_state
148 | for idx, symbol in enumerate(text):
149 | current_state = current_state.transitions.get(
150 | symbol, self._zero_state.transitions.get(symbol, self._zero_state))
151 | state = current_state
152 | while state != self._zero_state:
153 | if state.success:
154 | keyword = state.matched_keyword
155 | yield Result(**{
156 | 'keyword': keyword,
157 | 'location': idx - len(keyword) + 1,
158 | 'meta_data': state.meta_data
159 | })
160 | # yield (keyword, idx - len(keyword) + 1, state.meta_data)
161 | state = state.longest_strict_suffix
162 |
163 | def finalize(self):
164 | '''
165 | Needs to be called after all keywords have been added and
166 | before any searching is performed.
167 | '''
168 | if self._finalized:
169 | raise ValueError('KeywordTree has already been finalized.')
170 | self._zero_state.longest_strict_suffix = self._zero_state
171 | processed = set()
172 | to_process = [self._zero_state]
173 | while to_process:
174 | state = to_process.pop() # 删除并返回最后一个元素,所以这是深度优先搜索
175 | processed.add(state.identifier)
176 | for child in state.transitions.values():
177 | if child.identifier not in processed:
178 | self.search_lss(child)
179 | to_process.append(child)
180 | self._finalized = True
181 |
182 | def __str__(self):
183 | return "ahocorapy KeywordTree"
184 |
185 |
186 | def search_lss(self, state):
187 | if state.longest_strict_suffix is None:
188 | parent = state.parent
189 | traversed = parent.longest_strict_suffix
190 | while True:
191 | if state.symbol in traversed.transitions and \
192 | traversed.transitions[state.symbol] != state:
193 | state.longest_strict_suffix = \
194 | traversed.transitions[state.symbol]
195 | break
196 | elif traversed == self._zero_state:
197 | state.longest_strict_suffix = self._zero_state
198 | break
199 | else:
200 | traversed = traversed.longest_strict_suffix
201 | suffix = state.longest_strict_suffix
202 | if suffix.longest_strict_suffix is None:
203 | self.search_lss(suffix)
204 | for symbol, next_state in suffix.transitions.items():
205 | if (symbol not in state.transitions and
206 | suffix != self._zero_state):
207 | state.transitions[symbol] = next_state
208 |
209 | def add_to_ac(ac, entity_type, entity_before, entity_after, pri):
210 | entity_before = normal_transformer(entity_before)
211 | flag = "ignore"
212 | if entity_type == "song" and ((entity_after in frequentSong) and entity_after not in {"李白"}):
213 | return flag
214 | if entity_type == "singer" and entity_after in frequentSinger:
215 | return flag
216 | elif entity_type == "toplist" and entity_before == "首张":
217 | return flag
218 | elif entity_type == "emotion" and entity_before in {"high歌","相思","喜欢"}: # train和devs就是这么标注的
219 | return flag
220 | elif entity_type == "language" and entity_before in ["中国"]: # train和devs就是这么标注的
221 | return flag
222 | ac.add(keywords=entity_before, meta_data=(entity_after,pri))
223 | return "add success"
224 |
225 |
--------------------------------------------------------------------------------
/predict_param.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # 识别槽位
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | from Levenshtein import distance
9 | from curLine_file import curLine
10 | from find_entity.exacter_acmation import get_all_entity
11 | from confusion_words_danyin import correct_song, correct_singer
12 | tongyin_yuzhi = 0.65
13 | char_distance_yuzhi = 0.6
14 |
15 |
16 | # 得到两个字符串在字符级别的相似度 越大越相似
17 | def get_char_similarScore(predict_entity, known_entity):
18 | '''
19 | :param predict_entity: 识别得到的槽值
20 | :param known_entity: 库中已有的实体,用这个实体的长度对编辑距离进行归一化
21 | :return: 对编辑距离归一化后的分数,越大则越相似
22 | '''
23 | d = distance(predict_entity, known_entity)
24 | score = 1.0 - float(d) / (len(known_entity)+ len(predict_entity))
25 | if score < char_distance_yuzhi:
26 | score = 0
27 | return score
28 |
29 |
30 | def get_slot_info_str_forMusic(slot_info, raw_query, entityTypeMap): # 列表连接成字符串
31 | # 没有直接join得到字符串,因为想要剔除某些只有一个字的实体 例如phone_num
32 | slot_info_block = []
33 | param_list = []
34 | current_entityType = None
35 | for token in slot_info:
36 | if "<" in token and ">" in token and "/" not in token: # 一个槽位的开始
37 | if len(slot_info_block) > 0:
38 | print(curLine(), len(slot_info), "slot_info:", slot_info)
39 | assert len(slot_info_block) == 0, "".join(slot_info_block)
40 | slot_info_block = []
41 | current_entityType = token[1:-1]
42 | elif "<" in token and ">" in token and "/" in token: # 一个槽位的结束
43 | slot_info_block_str = "".join(slot_info_block)
44 | # 在循环前按照没找到字符角度相似的实体的情况初始化
45 | entity_before = slot_info_block_str # 如果没找到字符角度相似的实体,就返回entity_before
46 | entity_after = slot_info_block_str
47 | priority = 0 # 优先级
48 | ignore_flag = False # 是否忽略这个槽值
49 | if slot_info_block_str in {"什么歌","一首","小花","叮当","傻逼", "给你听","现在","喜欢","ye ye","没","去你娘的蛋", "j c"}: # 黑名单 不是一个槽值
50 | ignore_flag = True # 忽略这个槽值
51 | else:
52 | # 逐个实体计算和slot_info_block_str的相似度
53 | char_similarScore = 0 # 识别的槽值与库中实体在字符上的相似度
54 | for entity_info in entityTypeMap[current_entityType]:
55 | current_entity_before = entity_info['before']
56 | if slot_info_block_str == current_entity_before: # 在实体库中
57 | entity_before = current_entity_before
58 | entity_after = entity_info['after']
59 | priority = entity_info['priority']
60 | char_similarScore = 1.0 # 完全匹配,槽值在库中
61 | if slot_info_block_str != entity_after: # 已经出现过的纠错词
62 | slot_info_block_str = "%s||%s" % (slot_info_block_str, entity_after)
63 | break
64 | current_char_similarScore = get_char_similarScore(predict_entity=slot_info_block_str, known_entity=current_entity_before)
65 | # print(curLine(), slot_info_block_str, current_entity_before, current_char_similarScore)
66 | if current_char_similarScore > char_similarScore: # 在句子找到字面上更接近的实体
67 | entity_before = current_entity_before
68 | entity_after = entity_info['after']
69 | priority = entity_info['priority']
70 | char_similarScore = current_char_similarScore
71 | if priority == 0:
72 | if current_entityType not in {"singer", "song"}:
73 | ignore_flag = True
74 | if len(entity_before) < 2: # 忽略这个短的槽值
75 | ignore_flag = True
76 | if current_entityType == "song" and entity_before in {"鱼", "云", "逃", "退", "陶", "美", "图", "默", "哭", "雪"}:
77 | ignore_flag = False # 这个要在后面判断
78 | if ignore_flag:
79 | if entity_before not in "好点没走":
80 | print(curLine(), raw_query, "ignore entityType:%s, slot:%s, entity_before:%s" % (current_entityType, slot_info_block_str, entity_before))
81 | else:
82 | param_list.append({'entityType':current_entityType, 'before':entity_before, 'after':entity_after, 'priority':priority})
83 | slot_info_block = []
84 | current_entityType = None
85 | elif current_entityType is not None: # is in a slot
86 | slot_info_block.append(token)
87 | param_list_sorted = sorted(param_list, key=lambda item: len(item['before'])*100+item['priority'],
88 | reverse=True)
89 | slot_info_str_list = [raw_query]
90 | replace_query = raw_query
91 | for param in param_list_sorted:
92 | entity_before = param['before']
93 | replace_str = entity_before
94 | if replace_str not in replace_query:
95 | continue
96 | replace_query = replace_query.replace(replace_str, "", 1) # 只替换一次
97 | entityType = param['entityType']
98 |
99 | if param['priority'] == 0: # 模型识别的结果不在库中,尝试用拼音纠错
100 | similar_score = 0.0
101 | best_similar_word = None
102 | if entityType == "singer":
103 | similar_score, best_similar_word = correct_singer(entity_before, jichu_distance=0.001, char_ratio=0.1, char_distance=0)
104 | elif entityType == "song":
105 | similar_score, best_similar_word = correct_song(entity_before, jichu_distance=0.001, char_ratio=0.1, char_distance=0)
106 | if similar_score > tongyin_yuzhi and best_similar_word != entity_before: # 槽值纠错
107 | # print(curLine(), entityType, "entity_before:",entity_before, best_similar_word, similar_score)
108 | param['after'] = best_similar_word
109 |
110 | if entity_before != param['after']:
111 | replace_str = "%s||%s" % (entity_before, param['after'])
112 | for s_index,s in enumerate(slot_info_str_list):
113 | if entity_before not in s or ("" in s and ">" in s): # 不是当前槽值或,已经是一个槽值不能再替换
114 | continue
115 | insert_list = []
116 | start_index = s.find(entity_before)
117 | if start_index > 0:
118 | insert_list.append(s[:start_index])
119 | insert_list.append("<%s>%s%s>" % (entityType, replace_str, entityType))
120 | end_index = start_index+len(entity_before)
121 | if end_index < len(s):
122 | insert_list.append(s[end_index:])
123 | slot_info_str_list = slot_info_str_list[:s_index] + insert_list + slot_info_str_list[s_index+1:]
124 | break # 一个槽值只替换一次
125 | slot_info_str = "".join(slot_info_str_list)
126 | return slot_info_str
127 |
128 |
129 | def get_slot_info_str(slot_info, raw_query, entityTypeMap): # 列表连接成字符串
130 | # 没有直接join得到字符串,因为想要剔除某些只有一个字的实体 例如phone_num
131 | slot_info_str = []
132 | slot_info_block = []
133 | current_entityType = None
134 | for token in slot_info:
135 | if "<" in token and ">" in token and "/" not in token: # 一个槽位的开始
136 | if len(slot_info_block) > 0:
137 | print(curLine(), len(slot_info), "slot_info:", slot_info)
138 | assert len(slot_info_block) == 0, "".join(slot_info_block)
139 | slot_info_block = []
140 | current_entityType = token[1:-1]
141 | elif "<" in token and ">" in token and "/" in token: # 一个槽位的结束
142 | ignore_flag = False
143 | slot_info_block_str = "".join(slot_info_block)
144 | if (current_entityType != "phone_num" or len(slot_info_block_str) > 1) and \
145 | slot_info_block_str not in {"什么歌", "一首", "小花","叮当","傻逼", "给你听","现在"}: # 确认是一个槽值 # TODO 为什么phone_num例外
146 | if current_entityType in ["singer", "song"]: # 同音词查找
147 | known_entity_flag = False # 假设预测的实体不在实体库中
148 | for entity_info in entityTypeMap[current_entityType]:
149 | if slot_info_block_str == entity_info['before']: # 在实体库中
150 | entity_after = entity_info['after']
151 | if slot_info_block_str != entity_after: # 已经出现过的纠错词
152 | slot_info_block_str = "%s||%s" % (slot_info_block_str, entity_after)
153 | known_entity_flag = True
154 | break
155 |
156 | if not known_entity_flag:# 预测的实体不在实体库中,则要尝试结合拼音进行纠错
157 | # TODO 目前未考虑多个纠错后的after具有相同拼音的情况 可以结合音调和字符进行筛选
158 | if len(slot_info_block_str) < 2 and slot_info_block_str not in {"鱼","云","逃","退"}:
159 | ignore_flag = True # 忽略一个字的
160 | elif slot_info_block_str in {"什么歌", "一首"}:
161 | ignore_flag = True # 忽略一个字的
162 | else:
163 | if current_entityType == "singer":
164 | similar_score, best_similar_word = correct_singer(slot_info_block_str)
165 | elif current_entityType == "song":
166 | similar_score, best_similar_word = correct_song(slot_info_block_str, char_ratio=0.55, char_distance=0.0)
167 | if best_similar_word != slot_info_block_str:
168 | if similar_score > tongyin_yuzhi:
169 | # print(curLine(), current_entityType, slot_info_block_str, best_similar_word, similar_score)
170 | slot_info_block_str = "%s||%s" % (slot_info_block_str, best_similar_word)
171 | elif current_entityType in ["theme", "style", "age", "toplist", "emotion", "language", "instrument", "scene"]:
172 | entity_info_list = get_all_entity(raw_query, useEntityTypeList=[current_entityType])[current_entityType]
173 | ignore_flag = True # 忽略不在库中的实体
174 | min_distance = 1000
175 | houxuan = None
176 | for entity_info in entity_info_list:
177 | entity_before = entity_info["before"]
178 | if slot_info_block_str == entity_before:
179 | houxuan = entity_before
180 | min_distance = 0
181 | break
182 | if houxuan is not None and min_distance <= 1:
183 | if slot_info_block_str != houxuan:
184 | print(curLine(), len(entity_info_list), "entity_info_list:", entity_info_list, "slot_info_block_str:",
185 | slot_info_block_str)
186 | print(curLine(), "change %s from %s to %s" % (current_entityType, slot_info_block_str, houxuan))
187 | slot_info_block_str = houxuan
188 | ignore_flag = False # 不忽略
189 | if ignore_flag:
190 | slot_info_str.extend([slot_info_block_str])
191 | else:
192 | slot_info_str.extend(["<%s>" % current_entityType, slot_info_block_str, token])
193 |
194 | else: # 忽略这个短的槽值
195 | print(curLine(), "ignore entityType:%s, slot:%s" % (current_entityType, slot_info_block_str))
196 | slot_info_str.append(slot_info_block_str)
197 | current_entityType = None
198 | slot_info_block = []
199 | elif current_entityType is None: # not slot
200 | slot_info_str.append(token)
201 | else: # is in a slot
202 | slot_info_block.append(token)
203 | slot_info_str = "".join(slot_info_str)
204 | return slot_info_str
--------------------------------------------------------------------------------
/predict_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | """Utility functions for running inference with a LaserTagger model."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 |
22 | from __future__ import print_function
23 | from collections import defaultdict
24 |
25 | from find_entity.acmation import re_phoneNum
26 | from curLine_file import curLine
27 | from predict_param import get_slot_info_str, get_slot_info_str_forMusic
28 |
29 | tongyin_yuzhi = 0.7
30 | cancel_keywords = ["取消", "关闭", "停止", "结束", "关掉", "不要打", "退出", "不需要", "暂停", "谢谢你的服务"]
31 |
32 | class LaserTaggerPredictor(object):
33 | """Class for computing and realizing predictions with LaserTagger."""
34 |
35 | def __init__(self, tf_predictor,
36 | example_builder,
37 | label_map, slot_label_map,
38 | target_domain_name):
39 | """Initializes an instance of LaserTaggerPredictor.
40 |
41 | Args:
42 | tf_predictor: Loaded Tensorflow model.
43 | example_builder: BERT example builder.
44 | label_map: Mapping from tags to tag IDs.
45 | """
46 | self._predictor = tf_predictor
47 | self._example_builder = example_builder
48 | self.intent_id_2_tag = {
49 | tag_id: tag for tag, tag_id in label_map.items()
50 | }
51 | self.slot_id_2_tag = {
52 | tag_id: tag for tag, tag_id in slot_label_map.items()
53 | }
54 | self.target_domain_name = target_domain_name
55 |
56 | def predict_batch(self, sources_batch, location_batch=None, target_domain_name=None, predict_domain_batch=[], raw_query=None): # 由predict改成
57 | """Returns realized prediction for given sources."""
58 | # Predict tag IDs.
59 | keys = ['input_ids', 'input_mask', 'segment_ids', 'entity_type_ids', 'sequence_lengths']
60 | input_info = defaultdict(list)
61 | example_list = []
62 | input_tokens_list = []
63 | location = None
64 | for id, sources in enumerate(sources_batch):
65 | if location_batch is not None:
66 | location = location_batch[id] # 表示是否能修改
67 | example, input_tokens, _= self._example_builder.build_bert_example(sources, location=location)
68 | if example is None:
69 | raise ValueError("Example couldn't be built.")
70 | for key in keys:
71 | input_info[key].append(example.features[key])
72 | example_list.append(example)
73 | input_tokens_list.append(input_tokens)
74 |
75 | out = self._predictor(input_info)
76 | intent_list = []
77 | slot_list = []
78 |
79 | for index, intent in enumerate(out['pred_intent']):
80 | predicted_intent_ids = intent.tolist()
81 | predict_intent = self.intent_id_2_tag[predicted_intent_ids]
82 | query = sources_batch[index][-1]
83 | slot_info = query
84 |
85 |
86 | predict_domain = predict_domain_batch[index]
87 | if predict_domain != target_domain_name:
88 | intent_list.append(predict_intent)
89 | slot_list.append(slot_info)
90 | continue
91 | for word in cancel_keywords:
92 | if word in query:
93 | if target_domain_name == 'navigation':
94 | predict_intent = 'cancel_navigation'
95 | elif target_domain_name == 'music':
96 | predict_intent = 'pause'
97 | elif target_domain_name == 'phone_call':
98 | predict_intent = 'cancel'
99 | else:
100 | raise "wrong target_domain_name:%s" % target_domain_name
101 | break
102 |
103 | if target_domain_name == 'navigation':
104 | if predict_intent != 'navigation':
105 | slot_info = query
106 | else:
107 | slot_info = self.get_slot_info(out['pred_slot'][index], example_list[index], input_tokens_list[index], query)
108 | # if "cancel" not in predict_intent:
109 | # slot_info = self.get_slot_info(out['pred_slot'][index], example_list[index], input_tokens_list[index], query)
110 | # if ">" in slot_info and "" in slot_info:
111 | # predict_intent = 'navigation'
112 | elif target_domain_name == 'phone_call':
113 | if predict_intent != 'make_a_phone_call':
114 | slot_info = query
115 | else:
116 | slot_info = self.get_slot_info(out['pred_slot'][index], example_list[index], input_tokens_list[index], query)
117 | elif target_domain_name == 'music':
118 | # slot_info = self.get_slot_info(out['pred_slot'][index], example_list[index], input_tokens_list[index], query)
119 | if predict_intent != 'play':
120 | slot_info = query
121 | else:
122 | slot_info = self.get_slot_info(out['pred_slot'][index], example_list[index], input_tokens_list[index], query)
123 | intent_list.append(predict_intent)
124 | slot_list.append(slot_info)
125 | return intent_list, slot_list
126 |
127 | def get_slot_info(self, slot, example, input_tokens, query):
128 | # 为了处理UNK问题
129 | token_index_map = {} # tokenId 对应的 word piece在query中的开始位置
130 | if "[UNK]" in input_tokens:
131 | start_index = 1
132 | for index in range(2, len(input_tokens)):
133 | if "[SEP]" in input_tokens[-index]:
134 | start_index = len(input_tokens)-index + 1
135 | break
136 | previous_id = 0
137 | for tokenizer_id, t in enumerate(input_tokens[start_index:-1], start=start_index):
138 | if tokenizer_id > 0 and "[UNK]" in input_tokens[tokenizer_id-1]:
139 | if t in query[previous_id:]:
140 | previous_id = previous_id + query[previous_id:].index(t)
141 | else: # 出现连续的UNK情况,目前的做法是假设长度为1
142 | previous_id += 1
143 | token_index_map[tokenizer_id] = previous_id
144 | if "[UNK]" not in t:
145 | length_t = len(t)
146 | if t.startswith("##", 0, 2):
147 | length_t -= 2
148 | previous_id += length_t
149 | predicted_slot_ids = slot.tolist()
150 | labels_mask = example.features["labels_mask"]
151 | assert len(labels_mask) == len(predicted_slot_ids)
152 | slot_info = []
153 | current_entityType = None
154 | index = -1
155 | # print(curLine(), len(predicted_slot_ids), "predicted_slot_ids:", predicted_slot_ids)
156 | for tokens, mask, slot_id in zip(input_tokens, labels_mask, predicted_slot_ids):
157 | index += 1
158 | if mask > 0:
159 | if tokens.startswith("##"):
160 | tokens = tokens[2:]
161 | elif "[UNK]" in tokens: # 处理UNK的情况
162 | previoud_id = token_index_map[index] # unk对应word开始的位置
163 | next_previoud_id = previoud_id + 1 # unk对应word结束的位置
164 | if index+1 in token_index_map:
165 | next_previoud_id = token_index_map[index+1]
166 | tokens = query[previoud_id:next_previoud_id]
167 | print(curLine(), "unk self.passage[%d,%d]=%s" % (previoud_id, next_previoud_id, tokens))
168 | predict_slot = self.slot_id_2_tag[slot_id]
169 | # print(curLine(), tokens, mask, predict_slot)
170 | # 用规则增强对数字的识别
171 | if current_entityType == "phone_num" and "phone_num" not in predict_slot: # 正在phone_num的区间内
172 | token_numbers = re_phoneNum.findall(tokens)
173 | if len(token_numbers) > 0:
174 | first_span = token_numbers[0]
175 | if tokens.index(first_span) == 0: # 下一个wore piece仍然以数字开头,应该加到前面那个区间中
176 | slot_info.append(first_span)
177 | if len(first_span) < len(tokens): # 如果除了数字,还有其他部分则当作O来追加
178 | slot_info.append("")
179 | slot_info.append(tokens[len(first_span):])
180 | current_entityType = None
181 | continue
182 |
183 | if predict_slot == "O":
184 | if current_entityType is not None: # 已经进入本区间
185 | slot_info.append("%s>" % current_entityType) # 结束
186 | current_entityType = None
187 | slot_info.append(tokens)
188 | continue
189 |
190 | token_type = predict_slot[2:] # 当前token的类型
191 | if current_entityType is not None and token_type != current_entityType: # TODO 上一个区间还没结束
192 | slot_info.append("%s>" % current_entityType) # 强行结束
193 | current_entityType = None
194 | if "B-" in predict_slot:
195 | if token_type != current_entityType: # 上一个本区间已经结束,则添加开始符号
196 | slot_info.append("<%s>" % token_type)
197 | slot_info.append(tokens)
198 | current_entityType = token_type
199 | elif "E-" in predict_slot:
200 | if token_type == current_entityType: # 已经进入本区间
201 | if token_type not in {"origin","destination","destination","phone_num","contact_name","singer"} \
202 | or tokens not in {"的"}: # TODO 某些实体一般不会以这些字结尾
203 | slot_info.append(tokens)
204 | slot_info.append("%s>" % token_type) # 正常情况,先添加再结束
205 | else:
206 | slot_info.append("%s>" % token_type) # 先结束再添加
207 | slot_info.append(tokens)
208 | else: # 类型不符合,当作O处理
209 | slot_info.append(tokens)
210 | current_entityType = None
211 | else: # I
212 | if current_entityType != token_type and index+1 < len(predicted_slot_ids): # 前面没有B,直接以I开头.这时根据下一个tag决定
213 | next_predict_slot = self.slot_id_2_tag[predicted_slot_ids[index+1]]
214 | # print(curLine(), "next_predict_slot:%s, token_type=%s" % (next_predict_slot, token_type))
215 | if token_type in next_predict_slot:
216 | slot_info.append("<%s>" % token_type)
217 | current_entityType = token_type
218 | slot_info.append(tokens)
219 | if current_entityType is not None: # 已经到末尾还没结束
220 | slot_info.append("%s>" % current_entityType) # 强行结束
221 | if self.target_domain_name == "music":
222 | slot_info_str = get_slot_info_str_forMusic(slot_info, raw_query=query, entityTypeMap=example.entityTypeMap)
223 | else:
224 | slot_info_str = get_slot_info_str(slot_info, raw_query=query, entityTypeMap=example.entityTypeMap)
225 | return slot_info_str
--------------------------------------------------------------------------------
/sari_hook.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """SARI score for evaluating paraphrasing and other text generation models.
17 |
18 | The score is introduced in the following paper:
19 |
20 | Optimizing Statistical Machine Translation for Text Simplification
21 | Wei Xu, Courtney Napoles, Ellie Pavlick, Quanze Chen and Chris Callison-Burch
22 | In Transactions of the Association for Computational Linguistics (TACL) 2015
23 | http://cs.jhu.edu/~napoles/res/tacl2016-optimizing.pdf
24 |
25 | This implementation has two differences with the GitHub [1] implementation:
26 | (1) Define 0/0=1 instead of 0 to give higher scores for predictions that match
27 | a target exactly.
28 | (2) Fix an alleged bug [2] in the deletion score computation.
29 |
30 | [1] https://github.com/cocoxu/simplification/blob/master/SARI.py
31 | (commit 0210f15)
32 | [2] https://github.com/cocoxu/simplification/issues/6
33 | """
34 |
35 | from __future__ import absolute_import
36 | from __future__ import division
37 | from __future__ import print_function
38 |
39 | import collections
40 |
41 | import numpy as np
42 | import tensorflow as tf
43 |
44 | # The paper that intoduces the SARI score uses only the precision of the deleted
45 | # tokens (i.e. beta=0). To give more emphasis on recall, you may set, e.g.,
46 | # beta=1.
47 | BETA_FOR_SARI_DELETION_F_MEASURE = 0
48 |
49 |
50 | def _get_ngram_counter(ids, n):
51 | """Get a Counter with the ngrams of the given ID list.
52 |
53 | Args:
54 | ids: np.array or a list corresponding to a single sentence
55 | n: n-gram size
56 |
57 | Returns:
58 | collections.Counter with ID tuples as keys and 1s as values.
59 | """
60 | # Remove zero IDs used to pad the sequence.
61 | ids = [token_id for token_id in ids if token_id != 0]
62 | ngram_list = [tuple(ids[i:i + n]) for i in range(len(ids) + 1 - n)]
63 | ngrams = set(ngram_list)
64 | counts = collections.Counter()
65 | for ngram in ngrams:
66 | counts[ngram] = 1
67 | return counts
68 |
69 |
70 | def _get_fbeta_score(true_positives, selected, relevant, beta=1):
71 | """Compute Fbeta score.
72 |
73 | Args:
74 | true_positives: Number of true positive ngrams.
75 | selected: Number of selected ngrams.
76 | relevant: Number of relevant ngrams.
77 | beta: 0 gives precision only, 1 gives F1 score, and Inf gives recall only.
78 |
79 | Returns:
80 | Fbeta score.
81 | """
82 | precision = 1
83 | if selected > 0:
84 | precision = true_positives / selected
85 | if beta == 0:
86 | return precision
87 | recall = 1
88 | if relevant > 0:
89 | recall = true_positives / relevant
90 | if precision > 0 and recall > 0:
91 | beta2 = beta * beta
92 | return (1 + beta2) * precision * recall / (beta2 * precision + recall)
93 | else:
94 | return 0
95 |
96 |
97 | def get_addition_score(source_counts, prediction_counts, target_counts):
98 | """Compute the addition score (Equation 4 in the paper)."""
99 | added_to_prediction_counts = prediction_counts - source_counts
100 | true_positives = sum((added_to_prediction_counts & target_counts).values())
101 | selected = sum(added_to_prediction_counts.values())
102 | # Note that in the paper the summation is done over all the ngrams in the
103 | # output rather than the ngrams in the following set difference. Since the
104 | # former does not make as much sense we compute the latter, which is also done
105 | # in the GitHub implementation.
106 | relevant = sum((target_counts - source_counts).values())
107 | return _get_fbeta_score(true_positives, selected, relevant)
108 |
109 |
110 | def get_keep_score(source_counts, prediction_counts, target_counts):
111 | """Compute the keep score (Equation 5 in the paper)."""
112 | source_and_prediction_counts = source_counts & prediction_counts
113 | source_and_target_counts = source_counts & target_counts
114 | true_positives = sum((source_and_prediction_counts &
115 | source_and_target_counts).values())
116 | selected = sum(source_and_prediction_counts.values())
117 | relevant = sum(source_and_target_counts.values())
118 | return _get_fbeta_score(true_positives, selected, relevant)
119 |
120 |
121 | def get_deletion_score(source_counts, prediction_counts, target_counts, beta=0):
122 | """Compute the deletion score (Equation 6 in the paper)."""
123 | source_not_prediction_counts = source_counts - prediction_counts
124 | source_not_target_counts = source_counts - target_counts
125 | true_positives = sum((source_not_prediction_counts &
126 | source_not_target_counts).values())
127 | selected = sum(source_not_prediction_counts.values())
128 | relevant = sum(source_not_target_counts.values())
129 | return _get_fbeta_score(true_positives, selected, relevant, beta=beta)
130 |
131 |
132 | def get_sari_score(source_ids, prediction_ids, list_of_targets,
133 | max_gram_size=4, beta_for_deletion=0):
134 | """Compute the SARI score for a single prediction and one or more targets.
135 |
136 | Args:
137 | source_ids: a list / np.array of SentencePiece IDs
138 | prediction_ids: a list / np.array of SentencePiece IDs
139 | list_of_targets: a list of target ID lists / np.arrays
140 | max_gram_size: int. largest n-gram size we care about (e.g. 3 for unigrams,
141 | bigrams, and trigrams)
142 | beta_for_deletion: beta for deletion F score.
143 |
144 | Returns:
145 | the SARI score and its three components: add, keep, and deletion scores
146 | """
147 | addition_scores = []
148 | keep_scores = []
149 | deletion_scores = []
150 | for n in range(1, max_gram_size + 1):
151 | source_counts = _get_ngram_counter(source_ids, n)
152 | prediction_counts = _get_ngram_counter(prediction_ids, n)
153 | # All ngrams in the targets with count 1.
154 | target_counts = collections.Counter()
155 | # All ngrams in the targets with count r/num_targets, where r is the number
156 | # of targets where the ngram occurs.
157 | weighted_target_counts = collections.Counter()
158 | num_nonempty_targets = 0
159 | for target_ids_i in list_of_targets:
160 | target_counts_i = _get_ngram_counter(target_ids_i, n)
161 | if target_counts_i:
162 | weighted_target_counts += target_counts_i
163 | num_nonempty_targets += 1
164 | for gram in weighted_target_counts.keys():
165 | weighted_target_counts[gram] /= num_nonempty_targets
166 | target_counts[gram] = 1
167 | keep_scores.append(get_keep_score(source_counts, prediction_counts,
168 | weighted_target_counts))
169 | deletion_scores.append(get_deletion_score(source_counts, prediction_counts,
170 | weighted_target_counts,
171 | beta_for_deletion))
172 | addition_scores.append(get_addition_score(source_counts, prediction_counts,
173 | target_counts))
174 |
175 | avg_keep_score = sum(keep_scores) / max_gram_size
176 | avg_addition_score = sum(addition_scores) / max_gram_size
177 | avg_deletion_score = sum(deletion_scores) / max_gram_size
178 | sari = (avg_keep_score + avg_addition_score + avg_deletion_score) / 3.0
179 | return sari, avg_keep_score, avg_addition_score, avg_deletion_score
180 |
181 |
182 | def get_sari(source_ids, prediction_ids, target_ids, max_gram_size=4):
183 | """Computes the SARI scores from the given source, prediction and targets.
184 |
185 | Args:
186 | source_ids: A 2D tf.Tensor of size (batch_size , sequence_length)
187 | prediction_ids: A 2D tf.Tensor of size (batch_size, sequence_length)
188 | target_ids: A 3D tf.Tensor of size (batch_size, number_of_targets,
189 | sequence_length)
190 | max_gram_size: int. largest n-gram size we care about (e.g. 3 for unigrams,
191 | bigrams, and trigrams)
192 |
193 | Returns:
194 | A 4-tuple of 1D float Tensors of size (batch_size) for the SARI score and
195 | the keep, addition and deletion scores.
196 | """
197 |
198 | def get_sari_numpy(source_ids, prediction_ids, target_ids):
199 | """Iterate over elements in the batch and call the SARI function."""
200 | sari_scores = []
201 | keep_scores = []
202 | add_scores = []
203 | deletion_scores = []
204 | # Iterate over elements in the batch.
205 | for source_ids_i, prediction_ids_i, target_ids_i in zip(
206 | source_ids, prediction_ids, target_ids):
207 | sari, keep, add, deletion = get_sari_score(
208 | source_ids_i, prediction_ids_i, target_ids_i, max_gram_size,
209 | BETA_FOR_SARI_DELETION_F_MEASURE)
210 | sari_scores.append(sari)
211 | keep_scores.append(keep)
212 | add_scores.append(add)
213 | deletion_scores.append(deletion)
214 | return (np.asarray(sari_scores), np.asarray(keep_scores),
215 | np.asarray(add_scores), np.asarray(deletion_scores))
216 |
217 | sari, keep, add, deletion = tf.py_func(
218 | get_sari_numpy,
219 | [source_ids, prediction_ids, target_ids],
220 | [tf.float64, tf.float64, tf.float64, tf.float64])
221 | return sari, keep, add, deletion
222 |
223 |
224 | def sari_score(predictions, labels, features, **unused_kwargs):
225 | """Computes the SARI scores from the given source, prediction and targets.
226 |
227 | An approximate SARI scoring method since we do not glue word pieces or
228 | decode the ids and tokenize the output. By default, we use ngram order of 4.
229 | Also, this does not have beam search.
230 |
231 | Args:
232 | predictions: tensor, model predictions.
233 | labels: tensor, gold output.
234 | features: dict, containing inputs.
235 |
236 | Returns:
237 | sari: int, approx sari score
238 | """
239 | if "inputs" not in features:
240 | raise ValueError("sari_score requires inputs feature")
241 |
242 | # Convert the inputs and outputs to a [batch_size, sequence_length] tensor.
243 | inputs = tf.squeeze(features["inputs"], axis=[-1, -2])
244 | outputs = tf.to_int32(tf.argmax(predictions, axis=-1))
245 | outputs = tf.squeeze(outputs, axis=[-1, -2])
246 |
247 | # Convert the labels to a [batch_size, 1, sequence_length] tensor.
248 | labels = tf.squeeze(labels, axis=[-1, -2])
249 | labels = tf.expand_dims(labels, axis=1)
250 |
251 | score, _, _, _ = get_sari(inputs, outputs, labels)
252 | return score, tf.constant(1.0)
253 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/run_lasertagger.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | """BERT-based LaserTagger runner."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 |
22 | from __future__ import print_function
23 | import os
24 | from absl import flags
25 | from termcolor import colored
26 | import tensorflow as tf
27 | import run_lasertagger_utils
28 | import utils
29 | from curLine_file import curLine
30 |
31 | FLAGS = flags.FLAGS
32 |
33 | ## Required parameters
34 | flags.DEFINE_string("training_file", None,
35 | "Path to the TFRecord training file.")
36 | flags.DEFINE_string("eval_file", None, "Path to the the TFRecord dev file.")
37 | flags.DEFINE_string(
38 | "label_map_file", None,
39 | "Path to the label map file. Either a JSON file ending with '.json', that "
40 | "maps each possible tag to an ID, or a text file that has one tag per line.")
41 | flags.DEFINE_string(
42 | 'slot_label_map_file', None,
43 | 'Path to the label map file. Either a JSON file ending with ".json", that '
44 | 'maps each possible tag to an ID, or a text file that has one tag per line.')
45 | flags.DEFINE_string(
46 | "model_config_file", None,
47 | "The config json file specifying the model architecture.")
48 | flags.DEFINE_string(
49 | "output_dir", None,
50 | "The output directory where the model checkpoints will be written. If "
51 | "`init_checkpoint' is not provided when exporting, the latest checkpoint "
52 | "from this directory will be exported.")
53 |
54 | ## Other parameters
55 |
56 | flags.DEFINE_string(
57 | "init_checkpoint", None,
58 | "Initial checkpoint, usually from a pre-trained BERT model. In the case of "
59 | "exporting, one can optionally provide path to a particular checkpoint to "
60 | "be exported here.")
61 | flags.DEFINE_integer(
62 | "max_seq_length", 128, # contain CLS and SEP
63 | "The maximum total input sequence length after WordPiece tokenization. "
64 | "Sequences longer than this will be truncated, and sequences shorter than "
65 | "this will be padded.")
66 |
67 | flags.DEFINE_bool("do_train", False, "Whether to run training.")
68 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
69 | flags.DEFINE_bool("do_export", False, "Whether to export a trained model.")
70 | flags.DEFINE_bool("eval_all_checkpoints", False, "Run through all checkpoints.")
71 | flags.DEFINE_integer(
72 | "eval_timeout", 600,
73 | "The maximum amount of time (in seconds) for eval worker to wait between "
74 | "checkpoints.")
75 |
76 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
77 | flags.DEFINE_integer("eval_batch_size", 128, "Total batch size for eval.")
78 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.")
79 | flags.DEFINE_float("learning_rate", 3e-5, "The initial learning rate for Adam.")
80 | flags.DEFINE_float(
81 | "warmup_proportion", 0.1,
82 | "Proportion of training to perform linear learning rate warmup for. "
83 | "E.g., 0.1 = 10% of training.")
84 |
85 | flags.DEFINE_float("num_train_epochs", 3.0,
86 | "Total number of training epochs to perform.")
87 |
88 | flags.DEFINE_float("drop_keep_prob", 0.9, "drop_keep_prob")
89 | flags.DEFINE_integer("kernel_size", 2, "kernel_size")
90 |
91 | flags.DEFINE_integer("save_checkpoints_steps", 1000,
92 | "How often to save the model checkpoint.")
93 | flags.DEFINE_integer("keep_checkpoint_max", 3,
94 | "How many checkpoints to keep.")
95 | flags.DEFINE_integer("iterations_per_loop", 1000,
96 | "How many steps to make in each estimator call.")
97 | flags.DEFINE_integer(
98 | "num_train_examples", None,
99 | "Number of training examples. This is used to determine the number of "
100 | "training steps to respect the `num_train_epochs` flag.")
101 | flags.DEFINE_integer(
102 | "num_eval_examples", None,
103 | "Number of eval examples. This is used to determine the number of "
104 | "eval steps to go through the eval file once.")
105 |
106 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
107 | flags.DEFINE_string(
108 | "tpu_name", None,
109 | "The Cloud TPU to use for training. This should be either the name "
110 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
111 | "url.")
112 | flags.DEFINE_string(
113 | "tpu_zone", None,
114 | "[Optional] GCE zone where the Cloud TPU is located in. If not "
115 | "specified, we will attempt to automatically detect the GCE project from "
116 | "metadata.")
117 | flags.DEFINE_string(
118 | "gcp_project", None,
119 | "[Optional] Project name for the Cloud TPU-enabled project. If not "
120 | "specified, we will attempt to automatically detect the GCE project from "
121 | "metadata.")
122 | flags.DEFINE_string("master", None,
123 | "Optional address of the master for the workers.")
124 | flags.DEFINE_string("export_path", None, "Path to save the exported model.")
125 | flags.DEFINE_float("slot_ratio", 0.3, "slot_ratio")
126 |
127 | flags.DEFINE_string(
128 | 'domain_name', None,
129 | 'Whether to lower case the input text. Should be True for uncased '
130 | 'models and False for cased models.')
131 | flags.DEFINE_string(
132 | 'entity_type_list_file', None, 'path of entity_type_list_file')
133 |
134 | def file_based_input_fn_builder(input_file, max_seq_length,
135 | is_training, drop_remainder, entity_type_num):
136 | """Creates an `input_fn` closure to be passed to TPUEstimator."""
137 |
138 | name_to_features = {
139 | "input_ids": tf.FixedLenFeature([max_seq_length], tf.int64),
140 | "input_mask": tf.FixedLenFeature([max_seq_length], tf.int64),
141 | "segment_ids": tf.FixedLenFeature([max_seq_length], tf.int64),
142 | "labels": tf.FixedLenFeature([], tf.int64),
143 | "slot_labels": tf.FixedLenFeature([max_seq_length], tf.int64),
144 | "labels_mask": tf.FixedLenFeature([max_seq_length], tf.int64),
145 | "entity_type_ids": tf.FixedLenFeature([max_seq_length*entity_type_num], tf.int64),
146 | "sequence_lengths": tf.FixedLenFeature([], tf.int64),
147 | }
148 |
149 | def _decode_record(record, name_to_features):
150 | """Decodes a record to a TensorFlow example."""
151 | example = tf.parse_single_example(record, name_to_features)
152 |
153 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
154 | # So cast all int64 to int32.
155 | for name in list(example.keys()):
156 | t = example[name]
157 | if t.dtype == tf.int64:
158 | t = tf.to_int32(t)
159 | example[name] = t
160 |
161 | return example
162 |
163 | def input_fn(params):
164 | """The actual input function."""
165 | d = tf.data.TFRecordDataset(input_file)
166 | # For training, we want a lot of parallel reading and shuffling.
167 | # For eval, we want no shuffling and parallel reading doesn't matter.
168 | if is_training:
169 | d = d.repeat() # 每epoch重复使用
170 | d = d.shuffle(buffer_size=100)
171 | d = d.apply(
172 | tf.contrib.data.map_and_batch(
173 | lambda record: _decode_record(record, name_to_features),
174 | batch_size=params["batch_size"],
175 | drop_remainder=drop_remainder))
176 | return d
177 |
178 | return input_fn
179 |
180 |
181 | def _calculate_steps(num_examples, batch_size, num_epochs, warmup_proportion=0):
182 | """Calculates the number of steps.
183 |
184 | Args:
185 | num_examples: Number of examples in the dataset.
186 | batch_size: Batch size.
187 | num_epochs: How many times we should go through the dataset.
188 | warmup_proportion: Proportion of warmup steps.
189 |
190 | Returns:
191 | Tuple (number of steps, number of warmup steps).
192 | """
193 | steps = int(num_examples / batch_size * num_epochs)
194 | warmup_steps = int(warmup_proportion * steps)
195 | return steps, warmup_steps
196 |
197 |
198 | def main(_):
199 | tf.logging.set_verbosity(tf.logging.INFO)
200 | if not (FLAGS.do_train or FLAGS.do_eval or FLAGS.do_export):
201 | raise ValueError("At least one of `do_train`, `do_eval` or `do_export` must be True.")
202 |
203 | model_config = run_lasertagger_utils.LaserTaggerConfig.from_json_file(
204 | FLAGS.model_config_file)
205 |
206 | if FLAGS.max_seq_length > model_config.max_position_embeddings:
207 | raise ValueError(
208 | "Cannot use sequence length %d because the BERT model "
209 | "was only trained up to sequence length %d" %
210 | (FLAGS.max_seq_length, model_config.max_position_embeddings))
211 |
212 | if not FLAGS.do_export and not os.path.exists(FLAGS.output_dir):
213 | os.makedirs(FLAGS.output_dir)
214 |
215 | num_tags = len(utils.read_label_map(FLAGS.label_map_file))
216 | num_slot_tags = len(utils.read_label_map(FLAGS.slot_label_map_file))
217 | entity_type_list = utils.read_label_map(FLAGS.entity_type_list_file)[FLAGS.domain_name]
218 |
219 | tpu_cluster_resolver = None
220 | if FLAGS.use_tpu and FLAGS.tpu_name:
221 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
222 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
223 |
224 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
225 | run_config = tf.contrib.tpu.RunConfig(
226 | cluster=tpu_cluster_resolver,
227 | master=FLAGS.master,
228 | model_dir=FLAGS.output_dir,
229 | save_checkpoints_steps=FLAGS.save_checkpoints_steps,
230 | keep_checkpoint_max=FLAGS.keep_checkpoint_max) # ,
231 | # tpu_config=tf.contrib.tpu.TPUConfig(
232 | # iterations_per_loop=FLAGS.iterations_per_loop,
233 | # per_host_input_for_training=is_per_host,
234 | # eval_training_input_configuration=tf.contrib.tpu.InputPipelineConfig.SLICED))
235 |
236 | if FLAGS.do_train:
237 | num_train_steps, num_warmup_steps = _calculate_steps(
238 | FLAGS.num_train_examples, FLAGS.train_batch_size,
239 | FLAGS.num_train_epochs, FLAGS.warmup_proportion)
240 | else:
241 | num_train_steps, num_warmup_steps = None, None
242 | entity_type_num = len(entity_type_list)
243 | print(curLine(), "num_train_steps=",num_train_steps, "learning_rate=", FLAGS.learning_rate,
244 | ",num_warmup_steps=",num_warmup_steps, ",max_seq_length=", FLAGS.max_seq_length)
245 | print(colored("%s restore from init_checkpoint:%s" % (curLine(), FLAGS.init_checkpoint), "red"))
246 | model_fn = run_lasertagger_utils.ModelFnBuilder(
247 | config=model_config,
248 | num_tags=num_tags,
249 | num_slot_tags=num_slot_tags,
250 | init_checkpoint=FLAGS.init_checkpoint,
251 | learning_rate=FLAGS.learning_rate,
252 | num_train_steps=num_train_steps,
253 | num_warmup_steps=num_warmup_steps,
254 | use_tpu=FLAGS.use_tpu,
255 | use_one_hot_embeddings=FLAGS.use_tpu,
256 | max_seq_length=FLAGS.max_seq_length,
257 | drop_keep_prob=FLAGS.drop_keep_prob,
258 | entity_type_num=entity_type_num,
259 | slot_ratio=FLAGS.slot_ratio).build()
260 |
261 | # If TPU is not available, this will fall back to normal Estimator on CPU or GPU.
262 | estimator = tf.contrib.tpu.TPUEstimator(
263 | use_tpu=FLAGS.use_tpu,
264 | model_fn=model_fn,
265 | config=run_config,
266 | train_batch_size=FLAGS.train_batch_size,
267 | eval_batch_size=FLAGS.eval_batch_size,
268 | predict_batch_size=FLAGS.predict_batch_size
269 | )
270 |
271 | if FLAGS.do_train:
272 | train_input_fn = file_based_input_fn_builder(
273 | input_file=FLAGS.training_file,
274 | max_seq_length=FLAGS.max_seq_length,
275 | is_training=True,
276 | drop_remainder=True,
277 | entity_type_num=entity_type_num)
278 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
279 |
280 | if FLAGS.do_export:
281 | tf.logging.info("Exporting the model to %s"% FLAGS.export_path)
282 | def serving_input_fn():
283 | def _input_fn():
284 | features = {
285 | "input_ids": tf.placeholder(tf.int64, [None, None]),
286 | "input_mask": tf.placeholder(tf.int64, [None, None]),
287 | "segment_ids": tf.placeholder(tf.int64, [None, None]),
288 | "entity_type_ids": tf.placeholder(tf.int64, [None, None]),
289 | "sequence_lengths": tf.placeholder(tf.int64, [None]),
290 | }
291 | return tf.estimator.export.ServingInputReceiver(
292 | features=features, receiver_tensors=features)
293 |
294 | return _input_fn
295 | if not os.path.exists(FLAGS.export_path):
296 | print(curLine(), "will make dir:%s" % FLAGS.export_path)
297 | os.makedirs(FLAGS.export_path)
298 |
299 | estimator.export_saved_model(
300 | FLAGS.export_path,
301 | serving_input_fn(),
302 | checkpoint_path=FLAGS.init_checkpoint)
303 |
304 |
305 | if __name__ == "__main__":
306 | flags.mark_flag_as_required("model_config_file")
307 | flags.mark_flag_as_required("label_map_file")
308 | flags.mark_flag_as_required("entity_type_list_file")
309 | tf.app.run()
310 |
--------------------------------------------------------------------------------
/run_lasertagger_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | """Utilities for building a LaserTagger TF model."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 |
22 | from __future__ import print_function
23 | from bert import modeling
24 | from bert import optimization
25 | import tensorflow as tf
26 | from tensorflow.nn.rnn_cell import LSTMCell
27 | from curLine_file import curLine
28 |
29 | class LaserTaggerConfig(modeling.BertConfig):
30 | """Model configuration for LaserTagger."""
31 |
32 | def __init__(self,
33 | use_t2t_decoder=True,
34 | decoder_num_hidden_layers=1,
35 | decoder_hidden_size=768,
36 | decoder_num_attention_heads=4,
37 | decoder_filter_size=3072,
38 | use_full_attention=False,
39 | **kwargs):
40 | """Initializes an instance of LaserTagger configuration.
41 |
42 | This initializer expects both the BERT specific arguments and the
43 | Transformer decoder arguments listed below.
44 |
45 | Args:
46 | use_t2t_decoder: Whether to use the Transformer decoder (i.e.
47 | LaserTagger_AR). If False, the remaining args do not affect anything and
48 | can be set to default values.
49 | decoder_num_hidden_layers: Number of hidden decoder layers.
50 | decoder_hidden_size: Decoder hidden size.
51 | decoder_num_attention_heads: Number of decoder attention heads.
52 | decoder_filter_size: Decoder filter size.
53 | use_full_attention: Whether to use full encoder-decoder attention.
54 | **kwargs: The arguments that the modeling.BertConfig initializer expects.
55 | """
56 | super(LaserTaggerConfig, self).__init__(**kwargs)
57 | self.use_t2t_decoder = use_t2t_decoder
58 | self.decoder_num_hidden_layers = decoder_num_hidden_layers
59 | self.decoder_hidden_size = decoder_hidden_size
60 | self.decoder_num_attention_heads = decoder_num_attention_heads
61 | self.decoder_filter_size = decoder_filter_size
62 | self.use_full_attention = use_full_attention
63 |
64 |
65 | class ModelFnBuilder(object):
66 | """Class for building `model_fn` closure for TPUEstimator."""
67 |
68 | def __init__(self, config, num_tags, num_slot_tags,
69 | init_checkpoint,
70 | learning_rate, num_train_steps,
71 | num_warmup_steps, use_tpu,
72 | use_one_hot_embeddings, max_seq_length, drop_keep_prob, entity_type_num, slot_ratio):
73 | """Initializes an instance of a LaserTagger model.
74 |
75 | Args:
76 | config: LaserTagger model configuration.
77 | num_tags: Number of different tags to be predicted.
78 | init_checkpoint: Path to a pretrained BERT checkpoint (optional).
79 | learning_rate: Learning rate.
80 | num_train_steps: Number of training steps.
81 | num_warmup_steps: Number of warmup steps.
82 | use_tpu: Whether to use TPU.
83 | use_one_hot_embeddings: Whether to use one-hot embeddings for word
84 | embeddings.
85 | max_seq_length: Maximum sequence length.
86 | """
87 | self._config = config
88 | self._num_tags = num_tags
89 | self.num_slot_tags = num_slot_tags
90 | self._init_checkpoint = init_checkpoint
91 | self._learning_rate = learning_rate
92 | self._num_train_steps = num_train_steps
93 | self._num_warmup_steps = num_warmup_steps
94 | self._use_tpu = use_tpu
95 | self._use_one_hot_embeddings = use_one_hot_embeddings
96 | self._max_seq_length = max_seq_length
97 | self.drop_keep_prob = drop_keep_prob
98 | self.slot_ratio = slot_ratio
99 | self.intent_ratio = 1.0-self.slot_ratio
100 | self.entity_type_num = entity_type_num
101 | self.lstm_hidden_size = 128
102 |
103 | def _create_model(self, mode, input_ids, input_mask, segment_ids, labels,
104 | slot_labels, labels_mask, drop_keep_prob, entity_type_ids, sequence_lengths):
105 | """Creates a LaserTagger model."""
106 | is_training = (mode == tf.estimator.ModeKeys.TRAIN)
107 | model = modeling.BertModel(
108 | config=self._config,
109 | is_training=is_training,
110 | input_ids=input_ids,
111 | input_mask=input_mask,
112 | token_type_ids=segment_ids,
113 | use_one_hot_embeddings=self._use_one_hot_embeddings)
114 |
115 | final_layer = model.get_sequence_output()
116 | # final_hidden = model.get_pooled_output()
117 |
118 | if is_training:
119 | # I.e., 0.1 dropout
120 | # final_hidden = tf.nn.dropout(final_hidden, keep_prob=drop_keep_prob)
121 | final_layer = tf.nn.dropout(final_layer, keep_prob=drop_keep_prob)
122 |
123 | # 结合实体信息
124 | batch_size, seq_length = modeling.get_shape_list(input_ids)
125 |
126 | self.entity_type_embedding = tf.get_variable(name="entity_type_embedding",
127 | shape=(self.entity_type_num, self._config.hidden_size),
128 | dtype=tf.float32, trainable=True,
129 | initializer=tf.random_uniform_initializer(
130 | -self._config.initializer_range*100,
131 | self._config.initializer_range*100, seed=20))
132 |
133 | with tf.init_scope():
134 | impact_weight_init = tf.constant(1.0 / self.entity_type_num, dtype=tf.float32, shape=(1, self.entity_type_num))
135 | self.impact_weight = tf.Variable(
136 | impact_weight_init, dtype=tf.float32, name="impact_weight") # 不同类型的影响权重
137 | impact_weight_matrix = tf.tile(self.impact_weight, multiples=[batch_size * seq_length, 1])
138 |
139 | entity_type_ids_matrix1 = tf.cast(tf.reshape(entity_type_ids, [batch_size * seq_length, self.entity_type_num]), dtype=tf.float32)
140 | entity_type_ids_matrix = tf.multiply(entity_type_ids_matrix1, impact_weight_matrix)
141 | entity_type_emb = tf.matmul(entity_type_ids_matrix, self.entity_type_embedding)
142 | final_layer = final_layer + tf.reshape(entity_type_emb, [batch_size, seq_length, self._config.hidden_size]) # TODO TODO # 0.7071067811865476是二分之根号二
143 | # final_layer = tf.concat([final_layer, tf.reshape(entity_type_emb, [batch_size, seq_length,self._config.hidden_size])], axis=-1)
144 |
145 | if is_training:
146 | final_layer = tf.nn.dropout(final_layer, keep_prob=drop_keep_prob)
147 |
148 | (output_fw_seq, output_bw_seq), ((c_fw,h_fw),(c_bw,h_bw)) = tf.nn.bidirectional_dynamic_rnn(
149 | cell_fw=LSTMCell(self.lstm_hidden_size),
150 | cell_bw=LSTMCell(self.lstm_hidden_size),
151 | inputs=final_layer,
152 | sequence_length=sequence_lengths,
153 | dtype=tf.float32)
154 | layer_matrix = tf.concat([output_fw_seq, output_bw_seq], axis=-1)
155 | final_hidden = tf.concat([c_fw, c_bw], axis=-1)
156 |
157 | layer_matrix= tf.contrib.layers.layer_norm(
158 | inputs=layer_matrix, begin_norm_axis=-1, begin_params_axis=-1)
159 |
160 |
161 | intent_logits = tf.layers.dense(
162 | final_hidden,
163 | self._num_tags,
164 | kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
165 | name="output_projection")
166 | slot_logits = tf.layers.dense(layer_matrix, self.num_slot_tags,
167 | kernel_initializer=tf.truncated_normal_initializer(stddev=0.02), name="slot_projection")
168 |
169 |
170 | with tf.variable_scope("loss"):
171 | loss = None
172 | per_example_intent_loss = None
173 | per_example_slot_loss = None
174 | if mode != tf.estimator.ModeKeys.PREDICT:
175 | per_example_intent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
176 | labels=labels, logits=intent_logits)
177 | slot_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
178 | labels=slot_labels, logits=slot_logits)
179 | per_example_slot_loss = tf.truediv(
180 | tf.reduce_sum(slot_loss, axis=1),
181 | tf.cast(tf.reduce_sum(labels_mask, axis=1), tf.float32))
182 |
183 | # from tensorflow.contrib.crf import crf_log_likelihood
184 | # from tensorflow.contrib.crf import viterbi_decode
185 | # batch_size = tf.shape(slot_logits)[0]
186 | # print(curLine(), batch_size, tf.constant([self._max_seq_length]))
187 | # length_batch = tf.tile(tf.constant([self._max_seq_length]), [batch_size])
188 | # print(curLine(), batch_size, "length_batch:", length_batch)
189 | # per_example_slot_loss, self.transition_params = crf_log_likelihood(inputs=slot_logits,
190 | # tag_indices=slot_labels,sequence_lengths=length_batch)
191 | # print(curLine(), "per_example_slot_loss:", per_example_slot_loss) # shape=(batch_size,)
192 | # print(curLine(), "self.transition_params:", self.transition_params) # shape=(9, 9)
193 |
194 | loss = tf.reduce_mean(self.intent_ratio*per_example_intent_loss + self.slot_ratio*per_example_slot_loss)
195 | pred_intent = tf.cast(tf.argmax(intent_logits, axis=-1), tf.int32)
196 | pred_slot = tf.cast(tf.argmax(slot_logits, axis=-1), tf.int32)
197 | return (loss, per_example_slot_loss, pred_intent, pred_slot,
198 | batch_size, entity_type_emb, impact_weight_matrix, entity_type_ids_matrix, final_layer, slot_logits)
199 |
200 | def build(self):
201 | """Returns `model_fn` closure for TPUEstimator."""
202 |
203 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
204 | """The `model_fn` for TPUEstimator."""
205 |
206 | tf.logging.info("*** Features ***")
207 | for name in sorted(features.keys()):
208 | tf.logging.info(" name = %s, shape = %s", name, features[name].shape)
209 | input_ids = features["input_ids"]
210 | input_mask = features["input_mask"]
211 | segment_ids = features["segment_ids"]
212 | labels = None
213 | slot_labels = None
214 | labels_mask = None
215 | if mode != tf.estimator.ModeKeys.PREDICT:
216 | labels = features["labels"]
217 | slot_labels = features["slot_labels"]
218 | labels_mask = features["labels_mask"]
219 | (total_loss, per_example_loss, pred_intent, pred_slot,
220 | batch_size, entity_type_emb, impact_weight_matrix, entity_type_ids_matrix, final_layer, slot_logits) = self._create_model(
221 | mode, input_ids, input_mask, segment_ids, labels, slot_labels, labels_mask, self.drop_keep_prob,
222 | features["entity_type_ids"], sequence_lengths=features["sequence_lengths"])
223 |
224 | tvars = tf.trainable_variables()
225 | initialized_variable_names = {}
226 | scaffold_fn = None
227 | if self._init_checkpoint:
228 | (assignment_map, initialized_variable_names
229 | ) = modeling.get_assignment_map_from_checkpoint(tvars,
230 | self._init_checkpoint)
231 | if self._use_tpu:
232 | def tpu_scaffold():
233 | tf.train.init_from_checkpoint(self._init_checkpoint, assignment_map)
234 | return tf.train.Scaffold()
235 |
236 | scaffold_fn = tpu_scaffold
237 | else:
238 | tf.train.init_from_checkpoint(self._init_checkpoint, assignment_map)
239 |
240 | tf.logging.info("**** Trainable Variables ****")
241 | # for var in tvars:
242 | # tf.logging.info("Initializing the model from: %s",
243 | # self._init_checkpoint)
244 | # init_string = ""
245 | # if var.name in initialized_variable_names:
246 | # init_string = ", *INIT_FROM_CKPT*"
247 | # tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
248 | # init_string)
249 |
250 | output_spec = None
251 | if mode == tf.estimator.ModeKeys.TRAIN:
252 | train_op = optimization.create_optimizer(
253 | total_loss, self._learning_rate, self._num_train_steps,
254 | self._num_warmup_steps, self._use_tpu)
255 |
256 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
257 | mode=mode,
258 | loss=total_loss,
259 | train_op=train_op,
260 | scaffold_fn=scaffold_fn)
261 |
262 | elif mode == tf.estimator.ModeKeys.EVAL:
263 | def metric_fn(per_example_loss, labels, labels_mask, predictions):
264 | """Compute eval metrics."""
265 | accuracy = tf.cast(
266 | tf.reduce_all( # tf.reduce_all 相当于"逻辑AND"操作,找到输出完全正确的才算正确
267 | tf.logical_or(
268 | tf.equal(labels, predictions),
269 | ~tf.cast(labels_mask, tf.bool)),
270 | axis=1), tf.float32)
271 | return {
272 | # This is equal to the Exact score if the final realization step
273 | # doesn't introduce errors.
274 | "sentence_level_acc": tf.metrics.mean(accuracy),
275 | "eval_loss": tf.metrics.mean(per_example_loss),
276 | }
277 |
278 | eval_metrics = (metric_fn,
279 | [per_example_loss, labels, labels_mask, pred_intent])
280 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
281 | mode=mode,
282 | loss=total_loss,
283 | eval_metrics=eval_metrics,
284 | scaffold_fn=scaffold_fn)
285 | else:
286 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
287 | mode=mode, predictions={'pred_intent': pred_intent, 'pred_slot': pred_slot,
288 | 'batch_size':batch_size, 'entity_type_emb':entity_type_emb, 'impact_weight_matrix':impact_weight_matrix, 'entity_type_ids_matrix':entity_type_ids_matrix,
289 | 'final_layer':final_layer, 'slot_logits':slot_logits},
290 | # 'intent_logits':intent_logits, 'entity_type_ids_matrix':entity_type_ids_matrix},
291 | scaffold_fn=scaffold_fn)
292 | return output_spec
293 |
294 | return model_fn
295 |
--------------------------------------------------------------------------------
/predict_main.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | """Compute realized predictions for a dataset."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | from absl import app
24 | from absl import flags
25 | from absl import logging
26 | import math, time
27 | from termcolor import colored
28 | import tensorflow as tf
29 |
30 | import bert_example
31 | import predict_utils
32 | import csv
33 | import utils
34 | from find_entity import exacter_acmation
35 | from curLine_file import curLine, normal_transformer, other_tag
36 |
37 | FLAGS = flags.FLAGS
38 | flags.DEFINE_string(
39 | 'input_file', None,
40 | 'Path to the input file containing examples for which to compute '
41 | 'predictions.')
42 | flags.DEFINE_enum(
43 | 'input_format', None, ['wikisplit', 'discofuse','qa', 'nlu'],
44 | 'Format which indicates how to parse the input_file.')
45 | flags.DEFINE_string(
46 | 'output_file', None,
47 | 'Path to the CSV file where the predictions are written to.')
48 | flags.DEFINE_string(
49 | 'submit_file', None,
50 | 'Path to the CSV file where the predictions are written to.')
51 | flags.DEFINE_string(
52 | 'label_map_file', None,
53 | 'Path to the label map file. Either a JSON file ending with ".json", that '
54 | 'maps each possible tag to an ID, or a text file that has one tag per line.')
55 | flags.DEFINE_string(
56 | 'slot_label_map_file', None,
57 | 'Path to the label map file. Either a JSON file ending with ".json", that '
58 | 'maps each possible tag to an ID, or a text file that has one tag per line.')
59 | flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.')
60 | flags.DEFINE_integer('max_seq_length', 128, 'Maximum sequence length.')
61 | flags.DEFINE_bool(
62 | 'do_lower_case', False,
63 | 'Whether to lower case the input text. Should be True for uncased '
64 | 'models and False for cased models.')
65 | flags.DEFINE_string('saved_model', None, 'Path to an exported TF model.')
66 | flags.DEFINE_string(
67 | 'domain_name', None,
68 | 'Whether to lower case the input text. Should be True for uncased '
69 | 'models and False for cased models.')
70 | flags.DEFINE_string(
71 | 'entity_type_list_file', None, 'path of entity_type_list_file')
72 |
73 | def main(argv):
74 | if len(argv) > 1:
75 | raise app.UsageError('Too many command-line arguments.')
76 | flags.mark_flag_as_required('input_file')
77 | flags.mark_flag_as_required('input_format')
78 | flags.mark_flag_as_required('output_file')
79 | flags.mark_flag_as_required('label_map_file')
80 | flags.mark_flag_as_required('vocab_file')
81 | flags.mark_flag_as_required('saved_model')
82 | label_map = utils.read_label_map(FLAGS.label_map_file)
83 | slot_label_map = utils.read_label_map(FLAGS.slot_label_map_file)
84 | target_domain_name = FLAGS.domain_name
85 | print(curLine(), "target_domain_name:", target_domain_name)
86 | assert target_domain_name in ["navigation", "phone_call", "music"]
87 | entity_type_list = utils.read_label_map(FLAGS.entity_type_list_file)[FLAGS.domain_name]
88 |
89 | builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
90 | FLAGS.max_seq_length,
91 | FLAGS.do_lower_case, slot_label_map=slot_label_map,
92 | entity_type_list=entity_type_list, get_entity_func=exacter_acmation.get_all_entity)
93 | predictor = predict_utils.LaserTaggerPredictor(
94 | tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
95 | label_map, slot_label_map, target_domain_name=target_domain_name)
96 | print(colored("%s saved_model:%s" % (curLine(), FLAGS.saved_model), "red"))
97 |
98 | ##### test
99 | print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
100 |
101 |
102 | domain_list = []
103 | slot_info_list = []
104 | intent_list = []
105 |
106 | predict_domain_list = []
107 | previous_pred_slot_list = []
108 | previous_pred_intent_list = []
109 | sources_list = []
110 | predict_batch_size = 64
111 | limit = predict_batch_size * 1500 # 5184 # 10001 #
112 | with tf.gfile.GFile(FLAGS.input_file) as f:
113 | reader = csv.reader(f)
114 | session_list = []
115 | for row_id, line in enumerate(reader):
116 | if len(line) == 1:
117 | line = line[0].strip().split("\t")
118 | if len(line) > 4: # 有标注
119 | (sessionId, raw_query, predDomain, predIntent, predSlot, domain, intent, slot) = line
120 | domain_list.append(domain)
121 | intent_list.append(intent)
122 | slot_info_list.append(slot)
123 | else:
124 | (sessionId, raw_query, predDomainIntent, predSlot) = line
125 | if "." in predDomainIntent:
126 | predDomain,predIntent = predDomainIntent.split(".")
127 | else:
128 | predDomain,predIntent = predDomainIntent, predDomainIntent
129 | if "忘记电话" in raw_query:
130 | predDomain = "phone_call" # rule
131 | if "专用道" in raw_query:
132 | predDomain = "navigation" # rule
133 | predict_domain_list.append(predDomain)
134 | previous_pred_slot_list.append(predSlot)
135 | previous_pred_intent_list.append(predIntent)
136 | query = normal_transformer(raw_query)
137 | if query != raw_query:
138 | print(curLine(), len(query), "query: ", query)
139 | print(curLine(), len(raw_query), "raw_query:", raw_query)
140 |
141 | sources = []
142 | if row_id > 0 and sessionId == session_list[row_id - 1][0]:
143 | sources.append(session_list[row_id - 1][1]) # last query
144 | sources.append(query)
145 | session_list.append((sessionId, raw_query))
146 | sources_list.append(sources)
147 |
148 | if len(sources_list) >= limit:
149 | print(colored("%s stop reading at %d to save time" %(curLine(), limit), "red"))
150 | break
151 |
152 | number = len(sources_list) # 总样本数
153 |
154 | predict_intent_list = []
155 | predict_slot_list = []
156 | predict_batch_size = min(predict_batch_size, number)
157 | batch_num = math.ceil(float(number) / predict_batch_size)
158 | start_time = time.time()
159 | num_predicted = 0
160 | modemode = 'a'
161 | if len(domain_list) > 0: # 有标注
162 | modemode = 'w'
163 | with tf.gfile.Open(FLAGS.output_file, modemode) as writer:
164 | # if len(domain_list) > 0: # 有标注
165 | # writer.write("\t".join(["sessionId", "query", "predDomain", "predIntent", "predSlot", "domain", "intent", "Slot"]) + "\n")
166 | for batch_id in range(batch_num):
167 | # if batch_id <= 48:
168 | # continue
169 | sources_batch = sources_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size]
170 | predict_domain_batch = predict_domain_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size]
171 | predict_intent_batch, predict_slot_batch = predictor.predict_batch(sources_batch=sources_batch, target_domain_name=target_domain_name, predict_domain_batch=predict_domain_batch)
172 | assert len(predict_intent_batch) == len(sources_batch)
173 | num_predicted += len(predict_intent_batch)
174 | for id, [predict_intent, predict_slot_info, sources] in enumerate(zip(predict_intent_batch, predict_slot_batch, sources_batch)):
175 | sessionId, raw_query = session_list[batch_id * predict_batch_size + id]
176 | predict_domain = predict_domain_list[batch_id * predict_batch_size + id]
177 | # if predict_domain == "music":
178 | # predict_slot_info = raw_query
179 | # if predict_intent == "play": # 模型分类到播放意图,但没有找到槽位,这时用ac自动机提高召回
180 | # predict_intent_rule, predict_slot_info = rules(raw_query, predict_domain, target_domain_name)
181 | # # if predict_intent_rule in {"pause", "next"}:
182 | # # predict_intent = predict_intent_rule
183 | # if "<" in predict_slot_info_rule : # and "<" not in predict_slot_info:
184 | # predict_slot_info = predict_slot_info_rule
185 | # print(curLine(), "predict_slot_info_rule:", predict_slot_info_rule)
186 | # print(curLine())
187 |
188 | if predict_domain != target_domain_name: # 不是当前模型的domain,用规则识别
189 | predict_intent = previous_pred_intent_list[batch_id * predict_batch_size + id]
190 | predict_slot_info = previous_pred_slot_list[batch_id * predict_batch_size + id]
191 | # else:
192 | # print(curLine(), predict_intent, "predict_slot_info:", predict_slot_info)
193 | predict_intent_list.append(predict_intent)
194 | predict_slot_list.append(predict_slot_info)
195 | if len(domain_list) > 0: # 有标注
196 | domain = domain_list[batch_id * predict_batch_size + id]
197 | intent = intent_list[batch_id * predict_batch_size + id]
198 | slot = slot_info_list[batch_id * predict_batch_size + id]
199 | domain_flag = "right"
200 | if domain != predict_domain:
201 | domain_flag = "wrong"
202 | writer.write("\t".join([sessionId, raw_query, predict_domain, predict_intent, predict_slot_info, domain, intent, slot]) + "\n") # , domain_flag
203 | if batch_id % 5 == 0:
204 | cost_time = (time.time() - start_time) / 60.0
205 | print("%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin." %
206 | (curLine(), batch_id + 1, batch_num, num_predicted, number, cost_time))
207 | cost_time = (time.time() - start_time) / 60.0
208 | print(
209 | f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time/num_predicted*60} s.')
210 |
211 |
212 | if FLAGS.submit_file is not None:
213 | import collections, os
214 | domain_counter = collections.Counter()
215 | if os.path.exists(path=FLAGS.submit_file):
216 | os.remove(FLAGS.submit_file)
217 | with open(FLAGS.submit_file, 'w',encoding='UTF-8') as f:
218 | writer = csv.writer(f, dialect='excel')
219 | # writer.writerow(["session_id", "query", "intent", "slot_annotation"]) # TODO
220 | for example_id, sources in enumerate(sources_list):
221 | sessionId, raw_query = session_list[example_id]
222 | predict_domain = predict_domain_list[example_id]
223 | predict_intent = predict_intent_list[example_id]
224 | predict_domain_intent = other_tag
225 | domain_counter.update([predict_domain])
226 | slot = raw_query
227 | if predict_domain != other_tag:
228 | predict_domain_intent = "%s.%s" % (predict_domain, predict_intent)
229 | slot = predict_slot_list[example_id]
230 | # if predict_domain == "navigation": # TODO TODO
231 | # predict_domain_intent = other_tag
232 | # slot = raw_query
233 | line = [sessionId, raw_query, predict_domain_intent, slot]
234 | writer.writerow(line)
235 | print(curLine(), "example_id=", example_id)
236 | print(curLine(), "domain_counter:", domain_counter)
237 | cost_time = (time.time() - start_time) / 60.0
238 | num_predicted = example_id+1
239 | print(curLine(), "%s cost %f s" % (target_domain_name, cost_time))
240 | print(
241 | f'{curLine()} {num_predicted} predictions saved to:{FLAGS.submit_file}, cost {cost_time} min, ave {cost_time/num_predicted*60} s.')
242 |
243 |
244 |
245 | def rules(raw_query, predict_domain, target_domain_name):
246 | predict_intent = predict_domain # OTHERS
247 | slot_info = raw_query
248 | if predict_domain == "navigation":
249 | predict_intent = 'navigation'
250 | if "打开" in raw_query:
251 | predict_intent = "open"
252 | elif "开始" in raw_query:
253 | predict_intent = "start_navigation"
254 | for word in predict_utils.cancel_keywords:
255 | if word in raw_query:
256 | predict_intent = 'cancel_navigation'
257 | break
258 | # slot_info = raw_query
259 | # if predict_intent == 'navigation': TODO
260 | slot_info = exacter_acmation.get_slot_info(raw_query, domain=predict_domain)
261 | # if predict_intent != 'navigation': # TODO
262 | # print(curLine(), "slot_info:", slot_info)
263 | elif predict_domain == 'music':
264 | predict_intent = 'play'
265 | for word in predict_utils.cancel_keywords:
266 | if word in raw_query:
267 | predict_intent = 'pause'
268 | break
269 | for word in ["下一", "换一首", "换一曲", "切歌", "其他歌"]:
270 | if word in raw_query:
271 | predict_intent = 'next'
272 | break
273 | slot_info = exacter_acmation.get_slot_info(raw_query, domain=predict_domain)
274 | if predict_intent not in ['play','pause'] and slot_info != raw_query: # 根据槽位修改意图 换一首高安的红尘情歌
275 | print(curLine(), predict_intent, slot_info)
276 | predict_intent = 'play'
277 | # if predict_intent != 'play': # 换一首高安的红尘情歌
278 | # print(curLine(), predict_intent, slot_info)
279 | elif predict_domain == 'phone_call':
280 | predict_intent = 'make_a_phone_call'
281 | for word in predict_utils.cancel_keywords:
282 | if word in raw_query:
283 | predict_intent = 'cancel'
284 | break
285 | slot_info = exacter_acmation.get_slot_info(raw_query, domain=predict_domain)
286 | return predict_intent, slot_info
287 |
288 | if __name__ == '__main__':
289 | app.run(main)
290 |
--------------------------------------------------------------------------------
/confusion_words_danyin.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # 混淆词提取算法 不考虑多音字,只考虑单音,故只处理同音字的情况
3 | from pypinyin import lazy_pinyin, pinyin, Style
4 | from Levenshtein import distance # python-Levenshtein
5 | # 编辑距离 (Levenshtein Distance算法)
6 | import copy
7 | import sys
8 | import os, json
9 | from find_entity.acmation import entity_files_folder, entity_folder
10 | from curLine_file import curLine
11 |
12 | number_map = {"0":"十", "1":"一", "2":"二", "3":"三", "4":"四", "5":"五", "6":"六", "7":"七", "8":"八", "9":"九"}
13 | cons_table = {'n': 'l', 'l': 'n', 'f': 'h', 'h': 'f', 'zh': 'z', 'z': 'zh', 'c': 'ch', 'ch': 'c', 's': 'sh', 'sh': 's'}
14 | vowe_table = {'ai': 'ei', 'an': 'ang', 'en': 'eng', 'in': 'ing',
15 | 'ei': 'ai', 'ang': 'an', 'eng': 'en', 'ing': 'in'}
16 | stop_chars = ["的", "了", "啦", "着","们", "呀", "啊", "地", "呢", ".", ",", ",", "。", "-", "_,", "(", ")", "(", "在", "儿"] # TODO 儿
17 |
18 | # 非标准拼音 http://www.pinyin.info/rules/initials_finals.html
19 | w_pinyin_map = {"wu":"u", "wa":"ua", "wo":"uo", "wai":"uai", "wei":"ui", "wan":"uan", "wang":"uang", "wen":"un", "weng":"ueng"}
20 | y_pinyin_map = {"yi":"i", "ya":"ia", "ye":"ie", "yao":"iao", "you":"iu", "yan":"ian", "yang":"iang", "yin":"in","ying":"ing", "yong":"iong"}
21 | def my_lazy_pinyin(word, style=Style.NORMAL, errors='default', strict=True):
22 | rtn_list = lazy_pinyin(word, style=style, errors=errors, strict=strict)
23 | # 非标准拼音
24 | if style == Style.NORMAL: # 返回整体拼音的列表
25 | for id, total_pinyin in enumerate(rtn_list):
26 | if total_pinyin[0]=="w" and total_pinyin in w_pinyin_map:
27 | rtn_list[id] = w_pinyin_map[total_pinyin]
28 | elif total_pinyin[0]=="y" and total_pinyin in y_pinyin_map:
29 | rtn_list[id] = y_pinyin_map[total_pinyin]
30 | elif style == Style.INITIALS: # 只返回声母的列表
31 | total_pinyin_list = lazy_pinyin(word, style=Style.NORMAL, errors=errors, strict=strict)
32 | for id, (shengmu, total_pinyin) in enumerate(zip(rtn_list, total_pinyin_list)):
33 | if total_pinyin[0]=="w" and total_pinyin in w_pinyin_map:
34 | rtn_list[id] = ""
35 | elif total_pinyin[0]=="y" and total_pinyin in y_pinyin_map:
36 | rtn_list[id] = ""
37 | else:
38 | raise Exception("wrong style")
39 | return rtn_list
40 |
41 |
42 | def add_func(entity_info_dict, entity_type, raw_entity, entity_variation, char_distance):
43 | shengmu_list, yunmu_list, pinyin_list = my_pinyin(word=entity_variation)
44 | assert len(shengmu_list) == len(yunmu_list) == len(pinyin_list)
45 | combination = "".join(pinyin_list)
46 | # if len(combination) < 3:
47 | # print(curLine(), "warning:", raw_entity, "combination:", combination)
48 | entity_info_dict[raw_entity]["combination_list"].append([combination, 0, char_distance, entity_variation])
49 | base_distance = 1.0 # 修改一个声母或韵母导致的距离,小于1
50 | for shengmu_index, shengmu in enumerate(shengmu_list):
51 | if shengmu not in cons_table: # 没有混淆对
52 | continue
53 | shengmu_list_variation = copy.copy(shengmu_list)
54 | shengmu_list_variation[shengmu_index] = cons_table[shengmu] # 修改一个声母
55 | combination_variation = "".join(
56 | ["%s%s" % (shengmu, yunmu) for shengmu, yunmu in zip(shengmu_list_variation, yunmu_list)])
57 | entity_info_dict[raw_entity]["combination_list"].append([combination_variation, base_distance, char_distance, entity_variation])
58 | if entity_type in ["song"]:
59 | for yunmu_index, yunmu in enumerate(yunmu_list):
60 | if yunmu not in vowe_table: # 没有混淆对
61 | continue
62 | yunmu_list_variation = copy.copy(yunmu_list)
63 | yunmu_list_variation[yunmu_index] = vowe_table[yunmu] # 修改一个韵母
64 | combination_variation = "".join(
65 | ["%s%s" % (shengmu, yunmu) for shengmu, yunmu in zip(shengmu_list, yunmu_list_variation)])
66 | entity_info_dict[raw_entity]["combination_list"].append([combination_variation, base_distance, char_distance, entity_variation])
67 | return
68 |
69 | def add_pinyin(raw_entity, entity_info_dict, pri, entity_type):
70 | entity = raw_entity
71 | for k, v in number_map.items():
72 | entity = entity.replace(k, v)
73 | if raw_entity not in entity_info_dict: # 新的实体
74 | entity_info_dict[raw_entity] = {"priority": pri, "combination_list": []}
75 | add_func(entity_info_dict, entity_type, raw_entity, entity_variation=entity, char_distance=0)
76 | for char in stop_chars:
77 | if char in entity:
78 | entity_variation = entity.replace(char, "")
79 | char_distance = len(entity) - len(entity_variation)
80 | add_func(entity_info_dict, entity_type, raw_entity, entity_variation, char_distance=char_distance)
81 | else:
82 | old_pri = entity_info_dict[raw_entity]["priority"]
83 | if pri > old_pri:
84 | entity_info_dict[raw_entity]["priority"] = pri
85 | return
86 |
87 | def get_entityType_pinyin(entity_type):
88 | entity_info_dict = {}
89 | entity_file = os.path.join(entity_folder, "%s.txt" % entity_type)
90 | with open(entity_file, "r") as fr:
91 | lines = fr.readlines()
92 | priority = 3
93 | if entity_type in ["song"]:
94 | priority -= 0.5
95 | print(curLine(), "get %d %s from %s, priority=%f" % (len(lines), entity_type, entity_file, priority))
96 | for line in lines:
97 | raw_entity = line.strip()
98 | add_pinyin(raw_entity, entity_info_dict, priority, entity_type)
99 |
100 | ### TODO 从标注语料中挖掘得到
101 | entity_file = os.path.join(entity_files_folder, "%s.json" % entity_type)
102 | with open(entity_file, "r") as fr:
103 | current_entity_dict = json.load(fr)
104 | print(curLine(), "get %d %s from %s, priority=%f" % (len(current_entity_dict), entity_type, entity_file, priority))
105 | for entity_before, entity_after_times in current_entity_dict.items():
106 | entity_after = entity_after_times[0]
107 | priority = 4
108 | if entity_type in ["song"]:
109 | priority -= 0.5
110 | add_pinyin(entity_after, entity_info_dict, priority, entity_type)
111 | return entity_info_dict
112 |
113 | # 用编辑距离度量拼音字符串之间的相似度 不考虑相似实体的多音情况
114 | def pinyin_similar_word_danyin(entity_info_dict, word, jichu_distance=0.05, char_ratio=0.55, char_distance=1.0):
115 | if word in entity_info_dict: # 存在实体,无需纠错
116 | return 1.0, word
117 | best_similar_word = None
118 | top_similar_score = 0
119 | if "0" in word: # 这里先尝试替换成零,后面会尝试替换成十的情况
120 | word_ling = word.replace("0", "零")
121 | top_similar_score_ling, best_similar_word_ling = pinyin_similar_word_danyin(
122 | entity_info_dict, word_ling, jichu_distance, char_ratio, char_distance)
123 | if top_similar_score_ling >= 0.99999:
124 | return top_similar_score_ling, best_similar_word_ling
125 | elif top_similar_score_ling >= top_similar_score:
126 | top_similar_score = top_similar_score_ling
127 | best_similar_word = best_similar_word_ling
128 | for k,v in number_map.items():
129 | word = word.replace(k, v)
130 | if word in entity_info_dict: # 存在实体,无需纠错
131 | return 1.0, word
132 | all_combination = [["".join(my_lazy_pinyin(word)), 0, 0, word]]
133 | for char in stop_chars:
134 | if char in word:
135 | entity_variation = word.replace(char, "")
136 | c = (len(word) - len(entity_variation)) *char_distance # TODO
137 | all_combination.append( ["".join(my_lazy_pinyin(entity_variation)), 0, c, entity_variation] )
138 |
139 | for current_combination, basebase, c1, entity_variation in all_combination: # 当前的各种发音
140 | if len(current_combination) == 0:
141 | continue
142 | similar_word = None
143 | current_distance = sys.maxsize
144 | for entity, priority_comList in entity_info_dict.items():
145 | priority = priority_comList["priority"]
146 | com_list = priority_comList["combination_list"]
147 | for com, a, c, entity_variation_ku in com_list:
148 | d = basebase + jichu_distance*a + char_distance*c+c1 - priority*0.01 + \
149 | distance(com, current_combination)*(1.0-char_ratio) + distance(entity_variation_ku,entity_variation) * char_ratio
150 | if d < current_distance:
151 | current_distance = d
152 | similar_word = entity
153 | current_similar_score = 1.0 - float(current_distance) / len(current_combination)
154 | if current_similar_score > top_similar_score:
155 | best_similar_word = similar_word
156 | top_similar_score = current_similar_score
157 | return top_similar_score, best_similar_word
158 |
159 | # 自己包装的函数,返回字符的声母(可能为空,如啊呀),韵母,整体拼音
160 | def my_pinyin(word):
161 | shengmu_list = my_lazy_pinyin(word, style=Style.INITIALS, strict=False)
162 | total_pinyin_list = my_lazy_pinyin(word, errors='default')
163 | yunmu_list = []
164 | for shengmu, total_pinyin in zip(shengmu_list, total_pinyin_list):
165 | yunmu = total_pinyin[total_pinyin.index(shengmu) + len(shengmu):]
166 | yunmu_list.append(yunmu)
167 | return shengmu_list, yunmu_list, total_pinyin_list
168 |
169 |
170 | singer_pinyin = get_entityType_pinyin(entity_type="singer")
171 | print(curLine(), len(singer_pinyin), "singer_pinyin")
172 |
173 | song_pinyin = get_entityType_pinyin(entity_type="song")
174 | print(curLine(), len(song_pinyin), "song_pinyin")
175 |
176 | def correct_song(entity_before, jichu_distance=0.001, char_ratio=0.53, char_distance=0.1): # char_ratio=0.48):
177 | top_similar_score, best_similar_word = pinyin_similar_word_danyin(
178 | song_pinyin, entity_before, jichu_distance, char_ratio, char_distance)
179 | return top_similar_score, best_similar_word
180 |
181 | def correct_singer(entity_before, jichu_distance=0.001, char_ratio=0.5, char_distance=0.1): # char_ratio=0.1, char_distance=1.0):
182 | top_similar_score, best_similar_word = pinyin_similar_word_danyin(
183 | singer_pinyin, entity_before, jichu_distance, char_ratio, char_distance)
184 | return top_similar_score, best_similar_word
185 |
186 |
187 |
188 | def test(entity_type, pinyin_ku, score_threshold, jichu_distance, char_ratio=0.0, char_distance=0.1): # 测试纠错的准确率
189 | # 把给的实体库加到集合中
190 | entity_set = set()
191 | entity_file = os.path.join(entity_folder, "%s.txt" % entity_type)
192 | with open(entity_file, "r") as fr:
193 | lines = fr.readlines()
194 | for line in lines:
195 | raw_entity = line.strip()
196 | entity_set.add(raw_entity)
197 | # 抽取的实体作为测试集合
198 | entity_file = os.path.join(entity_files_folder, "%s.json" % entity_type)
199 | with open(entity_file, "r") as fr:
200 | current_entity_dict = json.load(fr)
201 | test_num = 0.0
202 | wrong_num = 0.0
203 | right_num = 0.0
204 | for entity_before, entity_after_times in current_entity_dict.items():
205 | entity_after = entity_after_times[0]
206 | test_num += 1
207 | top_similar_score, best_similar_word = pinyin_similar_word_danyin(pinyin_ku, entity_before, jichu_distance, char_ratio, char_distance)
208 | predict_after = entity_before
209 | if top_similar_score > score_threshold:
210 | predict_after = best_similar_word
211 | if predict_after != entity_after:
212 | wrong_num += 1
213 | elif entity_before != predict_after: # 纠正对了
214 | right_num += 1
215 | return wrong_num, test_num, right_num
216 |
217 |
218 | if __name__ == "__main__":
219 | for word in ["霍元甲", "梁梁"]:
220 | a = correct_song(entity_before=word, jichu_distance=1.0, char_ratio=0.48)
221 | print(curLine(), word, a)
222 | for word in ["霍元甲", "m c天佑的歌曲", "v.a"]:
223 | a = correct_singer(entity_before=word, jichu_distance=1.0, char_ratio=0.1)
224 | print(curLine(), word, a)
225 |
226 | s = "啊饿恩昂你为什*.\"123c单身"
227 | s = "啊饿恩昂你为什单身的万瓦呢?"
228 | s = "打扰一样呆大衣虽四有挖"
229 | strict = False
230 | # yunmu_list = my_lazy_pinyin(s, style=Style.FINALS, strict=strict) # 根据声母(拼音的前半部分)得到韵母
231 | shengmu_list = my_lazy_pinyin(s, style=Style.INITIALS, strict=strict) # 返回声母,但是严格按照标准(strict=True)有些字没有声母会返回空字符串,strict=False时会返回
232 | pinyin_list = my_lazy_pinyin(s, errors='default') # ""ignore")
233 | print(curLine(), len(shengmu_list), "shengmu_list:", shengmu_list)
234 | # print(curLine(), len(yunmu_list), "yunmu_list:", yunmu_list)
235 | print(curLine(), len(pinyin_list), "pinyin_list:", pinyin_list)
236 | print(curLine(), my_pinyin(s))
237 | for id, c in enumerate(s):
238 | print(curLine(), id, c, my_pinyin(c))
239 |
240 |
241 |
242 | best_score = -sys.maxsize
243 | pinyin_ku = None
244 |
245 |
246 |
247 | # # # # song
248 | entity_type = "song"
249 | if entity_type == "song":
250 | pinyin_ku = song_pinyin
251 | elif entity_type == "singer":
252 | pinyin_ku = singer_pinyin
253 | for char_distance in [0.1]:
254 | for char_ratio in [0.4, 0.53, 0.55, 0.68]: #
255 | for score_threshold in [0.45, 0.65]: #
256 | for jichu_distance in [0.001]: # [0.005, 0.01, 0.03, 0.05]:#
257 | wrong_num, test_num, right_num = test(entity_type, pinyin_ku, score_threshold, jichu_distance, char_ratio, char_distance=char_distance)
258 | score = 0.01*right_num - wrong_num
259 | print("char_distance=%f, threshold=%f, jichu_distance=%f, char_ratio=%f, wrong_num=%d, right_num=%d, score=%f" %
260 | (char_distance, score_threshold, jichu_distance, char_ratio, wrong_num, right_num, score))
261 | if score > best_score:
262 | print(curLine(), "score=%f, best_score=%f" % (score, best_score))
263 | best_score = score
264 |
265 | # # singer
266 | # entity_type = "singer" #
267 | # if entity_type == "song":
268 | # pinyin_ku = song_pinyin
269 | # elif entity_type == "singer":
270 | # pinyin_ku = singer_pinyin
271 | # for char_distance in [0.05, 0.1]:
272 | # for char_ratio in [0.5]:
273 | # for score_threshold in [0.45, 0.65]:
274 | # for jichu_distance in [0.001]: # [0.005, 0.01, 0.03, 0.05]:#
275 | # wrong_num, test_num, right_num = test(entity_type, pinyin_ku, score_threshold, jichu_distance, char_ratio, char_distance)
276 | # score = 0.01*right_num - wrong_num
277 | # print("char_distance=%f, threshold=%f, jichu_distance=%f, char_ratio=%f, wrong_num=%d, right_num=%d, score=%f" %
278 | # (char_distance, score_threshold, jichu_distance, char_ratio, wrong_num, right_num, score))
279 | # if score > best_score:
280 | # print(curLine(), "score=%f, best_score=%f" % (score, best_score))
281 | # best_score = score
282 |
283 |
--------------------------------------------------------------------------------