├── README.md ├── data ├── smp_2019_task1_train.json ├── test.json └── train.json ├── evaluation.py ├── lstm_crf_layer.py ├── main.py ├── model_config.py ├── patterns.py ├── report ├── READ.ME ├── SMP2019ECDT任务1技术报告-出门问问信息科技有限公司 (1).pdf ├── SMP2019ECDT技术报告-北京沃丰时代数据科技有限公司.pdf └── coffeeNLU小队技术报告_final.pdf ├── res.out └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # bert-joint-NLU 2 | 使用bert做领域分类、意图识别和槽位填充任务 3 | -------------------------------------------------------------------------------- /data/test.json: -------------------------------------------------------------------------------- 1 | [{"text": "打开相机这"}, {"text": "国际象棋开局"}, {"text": "打开淘宝购物"}, {"text": "搜狗"}, {"text": "打开uc浏览器"}, {"text": "帮我打开人人"}, {"text": "打开酷狗并随机播放"}, {"text": "赶集"}, {"text": "从合肥到上海可以到哪坐车?"}, {"text": "从台州到金华的汽车。"}, {"text": "从西安到石嘴山的汽车票。"}, {"text": "云浮在哪里"}, {"text": "从新加坡花园怎么去宁溪路"}, {"text": "到合肥市逍遥津公园怎么走"}, {"text": "我去杨家坪该怎么走"}, {"text": "河北在哪里"}, {"text": "导航到石化小学"}, {"text": "湘潭到常德怎么走"}, {"text": "去永康怎么走"}, {"text": "去湘江"}, {"text": "带我去新都会"}, {"text": "导航到江门"}, {"text": "帮我查一下淮安到石家庄的火车"}, {"text": "上饶到南昌的火车票"}, {"text": "查询泉州到上海的动车票"}, {"text": "湖北荆州到黄石的火车"}, {"text": "帮我定一张去镇江的动车票"}, {"text": "帮我查一下赣州到杭州的火车票"}, {"text": "从武汉到杭州的火车"}, {"text": "查询湛江到广州的火车票"}, {"text": "查一下合肥到淮南明天的火车"}, {"text": "查询昆明到合肥的火车"}, {"text": "四月六号南宁到上海的火车"}, {"text": "到北京的火车班次"}, {"text": "搜索热门电影。"}, {"text": "最近在播什么电影?"}, {"text": "最近有什么热门的电影?"}, {"text": "哦哦,最新的电影。"}, {"text": "打电话给陈正成"}, {"text": "打电话给贾青子"}, {"text": "给万结打电话"}, {"text": "打电话给谢文豪"}, {"text": "打电话给张志国"}, {"text": "给金华打电话"}, {"text": "打电话麦亚伦"}, {"text": "打电话给王宇婕"}, {"text": "发短信给张三说“画皮2在哪个台播出”"}, {"text": "发送短信给陈群"}, {"text": "给严巧发短信"}, {"text": "把王建春的电话号码发给张学政"}, {"text": "给刘俊发短信"}, {"text": "把张三的电话发给李四"}, {"text": "给德师傅发短信"}, {"text": "查看短消息"}, {"text": "给张金波发短消息今天中午一起吃饭"}, {"text": "发条短信给杨桑"}, {"text": "给杨斌发短信"}, {"text": "把莫若的电话号码发送给反面教材"}, {"text": "查找张伟军"}, {"text": "福寿鱼怎么煮比较好吃?"}, {"text": "有什么特色美食。"}, {"text": "炒四季豆。"}, {"text": "自演艺,羊肉怎么做?"}, {"text": "牛骨汤"}, {"text": "培根金针菇怎么做?"}, {"text": "好吧,我想知道稻香排骨是怎么做的?"}, {"text": "烤鸡翅怎么做?"}, {"text": "牛肉的做法。"}, {"text": "梅菜扣肉怎么做啊?"}, {"text": "教我做馒头。"}, {"text": "送黄炒鸡蛋怎么做?"}, {"text": "怎么炒南瓜?"}, {"text": "做蛋挞怎么做"}, {"text": "鲜肉玉米羹怎么做?"}, {"text": "我说白鱼怎么做"}, {"text": "干锅香辣虾"}, {"text": "番红烧肉的做法。"}, {"text": "那做红烧鱼需要哪些材料啊?"}, {"text": "蒜茄子怎么做?"}, {"text": "萝卜炖排骨汤怎么做?"}, {"text": "南瓜饼的做法。"}, {"text": "小军儿炖土豆该怎么做?"}, {"text": "甲鱼,邓鸡,如何制作。"}, {"text": "山寨鱼怎么做?"}, {"text": "煮面怎么做?"}, {"text": "清蒸鱼怎么做?吴?"}, {"text": "我想听回锅肉做法。"}, {"text": "怎么做卤猪脚"}, {"text": "猪肚鸡怎么做啊?"}, {"text": "鱼香肉丝,怎么炒?"}, {"text": "酸辣藕丁的做法。"}, {"text": "海星的做法。"}, {"text": "猪舌头怎么做比较啊子?"}, {"text": "嗯,知道红烧肉怎么做?"}, {"text": "木瓜汤怎么做"}, {"text": "做红烧鱼的步骤。"}, {"text": "到了有鲈鱼怎么做?"}, {"text": "民间小炒肉的做法。"}, {"text": "糖醋排骨怎么做?好了额,你。"}, {"text": "煮水饺。"}, {"text": "凉拌海蜇怎么做好吃?"}, {"text": "我想答复这个邮件"}, {"text": "替我答复这条邮件"}, {"text": "发邮件给晶晶"}, {"text": "我要回复这个邮件"}, {"text": "今天湖南卫视的电子节目单"}, {"text": "浙江卫视今晚有中国好声音吗?"}, {"text": "哪个电视台在演叶问。"}, {"text": "读书频道上个星期五有播放新闻联播吗"}, {"text": "今天什么好看电影?"}, {"text": "回放昨晚的焦点访谈"}, {"text": "财经频道今天有什么节目"}, {"text": "给我搜搜什么电视台什么时间有转播神探高伦布"}, {"text": "15号的我愿意"}, {"text": "回看上周末早上8点的湖南卫视"}, {"text": "湖南卫视节目录"}, {"text": "CCTV8高清电视剧"}, {"text": "电视台在播什么节目"}, {"text": "BTV生活节目选择回放"}, {"text": "八月十一日从厦门飞往上海的航班"}, {"text": "明天去成都的航班有几趟"}, {"text": "查一下澳门到吉隆坡的飞机"}, {"text": "明天去桂林的航班。"}, {"text": "给我定一张株洲到深圳的机票"}, {"text": "查询大大后天广州到武汉的航班"}, {"text": "十月四号从广州到北京的飞机票多少钱"}, {"text": "帮我订一张杭州到上海的机票"}, {"text": "从厦门到昆明的飞机"}, {"text": "到北京的飞机。"}, {"text": "如何治疗脚气"}, {"text": "得了灰指甲怎么办呢?"}, {"text": "近视眼怎么治疗"}, {"text": "得了皮肤病怎么办?"}, {"text": "唇腭裂。"}, {"text": "前列腺炎该怎么治?"}, {"text": "关节痛怎么办?"}, {"text": "大肠癌"}, {"text": "鼻息肉"}, {"text": "帮我查一下上一期的大乐透"}, {"text": "上一期15选5的中奖号码是多少?"}, {"text": "帮我查一下2013年2月8日开奖的15选5中奖号码"}, {"text": "中超比赛。"}, {"text": "中超的比赛时间。"}, {"text": "我想看足球中超,第25轮的比赛时间。"}, {"text": "帮我查一下周杰伦的歌"}, {"text": "唱一首两只蝴蝶。"}, {"text": "放歌,天亮了。"}, {"text": "来首只爱你一个。"}, {"text": "来一首有点舍不得"}, {"text": "我想听周杰伦的菊花台"}, {"text": "播放眼色"}, {"text": "来一首听你听我"}, {"text": "我要听那首最美的时光"}, {"text": "我要听还是会这首歌"}, {"text": "是否还有其他关于钓鱼岛的新闻"}, {"text": "加,听新闻。"}, {"text": "总计今天的新闻。"}, {"text": "答复,我要听新闻。"}, {"text": "打开腾讯头条新闻"}, {"text": "我也听新闻。"}, {"text": "我想听,新闻。"}, {"text": "我想今天新闻最近有什么新的新闻?"}, {"text": "姨,最近有什么新闻。"}, {"text": "我要看现代言情小说"}, {"text": "搜索小说校园全能高手。"}, {"text": "陈忠实的爱情小说《白鹿原》"}, {"text": "你能搜到小说不?"}, {"text": "来首山口。"}, {"text": "思春夜喜雨下一句是。"}, {"text": "背诵一首唐诗"}, {"text": "月咏牡丹的下一句,是什么?"}, {"text": "长风吹月渡海来下一句是什么?"}, {"text": "念一首唐诗给我听。"}, {"text": "朗诵李白的静夜思。"}, {"text": "念首诗听一下。"}, {"text": "杜甫的诗。"}, {"text": "呦呦鹿鸣的下一句。"}, {"text": "来一首李煜的词。"}, {"text": "若相惜不弃下一句是什么?"}, {"text": "背诵唐诗三百首静夜思。"}, {"text": "给我来首诗。"}, {"text": "所以说李白的诗。"}, {"text": "何以解忧的下一句是什么?"}, {"text": "请帮我调频90.2连云港经济广播电台"}, {"text": "收听安徽广播电台。"}, {"text": "给我讲一个谜语吧!"}, {"text": "字谜你。"}, {"text": "查个字谜。"}, {"text": "打一个灯谜啊,你。"}, {"text": "查询中国石化的股价"}, {"text": "人福医药股票"}, {"text": "平安保险的股票价格"}, {"text": "万科a股票"}, {"text": "查询股票00230"}, {"text": "永辉超市股份的价格"}, {"text": "查寻科大讯飞的股"}, {"text": "招商银行港股的股价"}, {"text": "查询一汽轿车股票"}, {"text": "翻译李淼"}, {"text": "晚安的英语"}, {"text": "请翻译苹果"}, {"text": "你今天去了哪里用英文怎么讲"}, {"text": "翻译慷慨激昂"}, {"text": "翻译我今天要去打网球"}, {"text": "我记住了用英文怎么讲"}, {"text": "小老虎用英语怎么讲"}, {"text": "早上好英语怎么讲"}, {"text": "我想看高尔夫网球频道"}, {"text": "搜索湖南卫视直播。"}, {"text": "南京新闻频道"}, {"text": "找湖南卫视"}, {"text": "陕西二套"}, {"text": "我看呀中央五台"}, {"text": "调到新疆综艺"}, {"text": "天津电视台的国际频道,拜托了哈"}, {"text": "找一下非诚勿扰娱乐节目"}, {"text": "恶熊出没"}, {"text": "那金花和她的女婿"}, {"text": "锁梦楼"}, {"text": "我是特种兵二"}, {"text": "我要看巴拉小魔仙"}, {"text": "蝎子王蝎子王"}, {"text": "武间道"}, {"text": "2012年美国公告牌音乐大奖颁奖礼"}, {"text": "我想看张艺谋导演的电影"}, {"text": "当婆婆遇上妈"}, {"text": "你们什么搞笑剧"}, {"text": "智慧树"}, {"text": "我要当八路军"}, {"text": "铁甲威虫"}, {"text": "海棉宝宝"}, {"text": "我想看电影超人"}, {"text": "我要看西游记啊"}, {"text": "我想看喜羊羊"}, {"text": "刘三姐"}, {"text": "你好我想看电影"}, {"text": "电视剧粘豆包"}, {"text": "查找90年代好评的武侠电影"}, {"text": "下半年的电视剧"}, {"text": "今天有雨吗"}, {"text": "请打开网页"}, {"text": "索尼爱立信官网"}, {"text": "打开新浪网我肏"}, {"text": "打开昆明三六零手机网"}, {"text": "打开五五两性网站"}, {"text": "把我打开淘宝的网页"}, {"text": "新浪汽车"}, {"text": "打开会说话的汤姆猫"}, {"text": "请打开qq"}, {"text": "打开相机我想你"}, {"text": "打开系统软件清理"}, {"text": "打开个性闹钟"}, {"text": "百度百科"}, {"text": "手抓饼的做法。"}, {"text": "国足最新赛程。"}, {"text": "面条怎么样煮啊?"}, {"text": "搜索,西米露的做法。"}, {"text": "《武动乾坤》"}, {"text": "带我去丹阳市眼镜市场"}, {"text": "从我这里到科大讯飞走高速路线"}, {"text": "导航到望江西路上去"}, {"text": "去汽车南站怎么走"}, {"text": "到环都大酒店怎么走"}, {"text": "把李文鼎的号码发给谢服全"}, {"text": "把鱼苗的电话号码发给徐景明"}, {"text": "把许琦彪的号码发给徐鹏"}, {"text": "把上官围的电话发给戴勋"}, {"text": "把李志强的号码发给贾洪鉴"}, {"text": "把李会计的电话发给小江"}, {"text": "把张玉娟的手机号码发送给吴伟"}, {"text": "发邮件"}, {"text": "发短信给王小露的"}, {"text": "给刘志发短信"}, {"text": "发短信给王勋连说我很好"}, {"text": "给马翊桐发信息"}, {"text": "给杨李鹏发短信"}, {"text": "我要答复邮件"}, {"text": "全部应答"}, {"text": "回看安徽卫视烽火佳人"}, {"text": "大乐透中奖号码为"}, {"text": "上一期双色球开什么"}, {"text": "帮我买两注双色球"}, {"text": "搜索大乐透的中奖号码"}, {"text": "查看双色球开奖"}, {"text": "本期七星彩的中奖号码是多少?"}, {"text": "武汉理工大学在哪"}, {"text": "查查科大讯飞在哪里"}, {"text": "昆山大润发在哪里"}, {"text": "查一下沿途的科大讯飞"}, {"text": "灌阳县位置"}, {"text": "帮我查一下我所在的位置"}, {"text": "沿途有没有加油站"}, {"text": "CCTV6电影"}, {"text": "我想听昨天晚上的新闻。"}, {"text": "今天有什么军事新闻"}, {"text": "播放断点"}, {"text": "跳换到湖南卫视"}, {"text": "我想听星月神话"}, {"text": "放一首青花瓷哦哦"}, {"text": "读句诗来听听。"}, {"text": "分手时,背一首诗吧?"}, {"text": "朗诵诗词。"}, {"text": "呵呵,背首诗来听。"}, {"text": "读一首古诗给我听吧!"}, {"text": "打给曾玮文"}, {"text": "我给你打电话行吗"}, {"text": "呼叫小惠"}, {"text": "给木九超打电话"}, {"text": "打电话给阿敏"}, {"text": "挑逗你真好玩儿英语怎么说"}, {"text": "你去玩手机了吧用英语怎么说"}, {"text": "大家今天心情如何用英语怎么说"}, {"text": "笑话用英语怎么说"}, {"text": "点击说话的英语怎么写"}, {"text": "茶用英语怎么说呀"}, {"text": "赶集网"}, {"text": "现在打开你的官网"}, {"text": "中国新华网打开"}, {"text": "进入当乐网"}, {"text": "我要上中学学科网"}, {"text": "打开江苏移动网上营业厅网页"}, {"text": "我想写一个新邮件"}, {"text": "我想写一个邮件"}, {"text": "新建联系人18622625490"}, {"text": "给我写一个新邮件"}, {"text": "帮我写一条新邮件"}, {"text": "添加一条通讯录姓名张守刚号码是13811725158"}, {"text": "新建联系人天天"}, {"text": "我想要转发这条邮件"}, {"text": "找一首歌叫斑马斑马"}, {"text": "中国银行是跌了还是涨了"}, {"text": "人造鸡蛋的配方。"}, {"text": "打开应用漫画"}, {"text": "请打电话给沈晨浩"}, {"text": "但是我想上腾讯网"}, {"text": "科大讯飞今天的股票"}, {"text": "一个鬼子都不留"}, {"text": "打开音乐播放器"}, {"text": "上海到合肥怎么坐汽车?"}, {"text": "九江到景德镇的汽车。"}, {"text": "在合肥怎么坐去南京的汽车"}, {"text": "太原到广州的车。"}, {"text": "从杭州到黄山的汽车。"}, {"text": "帮我查一下温州到汉中的汽车。"}, {"text": "从北海到东莞的汽车。"}, {"text": "查一下运城到临汾的汽车。"}, {"text": "从东莞去北海的汽车。"}, {"text": "查看济南到东营的,汽车。"}, {"text": "从无锡市到西安市的汽车。"}, {"text": "宁波到岳阳的汽车。"}, {"text": "到无锡到北京的火车有几趟了"}, {"text": "明天去蚌埠的车票"}, {"text": "帮我搜索一下宜春到长沙的火车"}, {"text": "查询广州到海南的航班"}, {"text": "查一下那个南京到昆明的火车"}, {"text": "从合肥去西安的火车"}, {"text": "去重庆的机票多少钱。"}, {"text": "我要查找雷迪波尔电影院最新上映的大片"}, {"text": "现在电视台在放什么节目"}, {"text": "现在有什么好的电视剧。"}, {"text": "湖北卫视今天下午六点以后开始播男生女生向前冲"}, {"text": "给我查下明天凌晨在放妈妈来了"}, {"text": "今天晚上有没有中国好声音呢?"}, {"text": "现在看看有最新的新闻吗?"}, {"text": "上海东方卫视现在在播放什么?"}, {"text": "把高建清的号码发给潘青华"}, {"text": "把千哥的电话号码发给萌萌"}, {"text": "做面包的配方。"}, {"text": "糖醋藕怎么做?"}, {"text": "红烧芋头做法。"}, {"text": "请问红烧茄子怎么做?"}, {"text": "国画怎样做土豆烧鸡块"}, {"text": "要怎么煮饭。"}, {"text": "板栗如何烧?"}, {"text": "搜索西红柿的做法。"}, {"text": "狗光鱼怎么做?"}, {"text": "我说吃飞蟹,怎么做?"}, {"text": "秋葵怎么做?"}, {"text": "花格等蛋的做法。"}, {"text": "发短信给姐姐说晚上一起吃面条"}, {"text": "发短信给子健晚上回家吃饭"}, {"text": "翻译牛"}, {"text": "发送短信给老婆我在三号楼三零幺刚才那个教室关门了"}, {"text": "我要和妹妹睡觉英语怎么说"}, {"text": "翻译光大银行"}, {"text": "看今天的新闻。"}, {"text": "帮我查一查今天的新闻。"}, {"text": "安徽电视台12月18号晚上10:10的电视剧"}, {"text": "买三注星期四开奖的七乐彩"}, {"text": "帮我搜一下中超3月份的对阵情况"}, {"text": "今天的新闻有什么?"}, {"text": "CCTV9明晚有转播仁心解码2呢"}, {"text": "中央四频道火箭兜风"}, {"text": "12月9日北京卫视节目单"}, {"text": "今天晚上天津卫视放什么年岁月"}, {"text": "今天电影频道演什么"}, {"text": "中央一套昨天的新闻联播"}, {"text": "央视少儿频道的节目的"}, {"text": "把唐玉明的手机号码发给李康"}, {"text": "搜索国外大片"}, {"text": "查找校园小说沙漏第二部"}, {"text": "美国热门的恐怖电影"}, {"text": "小破孩动画片"}, {"text": "2013年亚洲冠军联赛恒广州恒大比赛时间。"}, {"text": "贺岁片"}, {"text": "查询郑州到上海的航班。"}, {"text": "杭州到北京的火车班次"}, {"text": "查询广州到北京的火车"}, {"text": "洛阳至西安的火车"}, {"text": "帮我查一下明天广州到长沙的航班"}, {"text": "九月二十号广州飞北京的航班"}, {"text": "明天去贵阳的火车"}, {"text": "明天南京到徐州的火车"}, {"text": "查询后天成都到伦敦的航班"}, {"text": "明天去滁州的火车"}, {"text": "明天去深圳的飞机"}, {"text": "从观音桥到重庆市图书馆怎么走"}, {"text": "湿疹怎么处理?"}, {"text": "甲亢怎么办?"}, {"text": "甲亢怎么治疗?"}, {"text": "手癣怎么治疗?"}, {"text": "脊髓损伤。"}, {"text": "甲型肝炎。"}, {"text": "淋病怎么办"}, {"text": "帮我看一下到科大讯飞最快的是哪条路线"}, {"text": "从这里去科大讯飞如何去"}, {"text": "到步行街不走高架的路线是哪条"}, {"text": "新东站怎么走"}, {"text": "帮我找到去安徽黄山的路"}, {"text": "导航去科大讯飞"}, {"text": "查一下从上海到安徽东至县的路线"}, {"text": "广东省英德在哪里"}, {"text": "龙口西路在哪里"}, {"text": "六安市飞云卫生院在哪里"}, {"text": "松江钢材城在哪里"}, {"text": "迎宾路在哪"}, {"text": "移动公司在哪里呀"}, {"text": "美国大使馆在哪"}, {"text": "去常州武进路线"}, {"text": "厦门到福建建阳的火车是几点呢"}, {"text": "查询厦门到武夷山的航班"}, {"text": "到太湖怎么走"}, {"text": "导航到奉贤"}, {"text": "我要去姜堰"}, {"text": "中超A组比赛结果分别是多少"}, {"text": "中超的比赛预告。"}, {"text": "中超赛事预告,啊!"}, {"text": "广州恒大比赛比分。"}, {"text": "中超比赛,结果。"}, {"text": "昨天恒大比赛结果是什么?"}, {"text": "中超赛事预告,从在哪里看呢?"}, {"text": "胡歌忘记时间"}, {"text": "我想听蔡依林的倒带这首歌"}, {"text": "来一首不可以。"}, {"text": "来首出师表"}, {"text": "宫崎骏动画电影"}, {"text": "来点张艺谋出演的节目"}, {"text": "陈奕迅有什么歌啊"}, {"text": "我要看李连杰的电影"}, {"text": "徐良唱的歌"}, {"text": "张绍刚的综艺节目"}, {"text": "告诉我西安的位置"}, {"text": "今天杭州有什么新闻"}, {"text": "泰州新闻。"}, {"text": "合肥明天的湿度如何"}, {"text": "龙岩新闻"}, {"text": "昨天的北京新闻"}, {"text": "我要听昆明新闻。"}, {"text": "我想听搜狐新闻"}, {"text": "有什么好看的校园小说"}, {"text": "背首杜甫的诗啊?"}, {"text": "读一首毛泽东的诗。"}, {"text": "那就来一首李白的诗词吗?"}, {"text": "李白的诗词"}, {"text": "唐代诗人李白。"}, {"text": "让读一下沁园春雪毛泽东。"}, {"text": "来一首白居易的长恨歌"}, {"text": "我想听首唐诗。"}, {"text": "恶念首唐诗。"}, {"text": "背唐诗李白侠客行"}, {"text": "来首宋词。"}, {"text": "你背首唐诗给我听听啊"}, {"text": "给我搜索唐诗。"}, {"text": "雄关漫道真如铁下一句是什么?"}, {"text": "抽刀断水水更流的下一句是什么?"}, {"text": "天王盖地虎的下一句是什么?"}, {"text": "火气冲天,下一句是什么?"}, {"text": "牧童骑黄牛的下一句是什么?"}, {"text": "造化钟神秀的下一句是什么?"}, {"text": "添加一条通讯录姓名葛勇手机号码189571642496"}, {"text": "代码000800的股票涨幅是多少"}, {"text": "请帮我调频89.6武汉交通广播电台"}, {"text": "请帮我调频97.5湖南人民广播电台文艺台"}, {"text": "搜索第10频道"}, {"text": "请帮我调频105山东人民广播电台生活频道"}, {"text": "请帮我调频89.7上虞人民广播电台"}, {"text": "查查一下连江到湛江的火车票"}, {"text": "帮我查一下临海到上海的火车"}, {"text": "姜堰到南京的火车"}, {"text": "请查询丰城到南昌的火车列车列车时刻表"}, {"text": "潜山到福州的火车"}, {"text": "宁乡到长沙怎么走"}, {"text": "晋江到武汉的航班"}, {"text": "帮我订明天早晨的机票"}, {"text": "智能手机英语怎么说"}, {"text": "很高兴见到你英文怎么讲"}, {"text": "我不吃英文怎么讲"}, {"text": "我想出去旅游英文怎么说"}, {"text": "电焊机用英语怎么说"}, {"text": "高清北京文艺台"}, {"text": "山东台高清"}, {"text": "高清动作电影"}, {"text": "甄嬛传高清"}, {"text": "CCTV6电影高清"}, {"text": "东方卫视高清频道"}, {"text": "高清央视综艺"}, {"text": "我想看周星驰的喜剧电影喜剧之王"}, {"text": "情景剧电视剧"}, {"text": "动画片巴啦啦小魔仙"}, {"text": "美国电影少年派的奇幻漂流"}, {"text": "最新电影"}, {"text": "最近有什么好看得"}, {"text": "今天温度是多少"}, {"text": "腾讯股票昨日收盘价"}, {"text": "爱在何方第三期"}] -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 9/5/19 10:54 AM 4 | # @Author : zchai 5 | # -*- coding: utf-8 -*- 6 | import json 7 | import codecs 8 | import sys 9 | 10 | ''' 11 | Calculate the sentence accuracy 12 | Json file format: { 13 | "text": "", 14 | "domain": "", 15 | "intent": "", 16 | "slots": { 17 | "name": "" 18 | } 19 | } 20 | ''' 21 | def sentence_acc(truth_dict_list, pred_dict_list): 22 | assert len(truth_dict_list) == len(pred_dict_list) 23 | 24 | acc_num = 0 25 | total_num = len(truth_dict_list) 26 | for truth_dic, pred_dic in zip(truth_dict_list, pred_dict_list): 27 | 28 | # Determine if the domain and intent are correct 29 | if truth_dic['domain'] != pred_dic['domain'] \ 30 | or truth_dic['intent'] != pred_dic['intent'] \ 31 | or len(truth_dic['slots']) != len(pred_dic['slots']): 32 | print("true: ", truth_dic) 33 | print("pre: ", pred_dic) 34 | continue 35 | else: 36 | # Determine if the slots_key and slots_value are correct 37 | flag = True 38 | for key, value in truth_dic['slots'].items(): 39 | if key not in pred_dic['slots']: 40 | flag = False 41 | break # if there is a key not in predict, flag set as false 42 | elif pred_dic['slots'][key] != truth_dic['slots'][key]: 43 | flag = False # if one not match, flag set as false 44 | break 45 | 46 | if flag: 47 | acc_num += 1 48 | #else: 49 | #print("true: ", truth_dic) 50 | #print("pre: ", pred_dic) 51 | 52 | return float(acc_num) / float(total_num) 53 | 54 | def domain_acc(truth_dict_list, pred_dict_list): 55 | assert len(truth_dict_list) == len(pred_dict_list) 56 | acc_num = 0 57 | total_num = len(truth_dict_list) 58 | for truth_dic, pred_dic in zip(truth_dict_list, pred_dict_list): 59 | if truth_dic['domain'] == pred_dic['domain']: 60 | acc_num += 1 61 | 62 | return float(acc_num) / float(total_num) 63 | 64 | 65 | def intent_acc(truth_dict_list, pred_dict_list): 66 | assert len(truth_dict_list) == len(pred_dict_list) 67 | acc_num = 0 68 | total_num = len(truth_dict_list) 69 | for truth_dic, pred_dic in zip(truth_dict_list, pred_dict_list): 70 | if truth_dic['intent'] == pred_dic['intent'] and truth_dic['domain'] == pred_dic['domain']: 71 | acc_num += 1 72 | 73 | return float(acc_num) / float(total_num) 74 | 75 | def slots_acc(truth_dict_list, pred_dict_list): 76 | assert len(truth_dict_list) == len(pred_dict_list) 77 | acc_num = 0 78 | total_num = 0 79 | for truth_dic, pred_dic in zip(truth_dict_list, pred_dict_list): 80 | total_num += len(truth_dic['slots']) 81 | for key, value in truth_dic['slots'].items(): 82 | if key not in pred_dic['slots']: 83 | continue 84 | elif pred_dic['slots'][key] == truth_dic['slots'][key]: 85 | acc_num+=1 86 | 87 | return float(acc_num) / float(total_num) 88 | 89 | def slots_f(truth_dict_list, pred_dict_list): 90 | assert len(truth_dict_list) == len(pred_dict_list) 91 | correct, p_denominator, r_denominator = 0, 0, 0 92 | for truth_dic, pred_dic in zip(truth_dict_list, pred_dict_list): 93 | r_denominator += len(truth_dic['slots']) 94 | p_denominator += len(pred_dic['slots']) 95 | for key, value in truth_dic['slots'].items(): 96 | if key not in pred_dic['slots']: 97 | continue 98 | elif pred_dic['slots'][key] == truth_dic['slots'][key] and \ 99 | truth_dic['domain'] == pred_dic['domain'] and \ 100 | truth_dic['intent'] == pred_dic['intent']: 101 | correct += 1 102 | precision = float(correct) / p_denominator 103 | recall = float(correct) / r_denominator 104 | f1 = 2 * precision * recall / (precision + recall) * 1.0 105 | 106 | return f1 107 | 108 | if __name__ == '__main__': 109 | if len(sys.argv) < 3: 110 | print('Too few args for this script') 111 | exit(1) 112 | 113 | with codecs.open(sys.argv[1], 'r', encoding='utf-8') as f: 114 | fp_truth = json.loads(f.read()) 115 | 116 | with codecs.open(sys.argv[2], 'r', encoding='utf-8') as f_pred: 117 | fp_pred = json.loads(f_pred.read()) 118 | 119 | domain_accuracy = domain_acc(fp_truth, fp_pred) 120 | intent_accuracy = intent_acc(fp_truth, fp_pred) 121 | slots_f = slots_f(fp_truth, fp_pred) 122 | 123 | sentence_accuracy = sentence_acc(fp_truth, fp_pred) 124 | 125 | print('Domain sentence accuracy : %f' % domain_accuracy) 126 | print('Intent sentence accuracy : %f' % intent_accuracy) 127 | print('Slots f score : %f' % slots_f) 128 | print('Avg sentence accuracy : %f' % sentence_accuracy) 129 | -------------------------------------------------------------------------------- /lstm_crf_layer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import rnn 3 | from tensorflow.contrib import crf 4 | 5 | class BLSTM_CRF(object): 6 | def __init__(self, config): 7 | """ 8 | BLSTM-CRF 网络 9 | :param embedded_chars: Fine-tuning embedding input 10 | :param hidden_unit: LSTM的隐含单元个数 11 | :param cell_type: RNN类型(LSTM OR GRU DICNN will be add in feature) 12 | :param num_layers: RNN的层数 13 | :param droupout_rate: droupout rate 14 | :param initializers: variable init class 15 | :param num_labels: 标签数量 16 | :param seq_length: 序列最大长度 17 | :param labels: 真实标签 18 | :param lengths: [batch_size] 每个batch下序列的真实长度 19 | :param is_training: 是否是训练过程 20 | """ 21 | 22 | self.hidden_unit = config["hidden_unit"] 23 | self.dropout_rate = config["dropout_rate"] 24 | self.cell_type = config["cell_type"] 25 | self.num_layers = config["num_layers"] 26 | self.embedded_chars = config["embedded_chars"] 27 | self.initializers = config["initializers"] 28 | self.seq_length = config["seq_length"] 29 | self.num_labels = config["num_labels"] 30 | self.labels = config["labels"] 31 | self.lengths = config["lengths"] 32 | self.embedding_dims = self.embedded_chars.shape[-1].value 33 | self.is_training = config["is_training"] 34 | 35 | def add_blstm_crf_layer(self, crf_only): 36 | """ 37 | blstm-crf 38 | """ 39 | if self.is_training: 40 | self.embedded_chars = tf.nn.dropout(self.embedded_chars, self.dropout_rate) 41 | 42 | if crf_only: 43 | logits = self.project_crf_layer(self.embedded_chars) 44 | else: 45 | #blstm 46 | lstm_output = self.blstm_layer(self.embedded_chars) 47 | #project 48 | logits = self.project_bilstm_layer(lstm_output) 49 | 50 | #crf 51 | loss, trans = self.crf_layer(logits) 52 | print(self.labels) 53 | # 54 | # CRF decode, pred_ids 是一条最大概率的标注路径 55 | if self.is_training: 56 | return (loss, logits, trans, None) 57 | 58 | pred_ids, _ = crf.crf_decode(potentials=logits, transition_params=trans, sequence_length=self.lengths) 59 | 60 | return (None, logits, None, pred_ids) 61 | 62 | def project_crf_layer(self, embedding_chars, name=None): 63 | """ 64 | hidden layer between input layer and logits 65 | :param lstm_outputs: [batch_size, num_steps, emb_size] 66 | :return: [batch_size, num_steps, num_tags] 67 | """ 68 | with tf.variable_scope("project" if not name else name): 69 | with tf.variable_scope("logits"): 70 | W = tf.get_variable("W", shape=[self.embedding_dims, self.num_labels],\ 71 | dtype=tf.float32, initializer=self.initializers.xavier_initializer()) 72 | 73 | b = tf.get_variable("b", shape=[self.num_labels], dtype=tf.float32,\ 74 | initializer=tf.zeros_initializer()) 75 | output = tf.reshape(self.embedded_chars,\ 76 | shape=[-1, self.embedding_dims]) # [batch_size, embedding_dims] 77 | pred = tf.tanh(tf.nn.xw_plus_b(output, W, b)) 78 | 79 | return tf.reshape(pred, [-1, self.seq_length, self.num_labels]) 80 | 81 | def crf_layer(self, logits): 82 | """ 83 | calculate crf loss 84 | :param project_logits: [1, num_steps, num_tags] 85 | :return: scalar loss 86 | """ 87 | with tf.variable_scope("crf_loss"): 88 | trans = tf.get_variable( 89 | "transitions", 90 | shape=[self.num_labels, self.num_labels], 91 | initializer=self.initializers.xavier_initializer()) 92 | if self.labels is None or self.is_training == False: 93 | return None, trans 94 | else: 95 | log_likelihood, trans = tf.contrib.crf.crf_log_likelihood( 96 | inputs=logits, 97 | tag_indices=self.labels, 98 | transition_params=trans, 99 | sequence_lengths=self.lengths) 100 | 101 | return tf.reduce_mean(-log_likelihood), trans 102 | 103 | def blstm_layer(self, embedding_chars): 104 | with tf.variable_scope("rnn_layer"): 105 | cell_fw = rnn.LSTMCell(self.hidden_unit) 106 | cell_bw = rnn.LSTMCell(self.hidden_unit) 107 | if self.is_training and self.dropout_rate is not None: 108 | cell_fw = rnn.DropoutWrapper(cell_fw, output_keep_prob=self.dropout_rate) 109 | cell_bw = rnn.DropoutWrapper(cell_bw, output_keep_prob=self.dropout_rate) 110 | outputs, _ = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, embedding_chars, dtype=tf.float32) 111 | 112 | outputs = tf.concat(outputs, axis=2) 113 | 114 | return outputs 115 | 116 | def project_bilstm_layer(self, lstm_outputs, name = None): 117 | ''' 118 | 119 | ''' 120 | with tf.variable_scope("project" if not name else name): 121 | with tf.variable_scope("logits"): 122 | W = tf.get_variable("W", shape=[self.hidden_unit * 2, self.num_labels],\ 123 | dtype=tf.float32, initializer=self.initializers.xavier_initializer()) 124 | b = tf.get_variable("b", shape=[self.num_labels], dtype=tf.float32,\ 125 | initializer=tf.zeros_initializer) 126 | 127 | outputs = tf.reshape(lstm_outputs, shape=[-1, self.hidden_unit * 2]) 128 | 129 | pred = tf.nn.xw_plus_b(outputs, W, b) 130 | 131 | return tf.reshape(pred, [-1, self.seq_length, self.num_labels]) 132 | ''' 133 | with tf.variable_scope("hidden"): 134 | W1 = tf.get_variable("W1", shape=[self.hidden_unit * 2, self.hidden_unit],\ 135 | dtype=tf.float32, initializer=self.initializers.xavier_initializer()) 136 | b1 = tf.get_variable("b1", shape=[self.hidden_unit], dtype=tf.float32,\ 137 | initializer=tf.zeros_initializer) 138 | output = tf.reshape(lstm_outputs, shape=[-1, self.hidden_unit * 2]) 139 | hidden = tf.tanh(tf.nn.xw_plus_b(output, W1, b1)) 140 | with tf.variable_scope("logits"): 141 | W2 = tf.get_variable("W2", shape=[self.hidden_unit, self.num_labels], 142 | dtype=tf.float32, initializer=self.initializers.xavier_initializer()) 143 | 144 | b2 = tf.get_variable("b2", shape=[self.num_labels], dtype=tf.float32, 145 | initializer=tf.zeros_initializer) 146 | 147 | pred = tf.nn.xw_plus_b(hidden, W2, b2) 148 | return tf.reshape(pred, [-1, self.seq_length, self.num_labels]) 149 | ''' 150 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import json 3 | import os 4 | import sys 5 | import numpy as np 6 | from model_config import joint_config as config 7 | from utils import Joint_Processor 8 | from utils import file_based_convert_examples_to_features, file_based_input_fn_builder 9 | from utils import model_fn_builder, convert_examples_to_features 10 | from utils import input_fn_builder, get_slot_name 11 | from bert import modeling, tokenization, optimization 12 | from patterns import code_pattern 13 | import json 14 | import re 15 | 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 17 | 18 | def main(test_file = 'test.json'): 19 | tf.logging.set_verbosity(tf.logging.INFO) 20 | #1.设置数据处理器 21 | processors = { 22 | 'joint': Joint_Processor 23 | } 24 | 25 | task_name = config['task_name'].lower() 26 | if task_name not in processors: 27 | raise ValueError("Task not found: %s" % task_name) 28 | processor = processors[task_name]() 29 | 30 | #1.1获取标签 31 | id2domain, domain2id, id2intent, intent2id, id2slot, slot2id, domain_w, intent_w = \ 32 | processor.get_labels(config["data_dir"],\ 33 | "train" if config['do_train'] else "test") 34 | 35 | #print(domain2id) 36 | #print(intent2id) 37 | #print(slot2id) 38 | #获取分词器 39 | tokenizer = tokenization.FullTokenizer(\ 40 | vocab_file=config['vocab_file'], do_lower_case=config['do_lower_case']) 41 | 42 | train_examples = None 43 | num_train_steps = None 44 | num_warmup_steps = None 45 | save_checkpoints_steps = config['save_checkpoints_steps'] 46 | 47 | #1.2读取训练数据,并转成example格式 48 | if config['do_train']: 49 | tf.logging.info("***** Loading training examples *****") 50 | train_examples = processor.get_train_examples(config['data_dir']) 51 | num_train_steps = int(len(train_examples) / config['train_batch_size'] * config['num_train_epochs']) 52 | num_warmup_steps = int(num_train_steps * config['warmup_proportion']) 53 | save_checkpoints_steps = int(len(train_examples) / config['train_batch_size']) + 1 54 | 55 | if config['do_train']: 56 | train_file = os.path.join(config['data_dir'], 'train.tf_record') 57 | #将example写入tf方便读取的文件 58 | file_based_convert_examples_to_features(train_examples, domain2id, intent2id, slot2id,\ 59 | config['max_seq_length'], tokenizer, train_file) 60 | 61 | #文件读取模块 62 | train_input_fn = file_based_input_fn_builder( 63 | input_file = train_file, 64 | seq_length = config['max_seq_length'], 65 | is_training = True, 66 | drop_remainder = False) 67 | #2.创建模型 68 | #2.1设置模型运行参数 69 | bert_config = modeling.BertConfig.from_json_file(config['bert_config_file']) 70 | 71 | tf_cfg = tf.ConfigProto() 72 | tf_cfg.gpu_options.per_process_gpu_memory_fraction = 0.8 73 | 74 | run_config = tf.estimator.RunConfig( 75 | model_dir = config['output_dir'], 76 | save_checkpoints_steps = save_checkpoints_steps, 77 | keep_checkpoint_max = 1, 78 | session_config = tf_cfg, 79 | log_step_count_steps = 100,) 80 | #2.1构建模型 81 | model_fn = model_fn_builder( 82 | bert_config = bert_config, 83 | num_domain = len(domain2id), 84 | num_intent = len(intent2id), 85 | num_slot = len(slot2id), 86 | init_checkpoint = config['init_checkpoint'], 87 | learning_rate = config['learning_rate'], 88 | num_train_steps = num_train_steps, 89 | num_warmup_steps = num_warmup_steps, 90 | use_tpu = config['use_tpu'], 91 | use_one_hot_embeddings = config['use_tpu'], 92 | do_serve = config['do_serve'], 93 | domain_w = domain_w, 94 | intent_w = intent_w) 95 | 96 | estimator = tf.estimator.Estimator( 97 | model_fn = model_fn, 98 | config = run_config, 99 | ) 100 | 101 | #3训练 102 | if config['do_train']: 103 | tf.logging.info("***** Running training *****") 104 | tf.logging.info(" Num examples = %d", len(train_examples)) 105 | tf.logging.info(" Batch size = %d", config['train_batch_size']) 106 | tf.logging.info(" Num steps = %d", num_train_steps) 107 | if config['do_eval']: 108 | #没有eval环节 109 | train_spec = tf.estimator.TrainSpec(input_fn = train_input_fn,\ 110 | max_steps = num_train_steps) 111 | eval_spec = tf.estimator.EvalSpec(input_fn = eval_input_fn,\ 112 | steps = eval_steps, start_delay_secs=60, throttle_secs=0) 113 | tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) 114 | else: 115 | estimator.train(input_fn = train_input_fn, max_steps = num_train_steps) 116 | 117 | return None 118 | 119 | #4预测 120 | #4.1加载预测数据 121 | if config['do_predict']: 122 | tf.logging.info("***** Loading training examples *****") 123 | test_examples = processor.get_test_examples(test_file) 124 | num_actual_predict_examples = len(test_examples) 125 | tf.logging.info("the number of test_examples is %d" % len(test_examples)) 126 | test_features = convert_examples_to_features(test_examples, domain2id,\ 127 | intent2id, slot2id, config['max_seq_length'], tokenizer) 128 | tf.logging.info("the number of test_features is %d" % len(test_features)) 129 | 130 | if config['do_predict']: 131 | predict_input_fn = input_fn_builder( 132 | features = test_features, 133 | seq_length = config['max_seq_length'], 134 | is_training = False, 135 | drop_remainder = False, 136 | ) 137 | result = estimator.predict(input_fn=predict_input_fn) 138 | print(result) 139 | pred_results = [] 140 | for pred_line, prediction in zip(test_examples, result): 141 | data = {} 142 | #print(pred_line.text) 143 | data['text'] = pred_line.text 144 | domain_pred = prediction["domain_pred"] 145 | intent_pred = prediction["intent_pred"] 146 | slot_pred = prediction["slot_pred"] 147 | data['domain'] = id2domain[domain_pred] 148 | 149 | data['intent'] = id2intent[intent_pred] if id2intent[intent_pred] != 'NaN' else np.nan 150 | idx = 0 151 | len_seq = len(pred_line.text) 152 | slot_labels = [] 153 | for sid in slot_pred: 154 | if idx >= len_seq: 155 | break 156 | if sid == 0: 157 | continue 158 | cur_slot = id2slot[sid] 159 | if cur_slot in ['[CLS]', '[SEP]']: 160 | continue 161 | slot_labels.append(cur_slot) 162 | idx += 1 163 | 164 | data['slots'] = get_slot_name(pred_line.text, slot_labels) 165 | 166 | for p in code_pattern: 167 | result = re.match(p, data['text']) 168 | if result: 169 | #print(result.group(1)) 170 | #print(result.group(0), result.group(1)) 171 | data['slots']['code'] = result.group(1) 172 | break 173 | pred_results.append(data) 174 | 175 | #print(domain_pred, intent_pred, slot_pred) 176 | json.dump(pred_results, open(sys.argv[2], 'w', encoding='utf8'), ensure_ascii=False) 177 | 178 | 179 | if __name__ == '__main__': 180 | test_file = sys.argv[1] 181 | print(test_file) 182 | main(test_file) 183 | 184 | -------------------------------------------------------------------------------- /model_config.py: -------------------------------------------------------------------------------- 1 | joint_config = { 2 | "data_dir" : './data', 3 | "train_file" : 'train.json', 4 | #"train_file" : 'smp_train_ehance.json', 5 | #"train_file" : 'smp_2019_task1_train.json', 6 | "bert_config_file" : './chinese_L-12_H-768_A-12/bert_config.json', 7 | "task_name" : 'joint', 8 | "vocab_file" : 'chinese_L-12_H-768_A-12/vocab.txt', 9 | "output_dir" : 'output/', 10 | "init_checkpoint" : './chinese_L-12_H-768_A-12/bert_model.ckpt', 11 | "do_lower_case" : True, 12 | "max_seq_length" : 128, 13 | "do_train" : False, 14 | "do_eval" : False, 15 | "do_predict" : True, 16 | "do_serve" : False, 17 | "train_batch_size" : 32, 18 | "eval_batch_size" : 32, 19 | "predict_batch_size" : 16, 20 | "learning_rate" : 5e-5, 21 | "num_train_epochs" : 40.0, 22 | "warmup_proportion" : 0.1, 23 | "save_checkpoints_steps": 1000, 24 | "iterations_per_loop" : 1000, 25 | "use_tpu" : False, 26 | "tpu_name" : None, 27 | "tpu_zone" : None, 28 | "gcp_project" : None, 29 | "master" : None, 30 | "num_tpu_cores" : 8, 31 | "label_file" : 'labels_map', 32 | "export_dir" : None, 33 | "lstm_size" : 100, 34 | "cell" : "lstm", 35 | "num_layers" : 1, 36 | "dropout_rate" : 0.5, 37 | } 38 | 39 | -------------------------------------------------------------------------------- /patterns.py: -------------------------------------------------------------------------------- 1 | 2 | code_pattern = [ 3 | "查询股票(\d+)", 4 | "搜索第(\d+)频道", 5 | ".*?(\d+)频道*", 6 | ".*号码是(\d+).*", 7 | ".*调频(\d+.\d+).*", 8 | ".*新建联系人(\d+).*", 9 | "代码(\d+)的股票" 10 | ] 11 | -------------------------------------------------------------------------------- /report/READ.ME: -------------------------------------------------------------------------------- 1 | SMP 2019 第三届中文人机对话技术评测顺利落幕 2 | https://mp.weixin.qq.com/s/P36g3Q_e1q47rOpCuzTNAw 3 | -------------------------------------------------------------------------------- /report/SMP2019ECDT任务1技术报告-出门问问信息科技有限公司 (1).pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaopp123/bert-joint-NLU/57259b7835ab3da058555145bd1951d68391cbd1/report/SMP2019ECDT任务1技术报告-出门问问信息科技有限公司 (1).pdf -------------------------------------------------------------------------------- /report/SMP2019ECDT技术报告-北京沃丰时代数据科技有限公司.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaopp123/bert-joint-NLU/57259b7835ab3da058555145bd1951d68391cbd1/report/SMP2019ECDT技术报告-北京沃丰时代数据科技有限公司.pdf -------------------------------------------------------------------------------- /report/coffeeNLU小队技术报告_final.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaopp123/bert-joint-NLU/57259b7835ab3da058555145bd1951d68391cbd1/report/coffeeNLU小队技术报告_final.pdf -------------------------------------------------------------------------------- /res.out: -------------------------------------------------------------------------------- 1 | 525 525 2 | true: {'text': '赶集', 'domain': 'app', 'intent': 'LAUNCH', 'slots': {'name': '赶集'}} 3 | pre: {'text': '赶集', 'domain': 'video', 'intent': 'QUERY', 'slots': {'name': '赶集'}} 4 | true: {'text': '从新加坡花园怎么去宁溪路', 'domain': 'map', 'intent': 'ROUTE', 'slots': {'endLoc_poi': '宁溪路', 'startLoc_poi': '新加坡花园'}} 5 | pre: {'text': '从新加坡花园怎么去宁溪路', 'domain': 'map', 'intent': 'ROUTE', 'slots': {'endLoc_poi': '新加坡花园宁溪路'}} 6 | true: {'text': '到合肥市逍遥津公园怎么走', 'domain': 'map', 'intent': 'ROUTE', 'slots': {'endLoc_city': '合肥', 'endLoc_poi': '逍遥津公园'}} 7 | pre: {'text': '到合肥市逍遥津公园怎么走', 'domain': 'map', 'intent': 'ROUTE', 'slots': {'endLoc_city': '合肥', 'endLoc_province': '市', 'endLoc_poi': '逍遥津公园'}} 8 | true: {'text': '湖北荆州到黄石的火车', 'domain': 'train', 'intent': 'QUERY', 'slots': {'endLoc_city': '黄石', 'startLoc_city': '荆州', 'startLoc_province': '湖北'}} 9 | pre: {'text': '湖北荆州到黄石的火车', 'domain': 'train', 'intent': 'QUERY', 'slots': {'startLoc_city': '湖北荆州', 'endLoc_city': '黄石'}} 10 | true: {'text': '发短信给张三说“画皮2在哪个台播出”', 'domain': 'message', 'intent': 'SEND', 'slots': {'content': '画皮2在哪个台播出', 'name': '张三'}} 11 | pre: {'text': '发短信给张三说“画皮2在哪个台播出”', 'domain': 'message', 'intent': 'SEND', 'slots': {'name': '张三“画皮2'}} 12 | true: {'text': '查看短消息', 'domain': 'message', 'intent': 'VIEW', 'slots': {}} 13 | pre: {'text': '查看短消息', 'domain': 'message', 'intent': 'SEND', 'slots': {}} 14 | true: {'text': '查找张伟军', 'domain': 'contacts', 'intent': nan, 'slots': {'name': '张伟军'}} 15 | pre: {'text': '查找张伟军', 'domain': 'video', 'intent': 'QUERY', 'slots': {'name': '张伟军'}} 16 | true: {'text': '有什么特色美食。', 'domain': 'cookbook', 'intent': 'QUERY', 'slots': {'keyword': '特色美食'}} 17 | pre: {'text': '有什么特色美食。', 'domain': 'cookbook', 'intent': 'QUERY', 'slots': {'category': '特', 'keyword': '色美食'}} 18 | true: {'text': '回放昨晚的焦点访谈', 'domain': 'epg', 'intent': 'LOOK_BACK', 'slots': {'datetime_time': '昨晚', 'name': '焦点访谈'}} 19 | pre: {'text': '回放昨晚的焦点访谈', 'domain': 'epg', 'intent': 'QUERY', 'slots': {'datetime_time': '昨晚', 'name': '焦点访谈'}} 20 | true: {'text': 'BTV生活节目选择回放', 'domain': 'epg', 'intent': 'LOOK_BACK', 'slots': {'tvchannel': 'BTV生活'}} 21 | pre: {'text': 'BTV生活节目选择回放', 'domain': 'epg', 'intent': 'QUERY', 'slots': {'tvchannel': 'BTV生活'}} 22 | true: {'text': '上一期15选5的中奖号码是多少?', 'domain': 'lottery', 'intent': 'NUMBER_QUERY', 'slots': {'name': '15选5', 'relIssue': '上一期'}} 23 | pre: {'text': '上一期15选5的中奖号码是多少?', 'domain': 'lottery', 'intent': 'NUMBER_QUERY', 'slots': {'name': '15选5'}} 24 | true: {'text': '我想看足球中超,第25轮的比赛时间。', 'domain': 'match', 'intent': 'QUERY', 'slots': {'category': '中超'}} 25 | pre: {'text': '我想看足球中超,第25轮的比赛时间。', 'domain': 'match', 'intent': 'QUERY', 'slots': {'category': '足球中超', 'datetime_date': '25'}} 26 | true: {'text': '帮我查一下周杰伦的歌', 'domain': 'music', 'intent': 'SEARCH', 'slots': {'artist': '周杰伦'}} 27 | pre: {'text': '帮我查一下周杰伦的歌', 'domain': 'music', 'intent': 'PLAY', 'slots': {'artist': '周杰伦'}} 28 | true: {'text': '播放眼色', 'domain': 'music', 'intent': 'PLAY', 'slots': {'song': '眼色'}} 29 | pre: {'text': '播放眼色', 'domain': 'video', 'intent': 'QUERY', 'slots': {'name': '眼色'}} 30 | true: {'text': '打开腾讯头条新闻', 'domain': 'news', 'intent': 'PLAY', 'slots': {'category': '头条', 'media': '腾讯'}} 31 | pre: {'text': '打开腾讯头条新闻', 'domain': 'news', 'intent': 'PLAY', 'slots': {'name': '腾讯头条'}} 32 | true: {'text': '收听安徽广播电台。', 'domain': 'radio', 'intent': 'LAUNCH', 'slots': {'location_province': '安徽', 'name': '广播电台'}} 33 | pre: {'text': '收听安徽广播电台。', 'domain': 'radio', 'intent': 'LAUNCH', 'slots': {'name': '安徽广播电台'}} 34 | true: {'text': '搜索湖南卫视直播。', 'domain': 'tvchannel', 'intent': 'PLAY', 'slots': {'name': '湖南卫视'}} 35 | pre: {'text': '搜索湖南卫视直播。', 'domain': 'epg', 'intent': 'QUERY', 'slots': {'tvchannel': '湖南卫视'}} 36 | true: {'text': '天津电视台的国际频道,拜托了哈', 'domain': 'tvchannel', 'intent': 'PLAY', 'slots': {'category': '国际', 'name': '天津电视台'}} 37 | pre: {'text': '天津电视台的国际频道,拜托了哈', 'domain': 'epg', 'intent': 'QUERY', 'slots': {'tvchannel': '天津电视台', 'category': '国', 'name': '际'}} 38 | true: {'text': '找一下非诚勿扰娱乐节目', 'domain': 'video', 'intent': 'QUERY', 'slots': {'category': '节目', 'name': '非诚勿扰', 'tag': '娱乐'}} 39 | pre: {'text': '找一下非诚勿扰娱乐节目', 'domain': 'video', 'intent': 'QUERY', 'slots': {'name': '非诚勿扰', 'category': '娱乐节目'}} 40 | true: {'text': '2012年美国公告牌音乐大奖颁奖礼', 'domain': 'video', 'intent': 'QUERY', 'slots': {'datetime_date': '2012年', 'name': '美国公告牌音乐大奖颁奖礼'}} 41 | pre: {'text': '2012年美国公告牌音乐大奖颁奖礼', 'domain': 'lottery', 'intent': 'QUERY', 'slots': {'datetime_date': '2012年', 'area': '美国', 'name': '公告牌音乐大奖'}} 42 | true: {'text': '你们什么搞笑剧', 'domain': 'video', 'intent': 'QUERY', 'slots': {'category': '搞笑剧'}} 43 | pre: {'text': '你们什么搞笑剧', 'domain': 'video', 'intent': 'QUERY', 'slots': {'tag': '搞笑', 'category': '剧'}} 44 | true: {'text': '查找90年代好评的武侠电影', 'domain': 'video', 'intent': 'QUERY', 'slots': {'category': '电影', 'decade': '90年代', 'payment': '好评', 'tag': '武侠'}} 45 | pre: {'text': '查找90年代好评的武侠电影', 'domain': 'video', 'intent': 'QUERY', 'slots': {'datetime_date': '90年代', 'tag': '武侠', 'category': '电影'}} 46 | true: {'text': '今天有雨吗', 'domain': 'weather', 'intent': 'QUERY', 'slots': {'datetime_date': '今天', 'questionWord': '有', 'subfocus': '雨'}} 47 | pre: {'text': '今天有雨吗', 'domain': 'weather', 'intent': 'QUERY', 'slots': {'datetime_date': '今天'}} 48 | true: {'text': '新浪汽车', 'domain': 'website', 'intent': 'OPEN', 'slots': {'name': '新浪汽车'}} 49 | pre: {'text': '新浪汽车', 'domain': 'app', 'intent': 'LAUNCH', 'slots': {'name': '新浪汽车'}} 50 | true: {'text': '百度百科', 'domain': 'app', 'intent': 'LAUNCH', 'slots': {'name': '百度百科'}} 51 | pre: {'text': '百度百科', 'domain': 'website', 'intent': 'OPEN', 'slots': {'name': '百度百科'}} 52 | true: {'text': '带我去丹阳市眼镜市场', 'domain': 'map', 'intent': 'ROUTE', 'slots': {'endLoc_area': '丹阳', 'endLoc_poi': '眼镜市场'}} 53 | pre: {'text': '带我去丹阳市眼镜市场', 'domain': 'map', 'intent': 'ROUTE', 'slots': {'endLoc_poi': '丹阳市眼镜市场'}} 54 | true: {'text': '上一期双色球开什么', 'domain': 'lottery', 'intent': 'NUMBER_QUERY', 'slots': {'name': '双色球', 'relIssue': '上一期'}} 55 | pre: {'text': '上一期双色球开什么', 'domain': 'lottery', 'intent': 'QUERY', 'slots': {'datetime_date': '上一', 'name': '双色球'}} 56 | true: {'text': '昆山大润发在哪里', 'domain': 'map', 'intent': 'POSITION', 'slots': {'location_area': '昆山', 'location_poi': '大润发'}} 57 | pre: {'text': '昆山大润发在哪里', 'domain': 'map', 'intent': 'POSITION', 'slots': {'location_poi': '昆山大润发'}} 58 | true: {'text': '沿途有没有加油站', 'domain': 'map', 'intent': 'POSITION', 'slots': {'location_poi': '加油站'}} 59 | pre: {'text': '沿途有没有加油站', 'domain': 'map', 'intent': 'ROUTE', 'slots': {'endLoc_poi': '加油站'}} 60 | true: {'text': 'CCTV6电影', 'domain': 'tvchannel', 'intent': 'PLAY', 'slots': {'name': 'CCTV6电影'}} 61 | pre: {'text': 'CCTV6电影', 'domain': 'tvchannel', 'intent': 'PLAY', 'slots': {'name': 'CCTV6电', 'category': '影'}} 62 | true: {'text': '播放断点', 'domain': 'music', 'intent': 'PLAY', 'slots': {'song': '断点'}} 63 | pre: {'text': '播放断点', 'domain': 'video', 'intent': 'QUERY', 'slots': {'name': '断点'}} 64 | true: {'text': '分手时,背一首诗吧?', 'domain': 'poetry', 'intent': 'DEFAULT', 'slots': {}} 65 | pre: {'text': '分手时,背一首诗吧?', 'domain': 'poetry', 'intent': 'DEFAULT', 'slots': {'keyword': '分'}} 66 | true: {'text': '新建联系人18622625490', 'domain': 'contacts', 'intent': 'CREATE', 'slots': {'code': '18622625490'}} 67 | pre: {'text': '新建联系人18622625490', 'domain': 'app', 'intent': 'LAUNCH', 'slots': {'name': '18622625490'}} 68 | true: {'text': '添加一条通讯录姓名张守刚号码是13811725158', 'domain': 'contacts', 'intent': 'CREATE', 'slots': {'code': '13811725158', 'name': '张守刚'}} 69 | pre: {'text': '添加一条通讯录姓名张守刚号码是13811725158', 'domain': 'email', 'intent': 'LAUNCH', 'slots': {'name': '通讯录张守刚13811725158'}} 70 | true: {'text': '新建联系人天天', 'domain': 'contacts', 'intent': 'CREATE', 'slots': {'name': '天天'}} 71 | pre: {'text': '新建联系人天天', 'domain': 'app', 'intent': 'LAUNCH', 'slots': {'name': '新建联系人天天'}} 72 | true: {'text': '找一首歌叫斑马斑马', 'domain': 'music', 'intent': 'SEARCH', 'slots': {'song': '斑马斑马'}} 73 | pre: {'text': '找一首歌叫斑马斑马', 'domain': 'music', 'intent': 'PLAY', 'slots': {'song': '斑马斑马'}} 74 | true: {'text': '人造鸡蛋的配方。', 'domain': 'cookbook', 'intent': nan, 'slots': {'dishName': '人造鸡蛋'}} 75 | pre: {'text': '人造鸡蛋的配方。', 'domain': 'cookbook', 'intent': 'QUERY', 'slots': {'dishName': '人造鸡蛋'}} 76 | true: {'text': '打开应用漫画', 'domain': 'app', 'intent': 'LAUNCH', 'slots': {'name': '漫画'}} 77 | pre: {'text': '打开应用漫画', 'domain': 'video', 'intent': 'QUERY', 'slots': {'category': '漫画'}} 78 | true: {'text': '现在电视台在放什么节目', 'domain': 'epg', 'intent': 'QUERY', 'slots': {'category': '节目', 'datetime_time': '现在'}} 79 | pre: {'text': '现在电视台在放什么节目', 'domain': 'epg', 'intent': 'QUERY', 'slots': {'datetime_time': '现在'}} 80 | true: {'text': '现在有什么好的电视剧。', 'domain': 'epg', 'intent': 'QUERY', 'slots': {'category': '电视剧', 'datetime_time': '现在'}} 81 | pre: {'text': '现在有什么好的电视剧。', 'domain': 'video', 'intent': 'QUERY', 'slots': {'category': '电视剧'}} 82 | true: {'text': '现在看看有最新的新闻吗?', 'domain': 'news', 'intent': 'PLAY', 'slots': {'datetime_time': '现在', 'keyword': '最新'}} 83 | pre: {'text': '现在看看有最新的新闻吗?', 'domain': 'news', 'intent': 'PLAY', 'slots': {'datetime_time': '现在'}} 84 | true: {'text': '今天晚上天津卫视放什么年岁月', 'domain': 'epg', 'intent': 'QUERY', 'slots': {'datetime_date': '今天', 'datetime_time': '晚上', 'tvchannel': '天津卫视'}} 85 | pre: {'text': '今天晚上天津卫视放什么年岁月', 'domain': 'epg', 'intent': 'QUERY', 'slots': {'datetime_date': '今天', 'datetime_time': '晚上', 'tvchannel': '天津卫视', 'name': '年岁月'}} 86 | true: {'text': '美国热门的恐怖电影', 'domain': 'video', 'intent': 'QUERY', 'slots': {'area': '美国', 'category': '电影', 'popularity': '热门', 'tag': '恐怖'}} 87 | pre: {'text': '美国热门的恐怖电影', 'domain': 'video', 'intent': 'QUERY', 'slots': {'area': '美国', 'tag': '恐怖', 'category': '电影'}} 88 | true: {'text': '从观音桥到重庆市图书馆怎么走', 'domain': 'map', 'intent': 'ROUTE', 'slots': {'endLoc_poi': '重庆市图书馆', 'startLoc_poi': '观音桥'}} 89 | pre: {'text': '从观音桥到重庆市图书馆怎么走', 'domain': 'map', 'intent': 'ROUTE', 'slots': {'startLoc_poi': '观音桥', 'endLoc_city': '重庆', 'endLoc_poi': '市图书馆'}} 90 | true: {'text': '厦门到福建建阳的火车是几点呢', 'domain': 'train', 'intent': 'QUERY', 'slots': {'endLoc_area': '建阳', 'endLoc_province': '福建', 'startLoc_city': '厦门'}} 91 | pre: {'text': '厦门到福建建阳的火车是几点呢', 'domain': 'train', 'intent': 'QUERY', 'slots': {'startLoc_city': '厦门', 'endLoc_city': '福建建阳'}} 92 | true: {'text': '我要去姜堰', 'domain': 'map', 'intent': 'ROUTE', 'slots': {'endLoc_area': '姜堰'}} 93 | pre: {'text': '我要去姜堰', 'domain': 'map', 'intent': 'ROUTE', 'slots': {'endLoc_area': '姜', 'endLoc_city': '堰'}} 94 | true: {'text': '中超A组比赛结果分别是多少', 'domain': 'match', 'intent': 'QUERY', 'slots': {'category': '中超', 'type': '结果'}} 95 | pre: {'text': '中超A组比赛结果分别是多少', 'domain': 'match', 'intent': 'QUERY', 'slots': {'category': '中超'}} 96 | true: {'text': '中超的比赛预告。', 'domain': 'match', 'intent': 'QUERY', 'slots': {'category': '中超', 'type': '预告'}} 97 | pre: {'text': '中超的比赛预告。', 'domain': 'match', 'intent': 'QUERY', 'slots': {'category': '中超'}} 98 | true: {'text': '中超赛事预告,啊!', 'domain': 'match', 'intent': 'QUERY', 'slots': {'category': '中超', 'type': '预告'}} 99 | pre: {'text': '中超赛事预告,啊!', 'domain': 'match', 'intent': 'QUERY', 'slots': {'category': '中超'}} 100 | true: {'text': '广州恒大比赛比分。', 'domain': 'match', 'intent': 'QUERY', 'slots': {'name': '广州恒大', 'type': '比分'}} 101 | pre: {'text': '广州恒大比赛比分。', 'domain': 'match', 'intent': 'QUERY', 'slots': {'name': '广州恒大'}} 102 | true: {'text': '中超比赛,结果。', 'domain': 'match', 'intent': 'QUERY', 'slots': {'category': '中超', 'type': '结果'}} 103 | pre: {'text': '中超比赛,结果。', 'domain': 'match', 'intent': 'QUERY', 'slots': {'category': '中超'}} 104 | true: {'text': '昨天恒大比赛结果是什么?', 'domain': 'match', 'intent': 'QUERY', 'slots': {'datetime_date': '昨天', 'name': '恒大', 'type': '结果'}} 105 | pre: {'text': '昨天恒大比赛结果是什么?', 'domain': 'match', 'intent': 'QUERY', 'slots': {'datetime_date': '昨天', 'name': '恒大'}} 106 | true: {'text': '中超赛事预告,从在哪里看呢?', 'domain': 'match', 'intent': 'QUERY', 'slots': {'category': '中超', 'type': '预告'}} 107 | pre: {'text': '中超赛事预告,从在哪里看呢?', 'domain': 'match', 'intent': 'QUERY', 'slots': {'category': '中超'}} 108 | true: {'text': '胡歌忘记时间', 'domain': 'music', 'intent': 'PLAY', 'slots': {'artist': '胡歌', 'song': '忘记时间'}} 109 | pre: {'text': '胡歌忘记时间', 'domain': 'video', 'intent': 'QUERY', 'slots': {'artist': '胡歌', 'name': '忘记时间'}} 110 | true: {'text': '宫崎骏动画电影', 'domain': 'video', 'intent': 'QUERY', 'slots': {'artist': '宫崎骏', 'category': '电影', 'tag': '动画'}} 111 | pre: {'text': '宫崎骏动画电影', 'domain': 'video', 'intent': 'QUERY', 'slots': {'name': '宫崎', 'artist': '骏', 'tag': '动画', 'category': '电影'}} 112 | true: {'text': '陈奕迅有什么歌啊', 'domain': 'music', 'intent': 'SEARCH', 'slots': {'artist': '陈奕迅'}} 113 | pre: {'text': '陈奕迅有什么歌啊', 'domain': 'music', 'intent': 'PLAY', 'slots': {'artist': '陈奕迅'}} 114 | true: {'text': '合肥明天的湿度如何', 'domain': 'weather', 'intent': 'QUERY', 'slots': {'datetime_date': '明天', 'location_city': '合肥', 'subfocus': '湿度'}} 115 | pre: {'text': '合肥明天的湿度如何', 'domain': 'weather', 'intent': 'QUERY', 'slots': {'location_city': '合肥', 'datetime_date': '明天'}} 116 | true: {'text': '我想听搜狐新闻', 'domain': 'news', 'intent': 'PLAY', 'slots': {'media': '搜狐'}} 117 | pre: {'text': '我想听搜狐新闻', 'domain': 'app', 'intent': 'LAUNCH', 'slots': {'name': '搜狐新闻'}} 118 | true: {'text': '有什么好看的校园小说', 'domain': 'novel', 'intent': 'QUERY', 'slots': {'category': '校园', 'popularity': '好看'}} 119 | pre: {'text': '有什么好看的校园小说', 'domain': 'novel', 'intent': 'QUERY', 'slots': {'category': '校园'}} 120 | true: {'text': '添加一条通讯录姓名葛勇手机号码189571642496', 'domain': 'contacts', 'intent': 'CREATE', 'slots': {'code': '189571642496', 'name': '葛勇'}} 121 | pre: {'text': '添加一条通讯录姓名葛勇手机号码189571642496', 'domain': 'email', 'intent': 'LAUNCH', 'slots': {'name': '通讯录葛勇189571642496'}} 122 | true: {'text': '晋江到武汉的航班', 'domain': 'flight', 'intent': 'QUERY', 'slots': {'endLoc_city': '武汉', 'startLoc_area': '晋江'}} 123 | pre: {'text': '晋江到武汉的航班', 'domain': 'flight', 'intent': 'QUERY', 'slots': {'startLoc_city': '晋', 'startLoc_area': '江', 'endLoc_city': '武汉'}} 124 | true: {'text': '高清北京文艺台', 'domain': 'tvchannel', 'intent': 'PLAY', 'slots': {'name': '北京文艺台', 'resolution': '高清'}} 125 | pre: {'text': '高清北京文艺台', 'domain': 'tvchannel', 'intent': 'PLAY', 'slots': {'name': '高清北京文艺台'}} 126 | true: {'text': '山东台高清', 'domain': 'tvchannel', 'intent': 'PLAY', 'slots': {'name': '山东台', 'resolution': '高清'}} 127 | pre: {'text': '山东台高清', 'domain': 'tvchannel', 'intent': 'PLAY', 'slots': {'name': '山东台高清'}} 128 | true: {'text': '高清动作电影', 'domain': 'video', 'intent': 'QUERY', 'slots': {'category': '电影', 'resolution': '高清', 'tag': '动作'}} 129 | pre: {'text': '高清动作电影', 'domain': 'epg', 'intent': 'QUERY', 'slots': {'name': '高清', 'tag': '动作', 'category': '电影'}} 130 | true: {'text': '甄嬛传高清', 'domain': 'video', 'intent': 'QUERY', 'slots': {'name': '甄嬛传', 'resolution': '高清'}} 131 | pre: {'text': '甄嬛传高清', 'domain': 'tvchannel', 'intent': 'PLAY', 'slots': {'name': '甄嬛传高清'}} 132 | true: {'text': 'CCTV6电影高清', 'domain': 'tvchannel', 'intent': 'PLAY', 'slots': {'name': 'CCTV6电影', 'resolution': '高清'}} 133 | pre: {'text': 'CCTV6电影高清', 'domain': 'tvchannel', 'intent': 'PLAY', 'slots': {'name': 'CCTV6电影高清'}} 134 | true: {'text': '东方卫视高清频道', 'domain': 'tvchannel', 'intent': 'PLAY', 'slots': {'name': '东方卫视', 'resolution': '高清'}} 135 | pre: {'text': '东方卫视高清频道', 'domain': 'tvchannel', 'intent': 'PLAY', 'slots': {'name': '东方卫视高清频道'}} 136 | true: {'text': '高清央视综艺', 'domain': 'tvchannel', 'intent': 'PLAY', 'slots': {'name': '央视综艺', 'resolution': '高清'}} 137 | pre: {'text': '高清央视综艺', 'domain': 'tvchannel', 'intent': 'PLAY', 'slots': {'name': '高清央视综艺'}} 138 | true: {'text': '最新电影', 'domain': 'video', 'intent': 'QUERY', 'slots': {'category': '电影', 'timeDescr': '最新'}} 139 | pre: {'text': '最新电影', 'domain': 'video', 'intent': 'QUERY', 'slots': {'category': '电影'}} 140 | true: {'text': '最近有什么好看得', 'domain': 'video', 'intent': 'QUERY', 'slots': {'scoreDescr': '好看', 'timeDescr': '最近'}} 141 | pre: {'text': '最近有什么好看得', 'domain': 'video', 'intent': 'QUERY', 'slots': {}} 142 | true: {'text': '今天温度是多少', 'domain': 'weather', 'intent': 'QUERY', 'slots': {'datetime_date': '今天', 'subfocus': '温度'}} 143 | pre: {'text': '今天温度是多少', 'domain': 'weather', 'intent': 'QUERY', 'slots': {'datetime_date': '今'}} 144 | true: {'text': '腾讯股票昨日收盘价', 'domain': 'stock', 'intent': 'CLOSEPRICE_QUERY', 'slots': {'name': '腾讯', 'yesterday': '昨日'}} 145 | pre: {'text': '腾讯股票昨日收盘价', 'domain': 'stock', 'intent': 'QUERY', 'slots': {'name': '腾讯'}} 146 | Domain sentence accuracy : 0.963810 147 | Intent sentence accuracy : 0.944762 148 | Slots f score : 0.796247 149 | Avg sentence accuracy : 0.742857 150 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import collections 3 | import os 4 | import pickle 5 | import tensorflow as tf 6 | from bert import modeling, tokenization, optimization 7 | from model_config import joint_config as config 8 | from tensorflow.contrib.layers.python.layers import initializers 9 | from lstm_crf_layer import BLSTM_CRF 10 | import numpy as np 11 | 12 | def get_slot_name(text, slot_label): 13 | slots = {} 14 | for i, slot in enumerate(slot_label): 15 | if slot == 'O': 16 | continue 17 | else: 18 | _, slot_name = slot.split('-') 19 | if slot_name in slots: 20 | slots[slot_name] += text[i] 21 | else: 22 | slots[slot_name] = text[i] 23 | 24 | return slots 25 | 26 | 27 | class InputExample(object): 28 | def __init__(self, guid, text, domain=None, intent=None, slots=None): 29 | self.guid = guid 30 | self.text = text 31 | self.slots = slots 32 | self.domain = domain 33 | self.intent = intent 34 | 35 | class PaddingInputExample(object): 36 | """Fake example so the num input examples is a multiple of the batch size. 37 | When running eval/predict on the TPU, we need to pad the number of examples 38 | to be a multiple of the batch size, because the TPU requires a fixed batch 39 | size. The alternative is to drop the last batch, which is bad because it means 40 | the entire output data won't be generated. 41 | We use this class instead of `None` because treating `None` as padding 42 | battches could cause silent errors. 43 | """ 44 | 45 | class InputFeatures(object): 46 | """A single set of features of data.""" 47 | def __init__(self, 48 | input_ids, 49 | input_mask, 50 | segment_ids, 51 | domain_id, 52 | intent_id, 53 | slot_id, 54 | is_real_example=True): 55 | self.input_ids = input_ids 56 | self.input_mask = input_mask 57 | self.segment_ids = segment_ids 58 | self.domain_id = domain_id 59 | self.intent_id = intent_id 60 | self.slot_id = slot_id 61 | self.is_real_example = is_real_example 62 | 63 | 64 | class Joint_Processor(object): 65 | def get_train_examples(self, data_dir): 66 | return self._create_examples(\ 67 | self._read_json(os.path.join(data_dir, config['train_file'])), "train") 68 | 69 | def get_test_examples(self, test_file): 70 | """读取指定的测试文件""" 71 | return self._create_examples(\ 72 | self._read_json(os.path.join("", test_file)), "test") 73 | 74 | def get_dev_examples(self, data_dir): 75 | pass 76 | 77 | def get_labels(self, data_dir, set_type): 78 | ''' 79 | 根据训练数据获取domain, intent, slots labels 80 | ''' 81 | if set_type == 'train': 82 | data_list = self._read_json(os.path.join(data_dir, config['train_file'])) 83 | domain_labels = set([data['domain'] for data in data_list]) 84 | intent_labels = set([str(data['intent']) for data in data_list]) 85 | 86 | slots_labels = set() 87 | for data in data_list: 88 | for slot in data['slots']: 89 | slots_labels.add("B-%s" % slot) 90 | #slots_labels.add("I-%s" % slot) 91 | slots_labels = list(slots_labels) 92 | 93 | id2domain = {i : label for i, label in enumerate(domain_labels)} 94 | domain2id = {label : i for i, label in id2domain.items()} 95 | 96 | id2intent = {i : label for i, label in enumerate(intent_labels)} 97 | intent2id = {label : i for i, label in id2intent.items()} 98 | 99 | # 100 | domain_d = {} 101 | intent_d = {} 102 | for data in data_list: 103 | if data['domain'] not in domain_d: 104 | domain_d[data['domain']] = 1 105 | else: 106 | domain_d[data['domain']] += 1 107 | 108 | for data in data_list: 109 | if data['intent'] not in intent_d: 110 | intent_d[str(data['intent'])] = 1 111 | else: 112 | intent_d[str(data['intent'])] += 1 113 | 114 | domain_w = [1] * len(domain2id) 115 | intent_w = [1] * len(intent2id) 116 | for key in domain2id: 117 | domain_w[domain2id[key]] = len(data_list) / (len(domain2id) + domain_d[key]) 118 | for key in intent2id: 119 | intent_w[intent2id[key]] = len(data_list) / (len(intent2id) + intent_d[key]) 120 | 121 | 122 | id2slot = {i : label for i, label in enumerate(slots_labels, 4)} 123 | id2slot[0] = '[PAD]' 124 | id2slot[1] = '[CLS]' 125 | id2slot[2] = '[SEP]' 126 | id2slot[3] = 'O' 127 | slot2id = {label : i for i, label in id2slot.items()} 128 | 129 | #保存 130 | with open(config['label_file'], 'wb') as fw: 131 | pickle.dump([id2domain, domain2id, id2intent, intent2id, id2slot, slot2id], fw) 132 | else: 133 | #预测时读取labels 134 | with open(config['label_file'], 'rb') as fr: 135 | id2domain, domain2id, id2intent, intent2id, id2slot, slot2id = pickle.load(fr) 136 | 137 | domain_w = [1] * len(domain2id) 138 | intent_w = [1] * len(intent2id) 139 | 140 | ''' 141 | print("bert load %d domain labels, %d intent labels, %d slot labels" % \ 142 | (len(id2domain), len(id2intent), len(id2slot))) 143 | ''' 144 | 145 | return id2domain, domain2id, id2intent, intent2id, id2slot, slot2id, domain_w, intent_w 146 | 147 | 148 | @classmethod 149 | def _read_json(cls, input_file): 150 | """read json data """ 151 | with open(input_file, "r") as f: 152 | return json.load(f) 153 | 154 | @classmethod 155 | def _get_slot_label(cls, text, slot): 156 | tag = ['O'] * len(text) 157 | for k, v in slot.items(): 158 | index = text.find(v) 159 | if index == -1: 160 | continue 161 | tag[index] = 'B-%s' % k 162 | if len(v) > 1: 163 | for i in range(len(v) - 1): 164 | tag[index + i + 1] = 'I-%s' % k 165 | tag[index + i + 1] = 'B-%s' % k 166 | 167 | return ' '.join(tag) 168 | 169 | def _create_examples(self, data_list, set_type): 170 | examples = [] 171 | for (i, data) in enumerate(data_list): 172 | guid = "%s-%s" % (set_type, i) 173 | text = tokenization.convert_to_unicode(data['text']) 174 | if set_type == "test": 175 | domain = "label-test" 176 | intent = "lable-test" 177 | slots = "label-test" 178 | else: 179 | domain = tokenization.convert_to_unicode(data['domain']) 180 | intent = tokenization.convert_to_unicode(str(data['intent'])) 181 | slots = self._get_slot_label(data['text'], data['slots']) 182 | slots = tokenization.convert_to_unicode(slots) 183 | examples.append(\ 184 | InputExample(guid=guid, text=text, domain=domain, intent=intent, slots=slots)) 185 | if set_type == "test": 186 | pass 187 | else: 188 | import numpy as np 189 | np.random.shuffle(examples) 190 | 191 | return examples 192 | 193 | def convert_single_example(ex_index, example, domain2id, intent2id, slot2id, max_seq_length,\ 194 | tokenizer): 195 | """Converts a single `InputExample` into a single `InputFeatures`.""" 196 | if isinstance(example, PaddingInputExample): 197 | return InputFeatures( 198 | input_ids = [0] * max_seq_length, 199 | input_mask = [0] * max_seq_length, 200 | segment_ids = [0] * max_seq_length, 201 | domain_id = 0, 202 | intent_id = 0, 203 | slot_id = [0] * max_seq_length, 204 | is_real_example = False) 205 | 206 | #tokens_text = tokenizer.tokenize(example.text) 207 | tokens_text = [t if t != ' ' else '$' for t in example.text] 208 | #对测试和训练label分别处理 209 | slot_label = [] 210 | if example.domain == 'label-test': 211 | domain_id = 0 212 | intent_id = 0 213 | slot_label = ['O'] * len(tokens_text) 214 | else: 215 | domain_id = domain2id[example.domain] 216 | intent_id = intent2id[example.intent] 217 | slot_label = example.slots.split() 218 | 219 | if len(tokens_text) > max_seq_length - 2: 220 | tokens_text = tokens_text[0 : max_seq_length - 2] 221 | slot_label = slot_label[0 : max_seq_length - 2] 222 | 223 | assert len(slot_label) == len(tokens_text) 224 | 225 | tokens = [] 226 | segment_ids = [] 227 | slot_id = [] 228 | tokens.append('[CLS]') 229 | segment_ids.append(0) 230 | slot_id.append(slot2id['[CLS]']) 231 | for i, token in enumerate(tokens_text): 232 | tokens.append(tokenizer.tokenize(token)[0]) 233 | segment_ids.append(0) 234 | slot_id.append(slot2id[slot_label[i]]) 235 | tokens.append('[SEP]') 236 | segment_ids.append(0) 237 | slot_id.append(slot2id['[SEP]']) 238 | 239 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 240 | input_mask = [1] * len(input_ids) 241 | 242 | #padding 243 | while len(input_ids) < max_seq_length: 244 | input_ids.append(0) 245 | input_mask.append(0) 246 | segment_ids.append(0) 247 | slot_id.append(0) 248 | 249 | assert len(input_ids) == max_seq_length 250 | assert len(input_mask) == max_seq_length 251 | assert len(segment_ids) == max_seq_length 252 | assert len(slot_id) == max_seq_length 253 | 254 | ''' 255 | if ex_index < 5: 256 | tf.logging.info("*** Example ***") 257 | tf.logging.info("guid: %s" % (example.guid)) 258 | tf.logging.info("tokens: %s" % " ".join([tokenization.printable_text(x) for x in tokens])) 259 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 260 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 261 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 262 | tf.logging.info("domain_ids: %s" % domain_id) 263 | tf.logging.info("intent_ids: %s" % intent_id) 264 | tf.logging.info("slot label: %s (id = %s)" % (example.slots, " ".join([str(x) for x in slot_id]))) 265 | ''' 266 | 267 | feature = InputFeatures( 268 | input_ids = input_ids, 269 | input_mask = input_mask, 270 | segment_ids = segment_ids, 271 | domain_id = domain_id, 272 | intent_id = intent_id, 273 | slot_id = slot_id, 274 | is_real_example = True) 275 | 276 | return feature 277 | 278 | def file_based_convert_examples_to_features(examples, domain2id, intent2id, slot2id,\ 279 | max_seq_length, tokenizer, output_file, task_name="domain"): 280 | """Convert a set of `InputExample`s to a TFRecord file.""" 281 | writer = tf.python_io.TFRecordWriter(output_file) 282 | for (ex_id, example) in enumerate(examples): 283 | if ex_id % 500 == 0: 284 | tf.logging.info("Writing example %d of %d" % (ex_id, len(examples))) 285 | feature = convert_single_example(ex_id, example, domain2id, intent2id, slot2id,\ 286 | max_seq_length, tokenizer) 287 | def create_int_feature(values): 288 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 289 | return f 290 | 291 | features = collections.OrderedDict() 292 | features['input_ids'] = create_int_feature(feature.input_ids) 293 | features['input_mask'] = create_int_feature(feature.input_mask) 294 | features['segment_ids'] = create_int_feature(feature.segment_ids) 295 | features['domain_id'] = create_int_feature([feature.domain_id]) 296 | features['intent_id'] = create_int_feature([feature.intent_id]) 297 | features['slot_id'] = create_int_feature(feature.slot_id) 298 | features['is_real_example'] = create_int_feature([int(feature.is_real_example)]) 299 | 300 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 301 | writer.write(tf_example.SerializeToString()) 302 | 303 | writer.close() 304 | 305 | def file_based_input_fn_builder(input_file, seq_length, is_training, drop_remainder, task_name="domain"): 306 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 307 | name_to_features = { 308 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 309 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64), 310 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 311 | "domain_id": tf.FixedLenFeature([], tf.int64), 312 | "intent_id": tf.FixedLenFeature([], tf.int64), 313 | "slot_id": tf.FixedLenFeature([seq_length], tf.int64), 314 | "is_real_example": tf.FixedLenFeature([], tf.int64), 315 | } 316 | 317 | def _decode_record(record, name_to_features): 318 | """Decodes a record to a TensorFlow example.""" 319 | example = tf.parse_single_example(record, name_to_features) 320 | 321 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 322 | # So cast all int64 to int32. 323 | for name in list(example.keys()): 324 | t = example[name] 325 | if t.dtype == tf.int64: 326 | t = tf.to_int32(t) 327 | example[name] = t 328 | 329 | return example 330 | 331 | def input_fn(params): 332 | """The actual input function.""" 333 | #任务batch size都设置相同 334 | batch_size = config['train_batch_size'] if is_training else config['predict_batch_size'] 335 | 336 | # For training, we want a lot of parallel reading and shuffling. 337 | # For eval, we want no shuffling and parallel reading doesn't matter. 338 | d = tf.data.TFRecordDataset(input_file) 339 | if is_training: 340 | d = d.repeat() 341 | d = d.shuffle(buffer_size=100) 342 | 343 | d = d.apply( 344 | tf.contrib.data.map_and_batch( 345 | lambda record: _decode_record(record, name_to_features), 346 | batch_size=batch_size, 347 | drop_remainder=drop_remainder)) 348 | return d 349 | 350 | return input_fn 351 | 352 | def convert_examples_to_features(examples, domain2id, intent2id, slot2id,\ 353 | max_seq_length, tokenizer): 354 | """在测试阶段将读取的InputExample数据转成features格式""" 355 | features_list = [] 356 | for ex_index, example in enumerate(examples): 357 | feature = convert_single_example(ex_index, example, domain2id, \ 358 | intent2id, slot2id, max_seq_length, tokenizer) 359 | 360 | features_list.append(feature) 361 | 362 | return features_list 363 | 364 | def input_fn_builder(features, seq_length, is_training, drop_remainder): 365 | """在测试过程中读取features的函数""" 366 | all_input_ids = [] 367 | all_input_mask = [] 368 | all_segment_ids = [] 369 | all_domain_ids = [] 370 | all_intent_ids = [] 371 | all_slot_ids = [] 372 | 373 | for feature in features: 374 | all_input_ids.append(feature.input_ids) 375 | all_input_mask.append(feature.input_mask) 376 | all_segment_ids.append(feature.segment_ids) 377 | all_domain_ids.append(feature.domain_id) 378 | all_intent_ids.append(feature.intent_id) 379 | all_slot_ids.append(feature.slot_id) 380 | 381 | def input_fn(params): 382 | """the actual input function""" 383 | batch_size = config['predict_batch_size'] 384 | num_examples = len(features) 385 | d = tf.data.Dataset.from_tensor_slices({ 386 | "input_ids": tf.constant(all_input_ids, \ 387 | shape=[num_examples, seq_length], dtype=tf.int32), 388 | "input_mask": tf.constant(all_input_mask,\ 389 | shape=[num_examples, seq_length], dtype=tf.int32), 390 | "segment_ids": tf.constant(all_segment_ids, \ 391 | shape=[num_examples, seq_length], dtype=tf.int32), 392 | "domain_id": tf.constant(all_domain_ids, \ 393 | shape=[num_examples], dtype=tf.int32), 394 | "intent_id": tf.constant(all_intent_ids, \ 395 | shape=[num_examples], dtype=tf.int32), 396 | "slot_id": tf.constant(all_slot_ids, \ 397 | shape=[num_examples, seq_length], dtype=tf.int32) 398 | }) 399 | if is_training: 400 | d = d.repeat() 401 | d = d.shuffle(buffer_size = 100) 402 | d = d.batch(batch_size = batch_size, drop_remainder = drop_remainder) 403 | 404 | return d 405 | 406 | return input_fn 407 | 408 | 409 | def domain_classification(model, domain_id, num_domain, is_training, domain_w): 410 | '''domain classification''' 411 | #[batch, hidden_size] 412 | print("domain_classification...") 413 | output_layer = model.get_pooled_output() 414 | hidden_size = output_layer.shape[-1].value 415 | domain_output_weights = tf.get_variable("domain_output_weights",\ 416 | [num_domain, hidden_size],\ 417 | initializer = tf.truncated_normal_initializer(stddev=0.02)) 418 | domain_output_bias = tf.get_variable("domain_output_bias",\ 419 | [num_domain], initializer=tf.zeros_initializer()) 420 | 421 | domain_w = tf.get_variable("domain_w", initializer=domain_w, dtype=tf.float32, trainable=False) 422 | 423 | with tf.variable_scope("domain_loss"): 424 | if is_training: 425 | output_layer = tf.nn.dropout(output_layer, keep_prob = config['dropout_rate']) 426 | domain_logits = tf.matmul(output_layer, domain_output_weights, transpose_b=True) 427 | domain_logits = tf.nn.bias_add(domain_logits, domain_output_bias) 428 | domain_probabilities = tf.nn.softmax(domain_logits, axis=-1) 429 | domain_log_probs = tf.nn.log_softmax(domain_logits, axis=-1) 430 | domain_predictions = tf.argmax(domain_logits, axis=-1) 431 | domain_one_hot_lables = tf.one_hot(domain_id, depth=num_domain, dtype=tf.float32) 432 | domain_per_example_loss = -tf.reduce_sum(domain_one_hot_lables * domain_log_probs * domain_w, axis=-1) 433 | 434 | domain_loss = tf.reduce_mean(domain_per_example_loss) 435 | 436 | return domain_loss, domain_probabilities, domain_predictions 437 | 438 | def intent_classification(model, intent_id, num_intent, is_training, intent_w): 439 | '''intent classification''' 440 | #[batch, hidden_size] 441 | output_layer = model.get_pooled_output() 442 | hidden_size = output_layer.shape[-1].value 443 | intent_output_weights = tf.get_variable("intent_output_weights",\ 444 | [num_intent, hidden_size],\ 445 | initializer = tf.truncated_normal_initializer(stddev=0.02)) 446 | intent_output_bias = tf.get_variable("intent_output_bias",\ 447 | [num_intent], initializer=tf.zeros_initializer()) 448 | 449 | intent_w = tf.get_variable("intent_w", initializer=intent_w, dtype=tf.float32, trainable=False) 450 | 451 | with tf.variable_scope("intent_loss"): 452 | if is_training: 453 | output_layer = tf.nn.dropout(output_layer, keep_prob = config['dropout_rate']) 454 | intent_logits = tf.matmul(output_layer, intent_output_weights, transpose_b=True) 455 | intent_logits = tf.nn.bias_add(intent_logits, intent_output_bias) 456 | intent_probabilities = tf.nn.softmax(intent_logits, axis=-1) 457 | intent_log_probs = tf.nn.log_softmax(intent_logits, axis=-1) 458 | intent_predictions = tf.argmax(intent_logits, axis=-1) 459 | intent_one_hot_lables = tf.one_hot(intent_id, depth=num_intent, dtype=tf.float32) 460 | #类别不均匀,损失加权 461 | intent_per_example_loss = -tf.reduce_sum(intent_one_hot_lables * intent_log_probs * intent_w, axis=-1) 462 | 463 | intent_loss = tf.reduce_mean(intent_per_example_loss) 464 | 465 | return intent_loss, intent_probabilities, intent_predictions 466 | 467 | def slot_filling(model, lengths, slot_id, num_slot, is_training): 468 | ''' 469 | slot filling 470 | ''' 471 | #获取对应的embedding 输入数据[batch_size, seq_length, embedding_size] 472 | embedding = model.get_sequence_output() 473 | 474 | max_seq_length = embedding.shape[1].value 475 | 476 | #添加CRF out 477 | blstm_crf_config = { 478 | "embedded_chars": embedding, 479 | "hidden_unit": config['lstm_size'], 480 | "cell_type": config['cell'], 481 | "num_layers": config['num_layers'], 482 | "dropout_rate": config['dropout_rate'], 483 | "initializers": initializers, 484 | "num_labels": num_slot, 485 | "seq_length": max_seq_length, 486 | "labels": slot_id, 487 | "lengths": lengths, 488 | "is_training": is_training 489 | } 490 | 491 | blstm_crf = BLSTM_CRF(blstm_crf_config) 492 | loss, logits, trans, pred_ids = blstm_crf.add_blstm_crf_layer(crf_only=False) 493 | 494 | return loss, logits, trans, pred_ids 495 | 496 | 497 | def create_model(bert_config, is_training, input_ids, input_mask,\ 498 | segment_ids, domain_id, intent_id, slot_id, num_domain,\ 499 | num_intent, num_slot, use_one_hot_embeddings, domain_w, intent_w): 500 | '''create a sequence labeling and classification model''' 501 | model = modeling.BertModel( 502 | config = bert_config, 503 | is_training = is_training, 504 | input_ids = input_ids, 505 | input_mask = input_mask, 506 | token_type_ids = segment_ids, 507 | use_one_hot_embeddings = use_one_hot_embeddings) 508 | 509 | #算序列真实长度 510 | used = tf.sign(tf.abs(input_ids)) 511 | lengths = tf.reduce_sum(used, reduction_indices = 1) # 512 | 513 | #领域分类 514 | domain_loss, domain_probabilities, domain_pred =\ 515 | domain_classification(model, domain_id, num_domain, is_training, domain_w) 516 | 517 | #意图识别 518 | intent_loss, intent_probabilities, intent_pred =\ 519 | intent_classification(model, intent_id, num_intent, is_training, intent_w) 520 | 521 | #槽位填充 522 | slot_loss, slot_logits, trans, slot_pred = slot_filling(model, lengths, slot_id, num_slot, is_training) 523 | 524 | 525 | ''' 526 | if is_training: 527 | total_loss = domain_loss + intent_loss + slot_loss 528 | return total_loss, domain_pred, intent_pred, slot_pred 529 | else: 530 | return None, domain_pred, intent_pred, slot_pred 531 | ''' 532 | return domain_loss, intent_loss, slot_loss, domain_pred, intent_pred, slot_pred 533 | 534 | def model_fn_builder(bert_config, num_domain, num_intent, num_slot, init_checkpoint,\ 535 | learning_rate, num_train_steps, num_warmup_steps, use_tpu,\ 536 | use_one_hot_embeddings, do_serve, domain_w, intent_w): 537 | 538 | #为什么会有一个labels? 539 | def model_fn(features, labels, mode, params): 540 | tf.logging.info("***features***") 541 | #print(features) 542 | input_ids = features['input_ids'] 543 | input_mask = features['input_mask'] 544 | segment_ids = features['segment_ids'] 545 | domain_id = features['domain_id'] 546 | intent_id = features['intent_id'] 547 | slot_id = features['slot_id'] 548 | is_real_example = None #含义 549 | if "is_real_example" in features: 550 | is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) 551 | else: 552 | is_real_example = tf.ones(tf.shape(domain_id), dtype=tf.float32) 553 | 554 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 555 | 556 | domain_loss, intent_loss, slot_loss, domain_pred, intent_pred, slot_pred = \ 557 | create_model(bert_config, is_training, input_ids, input_mask, segment_ids, \ 558 | domain_id, intent_id, slot_id, num_domain, num_intent, num_slot,\ 559 | use_one_hot_embeddings, np.array(domain_w, dtype=np.float32), np.array(intent_w, dtype=np.float32)) 560 | 561 | ''' 562 | total_loss, domain_pred, intent_pred, slot_pred = \ 563 | create_model(bert_config, is_training, input_ids, input_mask, segment_ids, \ 564 | domain_id, intent_id, slot_id, num_domain, num_intent, num_slot,\ 565 | use_one_hot_embeddings) 566 | ''' 567 | 568 | tvars = tf.trainable_variables() 569 | initialized_variable_names = {} 570 | scaffold_fn = None 571 | #加载模型 572 | if init_checkpoint: 573 | (assignment_map, initialized_variable_names) = \ 574 | modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 575 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 576 | 577 | output_spec = None 578 | if mode == tf.estimator.ModeKeys.TRAIN: 579 | 580 | global_step = tf.train.get_global_step() 581 | #print("global_step: ", global_step) 582 | ''' 583 | if num_train_steps < 1000: 584 | total_loss = domain_loss + intent_loss + slot_loss 585 | else: 586 | total_loss = domain_loss + intent_loss + (domain_loss + intent_loss) / slot_loss * slot_loss 587 | ''' 588 | total_loss = domain_loss + intent_loss + 2 * slot_loss 589 | 590 | train_op = optimization.create_optimizer( 591 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 592 | 593 | #EstimatorSpec的使用 594 | output_spec = tf.estimator.EstimatorSpec( 595 | mode = mode, 596 | loss = total_loss, 597 | train_op = train_op, 598 | scaffold = scaffold_fn) 599 | else: 600 | output_spec = tf.estimator.EstimatorSpec( 601 | mode = mode, 602 | predictions = {"domain_pred" : domain_pred, 603 | "intent_pred" : intent_pred, 604 | "slot_pred" : slot_pred}, 605 | scaffold = scaffold_fn) 606 | 607 | return output_spec 608 | 609 | return model_fn 610 | --------------------------------------------------------------------------------