├── .gitignore ├── LICENSE ├── MANIFEST.in ├── PULL_REQUEST_TEMPLATE.md ├── README.md ├── convlab2 ├── __init__.py ├── dialog_agent │ ├── __init__.py │ ├── agent.py │ ├── env.py │ └── session.py ├── dst │ ├── __init__.py │ ├── dst.py │ ├── rule │ │ ├── __init__.py │ │ └── crosswoz │ │ │ ├── __init__.py │ │ │ ├── dst.py │ │ │ └── evaluate.py │ └── trade │ │ ├── __init__.py │ │ ├── crosswoz │ │ ├── EWC_train.py │ │ ├── GEM_train.py │ │ ├── README.md │ │ ├── cnembedding.py │ │ ├── demo.py │ │ ├── evaluate.py │ │ ├── fine_tune.py │ │ ├── models │ │ │ └── TRADE.py │ │ ├── trade.py │ │ ├── train.py │ │ └── utils │ │ │ ├── config.py │ │ │ ├── fix_label.py │ │ │ ├── logger.py │ │ │ ├── mapping.pair │ │ │ ├── masked_cross_entropy.py │ │ │ ├── measures.py │ │ │ ├── multi-bleu.perl │ │ │ ├── utils_multiWOZ_DST.py │ │ │ └── utils_temp.py │ │ └── trade.py ├── nlg │ ├── __init__.py │ ├── nlg.py │ ├── sclstm │ │ ├── __init__.py │ │ ├── bleu.py │ │ ├── crosswoz │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── config │ │ │ │ ├── config.cfg │ │ │ │ └── config_usr.cfg │ │ │ ├── evaluate.py │ │ │ ├── generate_resources.py │ │ │ ├── loader │ │ │ │ └── dataset_woz.py │ │ │ ├── sc_lstm.py │ │ │ └── train.py │ │ └── model │ │ │ ├── layers │ │ │ └── decoder_deep.py │ │ │ ├── lm_deep.py │ │ │ └── masked_cross_entropy.py │ └── template │ │ ├── __init__.py │ │ └── crosswoz │ │ ├── __init__.py │ │ ├── auto_system_template_nlg.json │ │ ├── auto_user_template_nlg.json │ │ ├── evaluate.py │ │ ├── generate_auto_template.py │ │ ├── manual_system_template_nlg.json │ │ ├── manual_user_template_nlg.json │ │ └── nlg.py ├── nlu │ ├── __init__.py │ ├── jointBERT │ │ ├── __init__.py │ │ ├── crosswoz │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── analyse.py │ │ │ ├── configs │ │ │ │ ├── crosswoz_all.json │ │ │ │ ├── crosswoz_all_context.json │ │ │ │ ├── crosswoz_all_context_fr.json │ │ │ │ └── crosswoz_all_fr.json │ │ │ ├── nlu.py │ │ │ ├── postprocess.py │ │ │ └── preprocess.py │ │ ├── dataloader.py │ │ ├── jointBERT.py │ │ ├── test.py │ │ └── train.py │ └── nlu.py ├── policy │ ├── README.md │ ├── __init__.py │ ├── mle │ │ ├── __init__.py │ │ ├── crosswoz │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── config.json │ │ │ ├── evaluate.py │ │ │ ├── loader.py │ │ │ ├── mle.py │ │ │ └── train.py │ │ ├── loader.py │ │ ├── mle.py │ │ └── train.py │ ├── policy.py │ ├── rlmodule.py │ ├── rule │ │ ├── __init__.py │ │ └── crosswoz │ │ │ ├── __init__.py │ │ │ ├── evaluate.py │ │ │ └── rule_simulator.py │ ├── vec.py │ └── vector │ │ ├── __init__.py │ │ ├── dataset.py │ │ └── vector_crosswoz.py ├── task │ ├── __init__.py │ └── crosswoz │ │ ├── __init__.py │ │ ├── attraction_generator.py │ │ ├── goal_generator.py │ │ ├── hotel_generator.py │ │ ├── metro_generator.py │ │ ├── reorder.py │ │ ├── restaurant_generator.py │ │ ├── sentence_generator.py │ │ └── taxi_generator.py └── util │ ├── __init__.py │ ├── allennlp_file_utils.py │ ├── crosswoz │ ├── __init__.py │ ├── dbquery.py │ ├── lexicalize.py │ └── state.py │ ├── file_util.py │ ├── module.py │ └── train_util.py ├── data └── crosswoz │ ├── README.md │ ├── database │ ├── attraction_db.json │ ├── database.md │ ├── hotel_db.json │ ├── metro_db.json │ ├── restaurant_db.json │ └── taxi_db.json │ ├── gen_da_voc.py │ ├── sys_da_voc.json │ ├── test.json.zip │ ├── train.json.zip │ ├── usr_da_voc.json │ └── val.json.zip ├── example.png ├── requirements-dev.txt ├── requirements.txt ├── result.png ├── setup.cfg ├── setup.py └── web ├── .editorconfig ├── .env.example ├── .gitignore ├── README.md ├── data_labelling ├── __init__.py ├── admin.py ├── app.py ├── match_making │ ├── __init__.py │ ├── helpers.py │ └── match.py ├── models.py ├── redis.py ├── results │ └── input │ │ └── .gitkeep ├── routes │ ├── __init__.py │ ├── match.py │ ├── misc.py │ ├── room.py │ └── services.py ├── settings.example.py ├── static │ ├── css │ │ ├── bootstrap.min.css │ │ └── fontawesome.all.css │ ├── data │ │ ├── attraction_db.json │ │ ├── hotel_db.json │ │ ├── metro_db.json │ │ └── restaurant_db.json │ ├── js │ │ ├── axios.min.js │ │ ├── bootstrap.min.js │ │ ├── jquery.min.js │ │ ├── jquery.min.map │ │ ├── jquery.slim.min.js │ │ ├── jquery.slim.min.map │ │ ├── polyfill.min.js │ │ ├── popper.min.js │ │ ├── socket.io.min.js │ │ └── vue.min.js │ └── webfonts │ │ ├── fa-brands-400.eot │ │ ├── fa-brands-400.svg │ │ ├── fa-brands-400.ttf │ │ ├── fa-brands-400.woff │ │ ├── fa-brands-400.woff2 │ │ ├── fa-regular-400.eot │ │ ├── fa-regular-400.svg │ │ ├── fa-regular-400.ttf │ │ ├── fa-regular-400.woff │ │ ├── fa-regular-400.woff2 │ │ ├── fa-solid-900.eot │ │ ├── fa-solid-900.svg │ │ ├── fa-solid-900.ttf │ │ ├── fa-solid-900.woff │ │ └── fa-solid-900.woff2 ├── templates │ ├── admin │ │ └── index.html │ ├── base.html │ ├── chatbox.html │ ├── clientside.html │ ├── dashboard.html │ ├── heading.html │ ├── index.html │ ├── login.html │ ├── match.html │ ├── register.html │ ├── room.html │ └── systemside.html └── utils.py ├── ecosystem.config.js ├── example_goal.json ├── requirements.txt ├── resetdb.py ├── run.py ├── setup.bat └── setup.sh /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | 4 | .DS_Store 5 | 6 | # pycharm 7 | .idea 8 | 9 | # vscode 10 | .vscode 11 | 12 | # data 13 | data/**/train.json 14 | data/**/val.json 15 | data/**/test.json 16 | data/camrest/CamRest676_v2.json 17 | data/multiwoz/annotated_user_da_with_span_full.json 18 | data/schema/dstc8-schema-guided-dialogue-master 19 | **/processed_data/* 20 | data/mdbt/data 21 | data/mdbt/models 22 | data/mdbt/word-vectors 23 | convlab2/nlg/sclstm/**/resource/* 24 | convlab2/nlg/sclstm/**/resource_usr/* 25 | convlab2/nlg/sclstm/**/sclstm.pt 26 | convlab2/nlg/sclstm/**/sclstm.res 27 | convlab2/nlg/sclstm/**/sclstm.log 28 | convlab2/nlg/sclstm/**/sclstm_usr.pt 29 | convlab2/nlg/sclstm/**/sclstm_usr.res 30 | convlab2/nlg/sclstm/**/sclstm_usr.log 31 | convlab2/nlu/jointBERT/**/output/ 32 | convlab2/dst/sumbt/multiwoz/output/ 33 | convlab2/nlg/sclstm/**/generated_sens_sys.json 34 | convlab2/nlg/template/**/generated_sens_sys.json 35 | # test script 36 | *_test.py 37 | 38 | # log 39 | **/log/** 40 | *.log 41 | 42 | # save 43 | **/save/** 44 | 45 | # .bak.py 46 | *.bak.py 47 | 48 | # compile files 49 | build 50 | dist 51 | convlab2.egg-info 52 | 53 | # configs 54 | 55 | 56 | .ipynb_checkpoints 57 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE.txt 2 | include README.md 3 | prune convlab2/*/__pycache__ 4 | -------------------------------------------------------------------------------- /PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | **Description:** 2 | 3 | 4 | 5 | **Reference Issues:** #XX (XX is the issue number you work on) 6 | -------------------------------------------------------------------------------- /convlab2/__init__.py: -------------------------------------------------------------------------------- 1 | from convlab2.nlu import NLU 2 | from convlab2.dst import DST 3 | from convlab2.policy import Policy 4 | from convlab2.nlg import NLG 5 | from convlab2.dialog_agent import Agent, PipelineAgent 6 | from convlab2.dialog_agent import Session, BiSession, DealornotSession 7 | 8 | from os.path import abspath, dirname 9 | 10 | 11 | def get_root_path(): 12 | return dirname(dirname(abspath(__file__))) 13 | -------------------------------------------------------------------------------- /convlab2/dialog_agent/__init__.py: -------------------------------------------------------------------------------- 1 | from convlab2.dialog_agent.agent import Agent, PipelineAgent 2 | from convlab2.dialog_agent.session import Session, BiSession, DealornotSession 3 | 4 | __all__ = ['Agent', 'PipelineAgent', 'Session', 'BiSession', 'DealornotSession'] -------------------------------------------------------------------------------- /convlab2/dialog_agent/env.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jul 17 14:27:34 2019 4 | 5 | @author: truthless 6 | """ 7 | 8 | class Environment(): 9 | 10 | def __init__(self, sys_nlg, usr, sys_nlu, sys_dst): 11 | self.sys_nlg = sys_nlg 12 | self.usr = usr 13 | self.sys_nlu = sys_nlu 14 | self.sys_dst = sys_dst 15 | 16 | def reset(self): 17 | self.usr.init_session() 18 | self.sys_dst.init_session() 19 | return self.sys_dst.state 20 | 21 | def step(self, action): 22 | model_response = self.sys_nlg.generate(action) if self.sys_nlg else action 23 | observation = self.usr.response(model_response) 24 | dialog_act = self.sys_nlu.predict(observation) if self.sys_nlu else observation 25 | state = self.sys_dst.update(dialog_act) 26 | 27 | reward = self.usr.get_reward() 28 | terminated = self.usr.is_terminated() 29 | 30 | return state, reward, terminated 31 | -------------------------------------------------------------------------------- /convlab2/dst/__init__.py: -------------------------------------------------------------------------------- 1 | from convlab2.dst.dst import DST 2 | -------------------------------------------------------------------------------- /convlab2/dst/dst.py: -------------------------------------------------------------------------------- 1 | """Dialog State Tracker Interface""" 2 | from convlab2.util.module import Module 3 | 4 | 5 | class DST(Module): 6 | """Base class for dialog state tracker models.""" 7 | 8 | def update(self, action): 9 | """ Update the internal dialog state variable. 10 | update state['user_action'] with input action 11 | 12 | Args: 13 | action (str or list of tuples): 14 | The type is str when DST is word-level (such as NBT), and list of tuples when it is DA-level. 15 | Returns: 16 | new_state (dict): 17 | Updated dialog state, with the same form of previous state. 18 | """ 19 | pass 20 | -------------------------------------------------------------------------------- /convlab2/dst/rule/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/convlab2/dst/rule/__init__.py -------------------------------------------------------------------------------- /convlab2/dst/rule/crosswoz/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/convlab2/dst/rule/crosswoz/__init__.py -------------------------------------------------------------------------------- /convlab2/dst/rule/crosswoz/dst.py: -------------------------------------------------------------------------------- 1 | from convlab2.dst.dst import DST 2 | from convlab2.util.crosswoz.state import default_state 3 | from convlab2.util.crosswoz.dbquery import Database 4 | from copy import deepcopy 5 | from collections import Counter 6 | from pprint import pprint 7 | 8 | 9 | class RuleDST(DST): 10 | """Rule based DST which trivially updates new values from NLU result to states. 11 | 12 | Attributes: 13 | state(dict): 14 | Dialog state. Function ``convlab2.util.crosswoz.state.default_state`` returns a default state. 15 | """ 16 | def __init__(self): 17 | super().__init__() 18 | self.state = default_state() 19 | self.database = Database() 20 | 21 | def init_session(self, state=None): 22 | """Initialize ``self.state`` with a default state, which ``convlab2.util.crosswoz.state.default_state`` returns.""" 23 | self.state = default_state() if not state else deepcopy(state) 24 | 25 | def update(self, usr_da=None): 26 | """ 27 | update belief_state, cur_domain, request_slot 28 | :param usr_da: 29 | :return: 30 | """ 31 | self.state['user_action'] = usr_da 32 | sys_da = self.state['system_action'] 33 | 34 | select_domains = Counter([x[1] for x in usr_da if x[0] == 'Select']) 35 | request_domains = Counter([x[1] for x in usr_da if x[0] == 'Request']) 36 | inform_domains = Counter([x[1] for x in usr_da if x[0] == 'Inform']) 37 | sys_domains = Counter([x[1] for x in sys_da if x[0] in ['Inform', 'Recommend']]) 38 | if len(select_domains) > 0: 39 | self.state['cur_domain'] = select_domains.most_common(1)[0][0] 40 | elif len(request_domains) > 0: 41 | self.state['cur_domain'] = request_domains.most_common(1)[0][0] 42 | elif len(inform_domains) > 0: 43 | self.state['cur_domain'] = inform_domains.most_common(1)[0][0] 44 | elif len(sys_domains) > 0: 45 | self.state['cur_domain'] = sys_domains.most_common(1)[0][0] 46 | else: 47 | self.state['cur_domain'] = None 48 | 49 | # print('cur_domain', self.cur_domain) 50 | 51 | NoOffer = 'NoOffer' in [x[0] for x in sys_da] and 'Inform' not in [x[0] for x in sys_da] 52 | # DONE: clean cur domain constraints because nooffer 53 | 54 | if NoOffer: 55 | if self.state['cur_domain']: 56 | self.state['belief_state'][self.state['cur_domain']] = deepcopy(default_state()['belief_state'][self.state['cur_domain']]) 57 | 58 | # DONE: clean request slot 59 | for domain, slot in deepcopy(self.state['request_slots']): 60 | if [domain, slot] in [x[1:3] for x in sys_da if x[0] in ['Inform', 'Recommend']]: 61 | self.state['request_slots'].remove([domain, slot]) 62 | 63 | # DONE: domain switch 64 | for intent, domain, slot, value in usr_da: 65 | if intent == 'Select': 66 | from_domain = value 67 | name = self.state['belief_state'][from_domain]['名称'] 68 | if name: 69 | if domain == from_domain: 70 | self.state['belief_state'][domain] = deepcopy(default_state()['belief_state'][domain]) 71 | self.state['belief_state'][domain]['周边{}'.format(from_domain)] = name 72 | 73 | for intent, domain, slot, value in usr_da: 74 | if intent == 'Inform': 75 | if slot in ['名称', '游玩时间', '酒店类型', '出发地', '目的地', '评分', '门票', '价格', '人均消费']: 76 | self.state['belief_state'][domain][slot] = value 77 | elif slot == '推荐菜': 78 | if not self.state['belief_state'][domain][slot]: 79 | self.state['belief_state'][domain][slot] = value 80 | else: 81 | self.state['belief_state'][domain][slot] += ' ' + value 82 | elif '酒店设施' in slot: 83 | if value == '是': 84 | faci = slot.split('-')[1] 85 | if not self.state['belief_state'][domain]['酒店设施']: 86 | self.state['belief_state'][domain]['酒店设施'] = faci 87 | else: 88 | self.state['belief_state'][domain]['酒店设施'] += ' ' + faci 89 | elif intent == 'Request': 90 | self.state['request_slots'].append([domain, slot]) 91 | 92 | return self.state 93 | 94 | def query(self): 95 | return self.database.query(self.state['belief_state'], self.state['cur_domain']) 96 | 97 | 98 | if __name__ == '__main__': 99 | dst = RuleDST() 100 | dst.init_session() 101 | pprint(dst.state) 102 | dst.update([['Inform', '酒店', '评分', '4分以上'],['Request', '酒店', '地址', '']]) 103 | pprint(dst.state) 104 | # pprint(dst.query()) 105 | -------------------------------------------------------------------------------- /convlab2/dst/rule/crosswoz/evaluate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import zipfile 3 | from collections import Counter 4 | from pprint import pprint 5 | from convlab2.dst.rule.crosswoz.dst import RuleDST 6 | from convlab2.util.crosswoz.state import default_state 7 | from copy import deepcopy 8 | 9 | 10 | def calculateJointState(predict_golden): 11 | res = [] 12 | for item in predict_golden: 13 | predicts = item['predict'] 14 | labels = item['golden'] 15 | res.append(predicts==labels) 16 | return sum(res) / len(res) if len(res) else 0. 17 | 18 | 19 | def calculateSlotState(predict_golden): 20 | res = [] 21 | for item in predict_golden: 22 | predicts = item['predict'] 23 | labels = item['golden'] 24 | for x, y in zip(predicts, labels): 25 | for w, z in zip(predicts[x].values(),labels[y].values()): 26 | res.append(w==z) 27 | return sum(res) / len(res) if len(res) else 0. 28 | 29 | 30 | def read_zipped_json(filepath, filename): 31 | archive = zipfile.ZipFile(filepath, 'r') 32 | return json.load(archive.open(filename)) 33 | 34 | 35 | def test_sys_state(data, goal_type): 36 | ruleDST = RuleDST() 37 | state_predict_golden = [] 38 | for task_id, item in data.items(): 39 | if goal_type and item['type']!=goal_type: 40 | continue 41 | ruleDST.init_session() 42 | for i, turn in enumerate(item['messages']): 43 | if turn['role'] == 'sys': 44 | usr_da = item['messages'][i - 1]['dialog_act'] 45 | if i > 2: 46 | for domain, svs in item['messages'][i - 2]['sys_state'].items(): 47 | for slot, value in svs.items(): 48 | if slot != 'selectedResults': 49 | ruleDST.state['belief_state'][domain][slot] = value 50 | ruleDST.update(usr_da) 51 | new_state = deepcopy(ruleDST.state['belief_state']) 52 | golden_state = deepcopy(turn['sys_state_init']) 53 | for x in golden_state: 54 | golden_state[x].pop('selectedResults') 55 | state_predict_golden.append({ 56 | 'predict': new_state, 57 | 'golden': golden_state 58 | }) 59 | print('joint state', calculateJointState(state_predict_golden)) 60 | print('slot state', calculateSlotState(state_predict_golden)) 61 | 62 | 63 | if __name__ == '__main__': 64 | test_data_path = '../../../../data/crosswoz/test.json.zip' 65 | test_data = read_zipped_json(test_data_path, 'test.json') 66 | for goal_type in ['单领域', '独立多领域', '独立多领域+交通', '不独立多领域', '不独立多领域+交通', None]: 67 | print(goal_type) 68 | test_sys_state(test_data, goal_type=goal_type) 69 | -------------------------------------------------------------------------------- /convlab2/dst/trade/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/convlab2/dst/trade/__init__.py -------------------------------------------------------------------------------- /convlab2/dst/trade/crosswoz/EWC_train.py: -------------------------------------------------------------------------------- 1 | from convlab2.dst.trade.crosswoz.utils.config import * 2 | from convlab2.dst.trade.crosswoz.models.TRADE import * 3 | from torch import autograd 4 | from copy import deepcopy 5 | import pickle 6 | import os.path 7 | 8 | 9 | #### LOAD MODEL path 10 | except_domain = args['except_domain'] 11 | directory = args['path'].split("/") 12 | HDD = directory[2].split('HDD')[1].split('BSZ')[0] 13 | # decoder = directory[1].split('-')[0] 14 | BSZ = int(args['batch']) if args['batch'] else int(directory[2].split('BSZ')[1].split('DR')[0]) 15 | args["decoder"] = "TRADE" 16 | args["HDD"] = HDD 17 | 18 | if args['dataset']=='multiwoz': 19 | from convlab2.dst.trade.crosswoz.utils.utils_multiWOZ_DST import * 20 | else: 21 | print("You need to provide the --dataset information") 22 | 23 | 24 | filename_fisher = args['path']+"fisher{}".format(args["fisher_sample"]) 25 | 26 | if(os.path.isfile(filename_fisher) ): 27 | print("Load Fisher Matrix" + filename_fisher) 28 | [fisher,optpar] = pickle.load(open(filename_fisher,'rb')) 29 | else: 30 | train, dev, test, test_special, lang, SLOTS_LIST, gating_dict, max_word = prepare_data_seq(True, args['task'], False, batch_size=1) 31 | model = globals()[args["decoder"]]( 32 | int(HDD), 33 | lang=lang, 34 | path=args['path'], 35 | task=args["task"], 36 | lr=args["learn"], 37 | dropout=args["drop"], 38 | slots=SLOTS_LIST, 39 | gating_dict=gating_dict) 40 | print("Computing Fisher Matrix ") 41 | fisher = {} 42 | optpar = {} 43 | for n, p in model.named_parameters(): 44 | optpar[n] = torch.Tensor(p.cpu().data).cuda() 45 | p.data.zero_() 46 | fisher[n] = torch.Tensor(p.cpu().data).cuda() 47 | 48 | pbar = tqdm(enumerate(train),total=len(train)) 49 | for i, data_o in pbar: 50 | model.train_batch(data_o, int(args['clip']), SLOTS_LIST[1], reset=(i==0)) 51 | model.loss_ptr_to_bp.backward() 52 | for n, p in model.named_parameters(): 53 | if p.grad is not None: 54 | fisher[n].data += p.grad.data ** 2 55 | if(i == args["fisher_sample"]):break 56 | 57 | for name_f,_ in fisher.items():#range(len(fisher)): 58 | fisher[name_f] /= args["fisher_sample"] #len(train) 59 | print("Saving Fisher Matrix in ", filename_fisher) 60 | pickle.dump([fisher,optpar],open(filename_fisher,'wb')) 61 | exit(0) 62 | 63 | 64 | ### LOAD DATA 65 | train, dev, test, test_special, lang, SLOTS_LIST, gating_dict, max_word = prepare_data_seq(True, args['task'], False, batch_size=BSZ) 66 | 67 | args['only_domain'] = except_domain 68 | args['except_domain'] = '' 69 | args["fisher_sample"] = 0 70 | args["data_ratio"] = 1 71 | train_single, dev_single, test_single, _, _, SLOTS_LIST_single, _, _ = prepare_data_seq(True, args['task'], False, batch_size=BSZ) 72 | args['except_domain'] = except_domain 73 | 74 | 75 | #### LOAD MODEL 76 | model = globals()[args["decoder"]]( 77 | int(HDD), 78 | lang=lang, 79 | path=args['path'], 80 | task=args["task"], 81 | lr=args["learn"], 82 | dropout=args["drop"], 83 | slots=SLOTS_LIST, 84 | gating_dict=gating_dict) 85 | 86 | avg_best, cnt, acc = 0.0, 0, 0.0 87 | weights_best = deepcopy(model.state_dict()) 88 | try: 89 | for epoch in range(100): 90 | print("Epoch:{}".format(epoch)) 91 | # Run the train function 92 | pbar = tqdm(enumerate(train_single),total=len(train_single)) 93 | for i, data in pbar: 94 | model.train_batch(data, int(args['clip']), SLOTS_LIST_single[1], reset=(i==0)) 95 | 96 | ### EWC loss 97 | for i, (name,p) in enumerate(model.named_parameters()): 98 | if p.grad is not None: 99 | l = args['lambda_ewc'] * fisher[name].cuda() * (p - optpar[name].cuda()).pow(2) 100 | model.loss_grad += l.sum() 101 | model.optimize(args['clip']) 102 | pbar.set_description(model.print_loss()) 103 | 104 | 105 | if((epoch+1) % int(args['evalp']) == 0): 106 | acc = model.evaluate(dev_single, avg_best, SLOTS_LIST_single[2], args["earlyStop"]) 107 | model.scheduler.step(acc) 108 | if(acc >= avg_best): 109 | avg_best = acc 110 | cnt=0 111 | weights_best = deepcopy(model.state_dict()) 112 | else: 113 | cnt+=1 114 | if(cnt == 6 or (acc==1.0 and args["earlyStop"]==None)): 115 | print("Ran out of patient, early stop...") 116 | break 117 | except KeyboardInterrupt: 118 | pass 119 | 120 | 121 | model.load_state_dict({ name: weights_best[name] for name in weights_best }) 122 | model.eval() 123 | 124 | # After Fine tuning... 125 | print("[Info] After Fine Tune ...") 126 | print("[Info] Test Set on 4 domains...") 127 | acc_test_4d = model.evaluate(test_special, 1e7, SLOTS_LIST[2]) 128 | print("[Info] Test Set on 1 domain {} ...".format(except_domain)) 129 | acc_test = model.evaluate(test_single, 1e7, SLOTS_LIST[3]) 130 | 131 | 132 | -------------------------------------------------------------------------------- /convlab2/dst/trade/crosswoz/README.md: -------------------------------------------------------------------------------- 1 | # TRADE 2 | This is the implementation of TRADE (Transferable Dialogue State Generator) adopted from [jasonwu0731/trade-dst](https://github.com/jasonwu0731/trade-dst) 3 | on the CrossWOZ dataset. 4 | 5 | 6 | ## Example usage 7 | To run an example, you can jump into convlab2/dst/trade/crosswoz, and run the following command: 8 | ```bash 9 | $ python demo.py 10 | ``` 11 | The path in the example is our proposed pre-trained model of TRADE, which will 12 | be downloaded automatically at runtime. 13 | The data required for model running will also be downloaded at runtime. 14 | You can also run you own model by specifying the path parameter. 15 | 16 | ## Train 17 | To train a model from scratch, jump into convlab/dst/trade/crosswoz, and run the following command: 18 | ```bash 19 | $ python train.py 20 | ``` 21 | Note that the training data will be download automatically. 22 | 23 | ## Evaluation 24 | To evaluate the model on the test set of CrossWOZ, you can jump into convlab/dst/trade/crosswoz, and then run the following command: 25 | ```bash 26 | $ python evaluate.py 27 | ``` 28 | The evaluation results, including Joint Accuracy, Turn Accuracy, and Joing F1 on the test set will be shown. 29 | 30 | ## References 31 | ``` 32 | @InProceedings{WuTradeDST2019, 33 | author = "Wu, Chien-Sheng and Madotto, Andrea and Hosseini-Asl, Ehsan and Xiong, Caiming and Socher, Richard and Fung, Pascale", 34 | title = "Transferable Multi-Domain State Generator for Task-Oriented Dialogue Systems", 35 | booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", 36 | year = "2019", 37 | publisher = "Association for Computational Linguistics" 38 | } 39 | ``` 40 | -------------------------------------------------------------------------------- /convlab2/dst/trade/crosswoz/cnembedding.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join, dirname, abspath 3 | import random 4 | 5 | root_path = dirname(abspath(__file__)) 6 | 7 | class CNEmbedding: 8 | def __init__(self): 9 | vector_path = join(root_path, 'data', 'crosswoz', 'vector.txt') 10 | self.word2vec = {} 11 | with open(vector_path) as fin: 12 | lines = fin.readlines()[1:] 13 | for line in lines: 14 | line = line.strip() 15 | tokens = line.split(' ') 16 | word = tokens[0] 17 | vec = tokens[1:] 18 | vec = [float(item) for item in vec] 19 | self.word2vec[word] = vec 20 | self.embed_size = 100 21 | 22 | 23 | def emb(self, token, default='zero'): 24 | get_default = { 25 | 'none': lambda: None, 26 | 'zero': lambda: 0., 27 | 'random': lambda: random.uniform(-0.1, 0.1), 28 | }[default] 29 | vec = self.word2vec.get(token, None) 30 | if vec is None: 31 | vec = [get_default()] * self.embed_size 32 | return vec -------------------------------------------------------------------------------- /convlab2/dst/trade/crosswoz/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 3 | 4 | from convlab2.dst.trade.crosswoz.utils.config import * 5 | from convlab2.dst.trade.crosswoz.trade import * 6 | 7 | ''' 8 | python demo.py 9 | ''' 10 | 11 | # specify model path 12 | args['path'] = 'model/TRADE-multiwozdst/HDD100BSZ4DR0.2ACC-0.3228' 13 | model = CrossWOZTRADE(args['path']) 14 | 15 | 16 | user_act = '你好 , 我想 找家 人均 消费 在 100 - 150 元 的 餐馆 吃 驴 杂汤 这 道菜 , 请 给 我 推荐 一家 餐馆 用餐 吧 。' 17 | model.state['history'] = [['user', user_act]] 18 | state = model.update(user_act) 19 | print(state) -------------------------------------------------------------------------------- /convlab2/dst/trade/crosswoz/evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate NLU models on specified dataset 3 | Usage: python evaluate.py [MultiWOZ|CrossWOZ] [TRADE|mdbt|sumbt|rule] 4 | """ 5 | import random 6 | import numpy 7 | import torch 8 | from convlab2.dst.trade.crosswoz.trade import CrossWOZTRADE 9 | 10 | 11 | def format_history(context): 12 | history = [] 13 | for i in range(len(context)): 14 | history.append(['system' if i%2==1 else 'user', context[i]]) 15 | return history 16 | 17 | if __name__ == '__main__': 18 | seed = 2020 19 | random.seed(seed) 20 | numpy.random.seed(seed) 21 | torch.manual_seed(seed) 22 | 23 | model = CrossWOZTRADE('model/TRADE-multiwozdst/HDD100BSZ4DR0.2ACC-0.3228') 24 | model.evaluate() -------------------------------------------------------------------------------- /convlab2/dst/trade/crosswoz/fine_tune.py: -------------------------------------------------------------------------------- 1 | from convlab2.dst.trade.crosswoz.utils.config import * 2 | from convlab2.dst.trade.crosswoz.models.TRADE import * 3 | from copy import deepcopy 4 | 5 | 6 | except_domain = args['except_domain'] 7 | directory = args['path'].split("/") 8 | HDD = directory[2].split('HDD')[1].split('BSZ')[0] 9 | BSZ = int(args['batch']) if args['batch'] else int(directory[2].split('BSZ')[1].split('DR')[0]) 10 | args["decoder"] = "TRADE" 11 | args["HDD"] = HDD 12 | 13 | if args['dataset']=='multiwoz': 14 | from convlab2.dst.trade.crosswoz.utils.utils_multiWOZ_DST import * 15 | else: 16 | print("You need to provide the --dataset information") 17 | 18 | train, dev, test, test_special, lang, SLOTS_LIST, gating_dict, max_word = prepare_data_seq(True, args['task'], False, batch_size=BSZ) 19 | 20 | args['only_domain'] = except_domain 21 | args['except_domain'] = '' 22 | args["data_ratio"] = 1 23 | train_single, dev_single, test_single, _, _, SLOTS_LIST_single, _, _ = prepare_data_seq(True, args['task'], False, batch_size=BSZ) 24 | args['except_domain'] = except_domain 25 | 26 | model = globals()[args["decoder"]]( 27 | int(HDD), 28 | lang=lang, 29 | path=args['path'], 30 | task=args["task"], 31 | lr=args["learn"], 32 | dropout=args["drop"], 33 | slots=SLOTS_LIST, 34 | gating_dict=gating_dict) 35 | 36 | avg_best, cnt, acc = 0.0, 0, 0.0 37 | weights_best = deepcopy(model.state_dict()) 38 | 39 | try: 40 | for epoch in range(100): 41 | print("Epoch:{}".format(epoch)) 42 | # Run the train function 43 | pbar = tqdm(enumerate(train_single),total=len(train_single)) 44 | for i, data in pbar: 45 | 46 | model.train_batch(data, int(args['clip']), SLOTS_LIST_single[1], reset=(i==0)) 47 | model.optimize(args['clip']) 48 | pbar.set_description(model.print_loss()) 49 | 50 | if((epoch+1) % int(args['evalp']) == 0): 51 | acc = model.evaluate(dev_single, avg_best, SLOTS_LIST_single[2], args["earlyStop"]) 52 | model.scheduler.step(acc) 53 | if(acc > avg_best): 54 | avg_best = acc 55 | cnt=0 56 | weights_best = deepcopy(model.state_dict()) 57 | else: 58 | cnt+=1 59 | if(cnt == 6 or (acc==1.0 and args["earlyStop"]==None)): 60 | print("Ran out of patient, early stop...") 61 | break 62 | except KeyboardInterrupt: 63 | pass 64 | 65 | model.load_state_dict({ name: weights_best[name] for name in weights_best }) 66 | model.eval() 67 | 68 | # After Fine tuning... 69 | print("[Info] After Fine Tune ...") 70 | print("[Info] Test Set on 4 domains...") 71 | acc_test_4d = model.evaluate(test_special, 1e7, SLOTS_LIST[2]) 72 | print("[Info] Test Set on 1 domain {} ...".format(except_domain)) 73 | acc_test = model.evaluate(test_single, 1e7, SLOTS_LIST[3]) 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /convlab2/dst/trade/crosswoz/train.py: -------------------------------------------------------------------------------- 1 | # specify cuda id 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 4 | 5 | from convlab2.dst.trade.crosswoz.utils.config import MODE 6 | from tqdm import tqdm 7 | import torch.nn as nn 8 | import shutil, zipfile 9 | from convlab2.util.file_util import cached_path 10 | 11 | from convlab2.dst.trade.crosswoz.utils.config import * 12 | from convlab2.dst.trade.crosswoz.models.TRADE import * 13 | 14 | ''' 15 | python train.py 16 | ''' 17 | 18 | 19 | def download_data(data_url="https://huggingface.co/ConvLab/ConvLab-2_models/resolve/main/trade_crosswoz_data.zip"): 20 | """Automatically download the pretrained model and necessary data.""" 21 | crosswoz_root = os.path.dirname(os.path.abspath(__file__)) 22 | if os.path.exists(os.path.join(crosswoz_root, 'data/crosswoz')) and \ 23 | os.path.exists(os.path.join(crosswoz_root, 'data/dev_dials.json')): 24 | return 25 | data_dir = os.path.join(crosswoz_root, 'data') 26 | if not os.path.exists(data_dir): 27 | os.mkdir(data_dir) 28 | zip_file_path = os.path.join(data_dir, 'trade_crosswoz_data.zip') 29 | if not os.path.exists(os.path.join(data_dir, 'trade_crosswoz_data.zip')): 30 | print('downloading crosswoz TRADE data files...') 31 | cached_path(data_url, data_dir) 32 | files = os.listdir(data_dir) 33 | target_file = '' 34 | for name in files: 35 | if name.endswith('.json'): 36 | target_file = name[:-5] 37 | try: 38 | assert target_file in files 39 | except Exception as e: 40 | print('allennlp download file error: TRADE Cross model download failed.') 41 | raise e 42 | shutil.copyfile(os.path.join(data_dir, target_file), zip_file_path) 43 | with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: 44 | print('unzipping data file ...') 45 | zip_ref.extractall(data_dir) 46 | 47 | early_stop = args['earlyStop'] 48 | 49 | if args['dataset']=='multiwoz': 50 | from convlab2.dst.trade.crosswoz.utils.utils_multiWOZ_DST import * 51 | early_stop = None 52 | else: 53 | print("You need to provide the --dataset information") 54 | exit(1) 55 | 56 | # specify model parameters 57 | args['decoder'] = 'TRADE' 58 | args['batch'] = 4 59 | args['drop'] = 0.2 60 | args['learn'] = 0.001 61 | args['load_embedding'] = 1 62 | 63 | # Configure models and load data 64 | avg_best, cnt, acc = 0.0, 0, 0.0 65 | download_data() 66 | train, dev, test, test_special, lang, SLOTS_LIST, gating_dict, max_word = prepare_data_seq_cn(True, args['task'], 67 | False, batch_size=int(args['batch'])) 68 | 69 | model = globals()[args['decoder']]( 70 | hidden_size=int(args['hidden']), 71 | lang=lang, 72 | path=args['path'], 73 | task=args['task'], 74 | lr=float(args['learn']), 75 | dropout=float(args['drop']), 76 | slots=SLOTS_LIST, 77 | gating_dict=gating_dict, 78 | nb_train_vocab=max_word, 79 | mode=MODE) 80 | 81 | # print("[Info] Slots include ", SLOTS_LIST) 82 | # print("[Info] Unpointable Slots include ", gating_dict) 83 | 84 | for epoch in range(200): 85 | print("Epoch:{}".format(epoch)) 86 | # Run the train function 87 | pbar = tqdm(enumerate(train),total=len(train)) 88 | for i, data in pbar: 89 | ## only part data to train 90 | # if MODE == 'cn' and i >= 1400: break 91 | model.train_batch(data, int(args['clip']), SLOTS_LIST[1], reset=(i==0)) 92 | model.optimize(args['clip']) 93 | pbar.set_description(model.print_loss()) 94 | # print(data) 95 | # exit(1) 96 | 97 | if((epoch+1) % int(args['evalp']) == 0): 98 | 99 | acc = model.evaluate(dev, avg_best, SLOTS_LIST[2], early_stop) 100 | model.scheduler.step(acc) 101 | 102 | if(acc >= avg_best): 103 | avg_best = acc 104 | cnt=0 105 | best_model = model 106 | else: 107 | cnt+=1 108 | 109 | if(cnt == args["patience"] or (acc==1.0 and early_stop==None)): 110 | print("Ran out of patient, early stop...") 111 | break 112 | 113 | -------------------------------------------------------------------------------- /convlab2/dst/trade/crosswoz/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import argparse 4 | from tqdm import tqdm 5 | import torch 6 | 7 | PAD_token = 1 8 | SOS_token = 3 9 | EOS_token = 2 10 | UNK_token = 0 11 | 12 | MODE = 'cn' 13 | data_version = 'init' # processed 14 | 15 | if torch.cuda.is_available(): 16 | USE_CUDA = True 17 | else: 18 | USE_CUDA = False 19 | 20 | MAX_LENGTH = 10 21 | 22 | parser = argparse.ArgumentParser(description='TRADE Multi-Domain DST') 23 | 24 | # Training Setting 25 | parser.add_argument('-ds','--dataset', help='dataset', required=False, default="multiwoz") 26 | parser.add_argument('-t','--task', help='Task Number', required=False, default="dst") 27 | parser.add_argument('-path','--path', help='path of the file to load', required=False) 28 | parser.add_argument('-sample','--sample', help='Number of Samples', required=False,default=None) 29 | parser.add_argument('-patience','--patience', help='', required=False, default=6, type=int) 30 | parser.add_argument('-es','--earlyStop', help='Early Stop Criteria, BLEU or ENTF1', required=False, default='BLEU') 31 | parser.add_argument('-all_vocab','--all_vocab', help='', required=False, default=1, type=int) 32 | parser.add_argument('-imbsamp','--imbalance_sampler', help='', required=False, default=0, type=int) 33 | parser.add_argument('-data_ratio','--data_ratio', help='', required=False, default=100, type=int) 34 | parser.add_argument('-um','--unk_mask', help='mask out input token to UNK', type=int, required=False, default=1) 35 | parser.add_argument('-bsz','--batch', help='Batch_size', required=False, type=int) 36 | 37 | # Testing Setting 38 | parser.add_argument('-rundev','--run_dev_testing', help='', required=False, default=0, type=int) 39 | parser.add_argument('-viz','--vizualization', help='vizualization', type=int, required=False, default=0) 40 | ## model predictions 41 | parser.add_argument('-gs','--genSample', help='Generate Sample', type=int, required=False, default=0) #### change this when testing 42 | parser.add_argument('-evalp','--evalp', help='evaluation period', required=False, default=1) 43 | parser.add_argument('-an','--addName', help='An add name for the model folder', required=False, default='') 44 | parser.add_argument('-eb','--eval_batch', help='Evaluation Batch_size', required=False, type=int, default=0) 45 | 46 | # Model architecture 47 | parser.add_argument('-gate','--use_gate', help='', required=False, default=1, type=int) 48 | parser.add_argument('-le','--load_embedding', help='', required=False, default=0, type=int) 49 | parser.add_argument('-femb','--fix_embedding', help='', required=False, default=0, type=int) 50 | parser.add_argument('-paral','--parallel_decode', help='', required=False, default=0, type=int) 51 | 52 | # Model Hyper-Parameters 53 | parser.add_argument('-dec','--decoder', help='decoder model', required=False) 54 | parser.add_argument('-hdd','--hidden', help='Hidden size', required=False, type=int, default=100) 55 | parser.add_argument('-lr','--learn', help='Learning Rate', required=False, type=float) 56 | parser.add_argument('-dr','--drop', help='Drop Out', required=False, type=float) 57 | parser.add_argument('-lm','--limit', help='Word Limit', required=False,default=-10000) 58 | parser.add_argument('-clip','--clip', help='gradient clipping', required=False, default=10, type=int) 59 | parser.add_argument('-tfr','--teacher_forcing_ratio', help='teacher_forcing_ratio', type=float, required=False, default=0.5) 60 | # parser.add_argument('-l','--layer', help='Layer Number', required=False) 61 | 62 | # Unseen Domain Setting 63 | parser.add_argument('-l_ewc','--lambda_ewc', help='regularization term for EWC loss', type=float, required=False, default=0.01) 64 | parser.add_argument('-fisher_sample','--fisher_sample', help='number of sample used to approximate fisher mat', type=int, required=False, default=0) 65 | parser.add_argument("--all_model", action="store_true") 66 | parser.add_argument("--domain_as_task", action="store_true") 67 | parser.add_argument('--run_except_4d', help='', required=False, default=1, type=int) 68 | parser.add_argument("--strict_domain", action="store_true") 69 | parser.add_argument('-exceptd','--except_domain', help='', required=False, default="", type=str) 70 | parser.add_argument('-onlyd','--only_domain', help='', required=False, default="", type=str) 71 | 72 | 73 | args = vars(parser.parse_known_args(args=[])[0]) 74 | if args["load_embedding"]: 75 | args["hidden"] = 100 76 | if args["fix_embedding"]: 77 | args["addName"] += "FixEmb" 78 | if args["except_domain"] != "": 79 | args["addName"] += "Except"+args["except_domain"] 80 | if args["only_domain"] != "": 81 | args["addName"] += "Only"+args["only_domain"] 82 | 83 | -------------------------------------------------------------------------------- /convlab2/dst/trade/crosswoz/utils/fix_label.py: -------------------------------------------------------------------------------- 1 | 2 | def fix_general_label_error(labels, type, slots): 3 | label_dict = dict([ (l[0], l[1]) for l in labels]) if type else dict([ (l["slots"][0][0], l["slots"][0][1]) for l in labels]) 4 | 5 | GENERAL_TYPO = { 6 | # type 7 | "guesthouse":"guest house", "guesthouses":"guest house", "guest":"guest house", "mutiple sports":"multiple sports", 8 | "sports":"multiple sports", "mutliple sports":"multiple sports","swimmingpool":"swimming pool", "concerthall":"concert hall", 9 | "concert":"concert hall", "pool":"swimming pool", "night club":"nightclub", "mus":"museum", "ol":"architecture", 10 | "colleges":"college", "coll":"college", "architectural":"architecture", "musuem":"museum", "churches":"church", 11 | # area 12 | "center":"centre", "center of town":"centre", "near city center":"centre", "in the north":"north", "cen":"centre", "east side":"east", 13 | "east area":"east", "west part of town":"west", "ce":"centre", "town center":"centre", "centre of cambridge":"centre", 14 | "city center":"centre", "the south":"south", "scentre":"centre", "town centre":"centre", "in town":"centre", "north part of town":"north", 15 | "centre of town":"centre", "cb30aq": "none", 16 | # price 17 | "mode":"moderate", "moderate -ly": "moderate", "mo":"moderate", 18 | # day 19 | "next friday":"friday", "monda": "monday", 20 | # parking 21 | "free parking":"free", 22 | # internet 23 | "free internet":"yes", 24 | # star 25 | "4 star":"4", "4 stars":"4", "0 star rarting":"none", 26 | # others 27 | "y":"yes", "any":"dontcare", "n":"no", "does not care":"dontcare", "not men":"none", "not":"none", "not mentioned":"none", 28 | '':"none", "not mendtioned":"none", "3 .":"3", "does not":"no", "fun":"none", "art":"none", 29 | } 30 | 31 | for slot in slots: 32 | if slot in label_dict.keys(): 33 | # general typos 34 | if label_dict[slot] in GENERAL_TYPO.keys(): 35 | label_dict[slot] = label_dict[slot].replace(label_dict[slot], GENERAL_TYPO[label_dict[slot]]) 36 | 37 | # miss match slot and value 38 | if slot == "hotel-type" and label_dict[slot] in ["nigh", "moderate -ly priced", "bed and breakfast", "centre", "venetian", "intern", "a cheap -er hotel"] or \ 39 | slot == "hotel-internet" and label_dict[slot] == "4" or \ 40 | slot == "hotel-pricerange" and label_dict[slot] == "2" or \ 41 | slot == "attraction-type" and label_dict[slot] in ["gastropub", "la raza", "galleria", "gallery", "science", "m"] or \ 42 | "area" in slot and label_dict[slot] in ["moderate"] or \ 43 | "day" in slot and label_dict[slot] == "t": 44 | label_dict[slot] = "none" 45 | elif slot == "hotel-type" and label_dict[slot] in ["hotel with free parking and free wifi", "4", "3 star hotel"]: 46 | label_dict[slot] = "hotel" 47 | elif slot == "hotel-star" and label_dict[slot] == "3 star hotel": 48 | label_dict[slot] = "3" 49 | elif "area" in slot: 50 | if label_dict[slot] == "no": label_dict[slot] = "north" 51 | elif label_dict[slot] == "we": label_dict[slot] = "west" 52 | elif label_dict[slot] == "cent": label_dict[slot] = "centre" 53 | elif "day" in slot: 54 | if label_dict[slot] == "we": label_dict[slot] = "wednesday" 55 | elif label_dict[slot] == "no": label_dict[slot] = "none" 56 | elif "price" in slot and label_dict[slot] == "ch": 57 | label_dict[slot] = "cheap" 58 | elif "internet" in slot and label_dict[slot] == "free": 59 | label_dict[slot] = "yes" 60 | 61 | # some out-of-define classification slot values 62 | if slot == "restaurant-area" and label_dict[slot] in ["stansted airport", "cambridge", "silver street"] or \ 63 | slot == "attraction-area" and label_dict[slot] in ["norwich", "ely", "museum", "same area as hotel"]: 64 | label_dict[slot] = "none" 65 | 66 | return label_dict 67 | 68 | -------------------------------------------------------------------------------- /convlab2/dst/trade/crosswoz/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | import tensorflow as tf 3 | import numpy as np 4 | import scipy.misc 5 | try: 6 | from StringIO import StringIO # Python 2.7 7 | except ImportError: 8 | from io import BytesIO # Python 3.x 9 | 10 | 11 | class Logger(object): 12 | 13 | def __init__(self, log_dir): 14 | """Create a summary writer logging to log_dir.""" 15 | self.writer = tf.summary.FileWriter(log_dir) 16 | 17 | def scalar_summary(self, tag, value, step): 18 | """Log a scalar variable.""" 19 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 20 | self.writer.add_summary(summary, step) 21 | 22 | def image_summary(self, tag, images, step): 23 | """Log a list of images.""" 24 | 25 | img_summaries = [] 26 | for i, img in enumerate(images): 27 | # Write the image to a string 28 | try: 29 | s = StringIO() 30 | except: 31 | s = BytesIO() 32 | scipy.misc.toimage(img).save(s, format="png") 33 | 34 | # Create an Image object 35 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 36 | height=img.shape[0], 37 | width=img.shape[1]) 38 | # Create a Summary value 39 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 40 | 41 | # Create and write Summary 42 | summary = tf.Summary(value=img_summaries) 43 | self.writer.add_summary(summary, step) 44 | 45 | def histo_summary(self, tag, values, step, bins=1000): 46 | """Log a histogram of the tensor of values.""" 47 | 48 | # Create a histogram using numpy 49 | counts, bin_edges = np.histogram(values, bins=bins) 50 | 51 | # Fill the fields of the histogram proto 52 | hist = tf.HistogramProto() 53 | hist.min = float(np.min(values)) 54 | hist.max = float(np.max(values)) 55 | hist.num = int(np.prod(values.shape)) 56 | hist.sum = float(np.sum(values)) 57 | hist.sum_squares = float(np.sum(values**2)) 58 | 59 | # Drop the start of the first bin 60 | bin_edges = bin_edges[1:] 61 | 62 | # Add bin edges and counts 63 | for edge in bin_edges: 64 | hist.bucket_limit.append(edge) 65 | for c in counts: 66 | hist.bucket.append(c) 67 | 68 | # Create and write Summary 69 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 70 | self.writer.add_summary(summary, step) 71 | self.writer.flush() -------------------------------------------------------------------------------- /convlab2/dst/trade/crosswoz/utils/mapping.pair: -------------------------------------------------------------------------------- 1 | it's it is 2 | don't do not 3 | doesn't does not 4 | didn't did not 5 | you'd you would 6 | you're you are 7 | you'll you will 8 | i'm i am 9 | they're they are 10 | that's that is 11 | what's what is 12 | couldn't could not 13 | i've i have 14 | we've we have 15 | can't cannot 16 | i'd i would 17 | i'd i would 18 | aren't are not 19 | isn't is not 20 | wasn't was not 21 | weren't were not 22 | won't will not 23 | there's there is 24 | there're there are 25 | . . . 26 | restaurants restaurant -s 27 | hotels hotel -s 28 | laptops laptop -s 29 | cheaper cheap -er 30 | dinners dinner -s 31 | lunches lunch -s 32 | breakfasts breakfast -s 33 | expensively expensive -ly 34 | moderately moderate -ly 35 | cheaply cheap -ly 36 | prices price -s 37 | places place -s 38 | venues venue -s 39 | ranges range -s 40 | meals meal -s 41 | locations location -s 42 | areas area -s 43 | policies policy -s 44 | children child -s 45 | kids kid -s 46 | kidfriendly kid friendly 47 | cards card -s 48 | upmarket expensive 49 | inpricey cheap 50 | inches inch -s 51 | uses use -s 52 | dimensions dimension -s 53 | driverange drive range 54 | includes include -s 55 | computers computer -s 56 | machines machine -s 57 | families family -s 58 | ratings rating -s 59 | constraints constraint -s 60 | pricerange price range 61 | batteryrating battery rating 62 | requirements requirement -s 63 | drives drive -s 64 | specifications specification -s 65 | weightrange weight range 66 | harddrive hard drive 67 | batterylife battery life 68 | businesses business -s 69 | hours hour -s 70 | one 1 71 | two 2 72 | three 3 73 | four 4 74 | five 5 75 | six 6 76 | seven 7 77 | eight 8 78 | nine 9 79 | ten 10 80 | eleven 11 81 | twelve 12 82 | anywhere any where 83 | good bye goodbye 84 | -------------------------------------------------------------------------------- /convlab2/dst/trade/crosswoz/utils/measures.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | import numpy 6 | 7 | import os 8 | import re 9 | import subprocess 10 | import tempfile 11 | import numpy as np 12 | 13 | from six.moves import urllib 14 | 15 | def wer(r, h): 16 | """ 17 | This is a function that calculate the word error rate in ASR. 18 | You can use it like this: wer("what is it".split(), "what is".split()) 19 | """ 20 | #build the matrix 21 | d = numpy.zeros((len(r)+1)*(len(h)+1), dtype=numpy.uint8).reshape((len(r)+1, len(h)+1)) 22 | for i in range(len(r)+1): 23 | for j in range(len(h)+1): 24 | if i == 0: d[0][j] = j 25 | elif j == 0: d[i][0] = i 26 | for i in range(1,len(r)+1): 27 | for j in range(1, len(h)+1): 28 | if r[i-1] == h[j-1]: 29 | d[i][j] = d[i-1][j-1] 30 | else: 31 | substitute = d[i-1][j-1] + 1 32 | insert = d[i][j-1] + 1 33 | delete = d[i-1][j] + 1 34 | d[i][j] = min(substitute, insert, delete) 35 | result = float(d[len(r)][len(h)]) / len(r) * 100 36 | # result = str("%.2f" % result) + "%" 37 | return result 38 | 39 | # -*- coding: utf-8 -*- 40 | # Copyright 2017 Google Inc. 41 | # 42 | # Licensed under the Apache License, Version 2.0 (the "License"); 43 | # you may not use this file except in compliance with the License. 44 | # You may obtain a copy of the License at 45 | # 46 | # http://www.apache.org/licenses/LICENSE-2.0 47 | # 48 | # Unless required by applicable law or agreed to in writing, software 49 | # distributed under the License is distributed on an "AS IS" BASIS, 50 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 51 | # See the License for the specific language governing permissions and 52 | # limitations under the License. 53 | """BLEU metric implementation. 54 | """ 55 | 56 | 57 | def moses_multi_bleu(hypotheses, references, lowercase=False): 58 | """Calculate the bleu score for hypotheses and references 59 | using the MOSES ulti-bleu.perl script. 60 | Args: 61 | hypotheses: A numpy array of strings where each string is a single example. 62 | references: A numpy array of strings where each string is a single example. 63 | lowercase: If true, pass the "-lc" flag to the multi-bleu script 64 | Returns: 65 | The BLEU score as a float32 value. 66 | """ 67 | 68 | if np.size(hypotheses) == 0: 69 | return np.float32(0.0) 70 | 71 | 72 | # Get MOSES multi-bleu script 73 | try: 74 | multi_bleu_path, _ = urllib.request.urlretrieve( 75 | "https://raw.githubusercontent.com/moses-smt/mosesdecoder/" 76 | "master/scripts/generic/multi-bleu.perl") 77 | os.chmod(multi_bleu_path, 0o755) 78 | except: #pylint: disable=W0702 79 | print("Unable to fetch multi-bleu.perl script, using local.") 80 | metrics_dir = os.path.dirname(os.path.realpath(__file__)) 81 | bin_dir = os.path.abspath(os.path.join(metrics_dir, "..", "..", "bin")) 82 | multi_bleu_path = os.path.join(bin_dir, "tools/multi-bleu.perl") 83 | 84 | 85 | # Dump hypotheses and references to tempfiles 86 | hypothesis_file = tempfile.NamedTemporaryFile() 87 | hypothesis_file.write("\n".join(hypotheses).encode("utf-8")) 88 | hypothesis_file.write(b"\n") 89 | hypothesis_file.flush() 90 | reference_file = tempfile.NamedTemporaryFile() 91 | reference_file.write("\n".join(references).encode("utf-8")) 92 | reference_file.write(b"\n") 93 | reference_file.flush() 94 | 95 | 96 | # Calculate BLEU using multi-bleu script 97 | with open(hypothesis_file.name, "r") as read_pred: 98 | bleu_cmd = [multi_bleu_path] 99 | if lowercase: 100 | bleu_cmd += ["-lc"] 101 | bleu_cmd += [reference_file.name] 102 | try: 103 | bleu_out = subprocess.check_output(bleu_cmd, stdin=read_pred, stderr=subprocess.STDOUT) 104 | bleu_out = bleu_out.decode("utf-8") 105 | bleu_score = re.search(r"BLEU = (.+?),", bleu_out).group(1) 106 | bleu_score = float(bleu_score) 107 | except subprocess.CalledProcessError as error: 108 | if error.output is not None: 109 | print("multi-bleu.perl script returned non-zero exit code") 110 | print(error.output) 111 | bleu_score = np.float32(0.0) 112 | 113 | # Close temp files 114 | hypothesis_file.close() 115 | reference_file.close() 116 | return bleu_score -------------------------------------------------------------------------------- /convlab2/dst/trade/trade.py: -------------------------------------------------------------------------------- 1 | from convlab2.dst.dst import DST 2 | 3 | class TRADE(DST): 4 | def update(self, act): 5 | pass -------------------------------------------------------------------------------- /convlab2/nlg/__init__.py: -------------------------------------------------------------------------------- 1 | from convlab2.nlg.nlg import NLG 2 | -------------------------------------------------------------------------------- /convlab2/nlg/nlg.py: -------------------------------------------------------------------------------- 1 | """Natural Language Generation Interface""" 2 | from convlab2.util.module import Module 3 | 4 | 5 | class NLG(Module): 6 | """Base class for NLG model.""" 7 | 8 | def generate(self, action): 9 | """Generate a natural language utterance conditioned on the dialog act. 10 | 11 | Args: 12 | action (list of list): 13 | The dialog action produced by dialog policy module, which is in dialog act format. 14 | Returns: 15 | utterance (str): 16 | A natural langauge utterance. 17 | """ 18 | return '' 19 | -------------------------------------------------------------------------------- /convlab2/nlg/sclstm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/convlab2/nlg/sclstm/__init__.py -------------------------------------------------------------------------------- /convlab2/nlg/sclstm/crosswoz/README.md: -------------------------------------------------------------------------------- 1 | # SCLSTM NLG on CrossWOZ 2 | 3 | Semantically-conditioned LSTM (SC-LSTM) is an NLG model that generates natural linguistically varied responses based on a deep, semantically controlled LSTM architecture. 4 | 5 | - *Sentence planning* maps input semantic symbols (e.g. dialog acts) into an intermediary form representing the utterance. 6 | - *Surface realization* converts the intermediate structure into the final text. 7 | 8 | The code derives from [github](https://github.com/andy194673/nlg-sclstm-multiwoz). We modify it to support user NLG. The original paper can be found at [ACL Anthology](https://aclweb.org/anthology/papers/D/D15/D15-1199/) 9 | 10 | ## Usage 11 | 12 | ### Prepare the data 13 | 14 | ```bash 15 | $ python generate_resources.py 16 | ``` 17 | 18 | This will generate two folders that contain training data: ./resource/\*, ./resource_usr/\*. 19 | 20 | ### Train 21 | 22 | ```bash 23 | $ python train.py --mode=train --model_path=sclstm.pt --n_layer=1 --lr=0.005 > sclstm.log 24 | ``` 25 | 26 | Set *user* to use user NLG,e.g. 27 | 28 | ```bash 29 | $ python train.py --mode=train --model_path=sclstm_usr.pt --n_layer=1 --lr=0.005 --user True > sclstm_usr.log 30 | ``` 31 | 32 | ### Test 33 | 34 | ```bash 35 | $ python train.py --mode=test --model_path=sclstm.pt --n_layer=1 --beam_size=10 > sclstm.res 36 | ``` 37 | 38 | ### Evaluate 39 | 40 | ```bash 41 | $ python evaluate.py [usr|sys] 42 | ``` 43 | 44 | ## Data 45 | 46 | We use CrossWOZ data (`data/crosswoz`). 47 | 48 | ## Performance on CrossWOZ 49 | 50 | `mode` determines the data we use: if mode=`usr`, use user utterances to train; if mode=`sys`, use system utterances to train. 51 | 52 | We evaluate the BLEU4 of delexicalized utterance. The references of a generated sentence are all the golden sentences that have the same dialog act. 53 | 54 | | mode | usr | sys | 55 | | ----- | ------ | ------ | 56 | | BLEU4 | 0.7858 | 0.8595 | 57 | 58 | ## Reference 59 | 60 | ``` 61 | @inproceedings{wen2015semantically, 62 | title={Semantically Conditioned LSTM-based Natural Language Generation for Spoken Dialogue Systems}, 63 | author={Wen, Tsung-Hsien and Gasic, Milica and Mrk{\v{s}}i{\'c}, Nikola and Su, Pei-Hao and Vandyke, David and Young, Steve}, 64 | booktitle={Proceedings of the 2015 Conference on Empirical Methods in Natural Language Processing}, 65 | pages={1711--1721}, 66 | year={2015} 67 | } 68 | ``` 69 | 70 | -------------------------------------------------------------------------------- /convlab2/nlg/sclstm/crosswoz/__init__.py: -------------------------------------------------------------------------------- 1 | from convlab2.nlg.sclstm.crosswoz.sc_lstm import SCLSTM -------------------------------------------------------------------------------- /convlab2/nlg/sclstm/crosswoz/config/config.cfg: -------------------------------------------------------------------------------- 1 | [DATA] 2 | vocab_file = %(dir)s/resource/vocab.txt 3 | feat_file = %(dir)s/resource/feat.json 4 | text_file = %(dir)s/resource/text.json 5 | template_file = %(dir)s/resource/template.txt 6 | dataSplit_file = %(dir)s/resource/split.json 7 | batch_size = 256 8 | shuffle = true 9 | dir = 10 | 11 | [MODEL] 12 | dec_type = sclstm 13 | hidden_size = 100 14 | dropout = 0.25 15 | clip = 0.5 16 | learning_rate = 0.001 17 | 18 | [TRAINING] 19 | model_epoch = best 20 | n_epochs = 75 21 | -------------------------------------------------------------------------------- /convlab2/nlg/sclstm/crosswoz/config/config_usr.cfg: -------------------------------------------------------------------------------- 1 | [DATA] 2 | vocab_file = %(dir)s/resource_usr/vocab.txt 3 | feat_file = %(dir)s/resource_usr/feat.json 4 | text_file = %(dir)s/resource_usr/text.json 5 | template_file = %(dir)s/resource_usr/template.txt 6 | dataSplit_file= %(dir)s/resource_usr/split.json 7 | batch_size = 256 8 | shuffle = true 9 | dir = 10 | 11 | [MODEL] 12 | dec_type = sclstm 13 | hidden_size = 100 14 | dropout = 0.25 15 | clip = 0.5 16 | learning_rate = 0.001 17 | 18 | [TRAINING] 19 | model_epoch = best 20 | n_epochs = 75 21 | -------------------------------------------------------------------------------- /convlab2/nlg/sclstm/model/lm_deep.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | from convlab2.nlg.sclstm.model.layers.decoder_deep import DecoderDeep 8 | from convlab2.nlg.sclstm.model.masked_cross_entropy import masked_cross_entropy 9 | 10 | 11 | class LMDeep(nn.Module): 12 | def __init__(self, dec_type, input_size, output_size, hidden_size, d_size, n_layer=1, dropout=0.5, lr=0.001, use_cuda=False): 13 | super(LMDeep, self).__init__() 14 | self.dec_type = dec_type 15 | self.hidden_size = hidden_size 16 | print('Using deep version with {} layer'.format(n_layer)) 17 | print('Using deep version with {} layer'.format(n_layer), file=sys.stderr) 18 | self.USE_CUDA = use_cuda 19 | self.dec = DecoderDeep(dec_type, input_size, output_size, hidden_size, d_size=d_size, n_layer=n_layer, dropout=dropout, use_cuda=use_cuda) 20 | # if self.dec_type != 'sclstm': 21 | # self.feat2hidden = nn.Linear(d_size, hidden_size) 22 | 23 | self.set_solver(lr) 24 | 25 | def forward(self, input_var, dataset, feats_var, gen=False, beam_search=False, beam_size=1): 26 | batch_size = dataset.batch_size 27 | if self.dec_type == 'sclstm': 28 | init_hidden = Variable(torch.zeros(batch_size, self.hidden_size)) 29 | if self.USE_CUDA: 30 | init_hidden = init_hidden.cuda() 31 | ''' 32 | train/valid (gen=False, beam_search=False, beam_size=1) 33 | test w/o beam_search (gen=True, beam_search=False, beam_size=beam_size) 34 | test w/i beam_search (gen=True, beam_search=True, beam_size=beam_size) 35 | ''' 36 | if beam_search: 37 | assert gen 38 | decoded_words = self.dec.beam_search(input_var, dataset, init_hidden=init_hidden, init_feat=feats_var, \ 39 | gen=gen, beam_size=beam_size) 40 | return decoded_words # list (batch_size=1) of list (beam_size) with generated sentences 41 | 42 | # w/o beam_search 43 | sample_size = beam_size 44 | decoded_words = [ [] for _ in range(batch_size) ] 45 | for sample_idx in range(sample_size): # over generation 46 | self.output_prob, gens = self.dec(input_var, dataset, init_hidden=init_hidden, init_feat=feats_var, \ 47 | gen=gen, sample_size=sample_size) 48 | for batch_idx in range(batch_size): 49 | decoded_words[batch_idx].append(gens[batch_idx]) 50 | 51 | return decoded_words # list (batch_size) of list (sample_size) with generated sentences 52 | 53 | 54 | else: # TODO: vanilla lstm 55 | pass 56 | # last_hidden = self.feat2hidden(conds_batches) 57 | # self.output_prob, decoded_words = self.dec(input_seq, dataset, last_hidden=last_hidden, gen=gen, random_sample=self.random_sample) 58 | 59 | 60 | def generate(self, dataset, feats_var, beam_size=1): 61 | batch_size = dataset.batch_size 62 | init_hidden = Variable(torch.zeros(batch_size, self.hidden_size)) 63 | if self.USE_CUDA: 64 | init_hidden = init_hidden.cuda() 65 | decoded_words = self.dec.beam_search(None, dataset, init_hidden=init_hidden, init_feat=feats_var, \ 66 | gen=True, beam_size=beam_size) 67 | return decoded_words 68 | 69 | def set_solver(self, lr): 70 | if self.dec_type == 'sclstm': 71 | self.solver = torch.optim.Adam(self.dec.parameters(), lr=lr) 72 | else: 73 | self.solver = torch.optim.Adam([{'params': self.dec.parameters()}, {'params': self.feat2hidden.parameters()}], lr=lr) 74 | 75 | 76 | def get_loss(self, target_label, target_lengths): 77 | self.loss = masked_cross_entropy( 78 | self.output_prob.contiguous(), # -> batch x seq 79 | target_label.contiguous(), # -> batch x seq 80 | target_lengths) 81 | return self.loss 82 | 83 | 84 | def update(self, clip): 85 | # Back prop 86 | self.loss.backward() 87 | 88 | # Clip gradient norms 89 | _ = torch.nn.utils.clip_grad_norm(self.dec.parameters(), clip) 90 | 91 | # Update 92 | self.solver.step() 93 | 94 | # Zero grad 95 | self.solver.zero_grad() 96 | -------------------------------------------------------------------------------- /convlab2/nlg/sclstm/model/masked_cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch.nn import functional 4 | 5 | 6 | def sequence_mask(sequence_length, max_len=None): 7 | if max_len is None: 8 | max_len = sequence_length.data.max() 9 | batch_size = sequence_length.size(0) 10 | # seq_range = torch.range(0, max_len - 1).long() 11 | seq_range = torch.arange(0, max_len).long() # andy 12 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) 13 | seq_range_expand = Variable(seq_range_expand) 14 | if sequence_length.is_cuda: 15 | seq_range_expand = seq_range_expand.cuda() 16 | seq_length_expand = (sequence_length.unsqueeze(1) 17 | .expand_as(seq_range_expand)) 18 | return seq_range_expand < seq_length_expand 19 | 20 | 21 | def masked_cross_entropy(logits, target, length): 22 | length = Variable(torch.LongTensor(length)).cuda() 23 | 24 | """ 25 | Args: 26 | logits: A Variable containing a FloatTensor of size 27 | (batch, max_len, num_classes) which contains the 28 | unnormalized probability for each class. 29 | target: A Variable containing a LongTensor of size 30 | (batch, max_len) which contains the index of the true 31 | class for each corresponding step. 32 | length: A Variable containing a LongTensor of size (batch,) 33 | which contains the length of each data in a batch. 34 | Returns: 35 | loss: An average loss value masked by the length. 36 | """ 37 | 38 | # logits_flat: (batch * max_len, num_classes) 39 | logits_flat = logits.view(-1, logits.size(-1)) 40 | # log_probs_flat: (batch * max_len, num_classes) 41 | log_probs_flat = functional.log_softmax(logits_flat, dim=1) 42 | # target_flat: (batch * max_len, 1) 43 | target_flat = target.view(-1, 1) 44 | # losses_flat: (batch * max_len, 1) 45 | losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat) 46 | # losses: (batch, max_len) 47 | losses = losses_flat.view(*target.size()) 48 | # mask: (batch, max_len) 49 | mask = sequence_mask(sequence_length=length, max_len=target.size(1)) 50 | losses = losses * mask.float() 51 | loss = losses.sum() / length.float().sum() # per word loss 52 | return loss 53 | -------------------------------------------------------------------------------- /convlab2/nlg/template/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/convlab2/nlg/template/__init__.py -------------------------------------------------------------------------------- /convlab2/nlg/template/crosswoz/__init__.py: -------------------------------------------------------------------------------- 1 | from convlab2.nlg.template.crosswoz.nlg import TemplateNLG 2 | -------------------------------------------------------------------------------- /convlab2/nlu/__init__.py: -------------------------------------------------------------------------------- 1 | from convlab2.nlu.nlu import NLU 2 | -------------------------------------------------------------------------------- /convlab2/nlu/jointBERT/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/convlab2/nlu/jointBERT/__init__.py -------------------------------------------------------------------------------- /convlab2/nlu/jointBERT/crosswoz/README.md: -------------------------------------------------------------------------------- 1 | # BERTNLU on CrossWOZ 2 | 3 | Based on pre-trained bert, BERTNLU use a linear layer for slot tagging and another linear layer for intent classification. Dialog acts are split into two groups, depending on whether the value is in the utterance. 4 | 5 | - For those dialog acts that the value appears in the utterance, they are translated to BIO tags. For example, `"Find me a cheap hotel"`, its dialog act is `{"Hotel-Inform":[["Price", "cheap"]]}`, and translated tag sequence is `["O", "O", "O", "B-Hotel-Inform+Price", "O"]`. An MLP takes bert word embeddings as input and classify the tag label. If you set `context=true` in config file, utterances of last three turn will be concatenated and provide context information with embedding of `[CLS]` for classification. 6 | - For each of the other dialog acts, such as `(Hotel-Request, Address, ?)`, another MLP takes embeddings of `[CLS]` of current utterance as input and do the binary classification. If you set `context=true` in config file, utterances of last three turn will be concatenated and provide context information with embedding of `[CLS]` for classification. 7 | 8 | We fine-tune BERT parameters on crosswoz. 9 | 10 | ## Usage 11 | 12 | Determine which data you want to use: if **mode**='usr', use user utterances to train; if **mode**='sys', use system utterances to train; if **mode**='all', use both user and system utterances to train. 13 | 14 | #### Preprocess data 15 | 16 | On `jointBERT/crosswoz` dir: 17 | 18 | ```sh 19 | $ python preprocess.py [mode] 20 | ``` 21 | 22 | output processed data on `data/[mode]_data/` dir. 23 | 24 | #### Train a model 25 | 26 | On `jointBERT` dir: 27 | 28 | ```sh 29 | $ python train.py --config_path crosswoz/configs/[config_file] 30 | ``` 31 | 32 | The model will be saved under `output_dir` of config_file. Also, it will be zipped as `zipped_model_path` in config_file. 33 | 34 | #### Test a model 35 | 36 | On `jointBERT` dir: 37 | 38 | ```sh 39 | $ python test.py --config_path crosswoz/configs/[config_file] 40 | ``` 41 | 42 | The result (`output.json`) will be saved under `output_dir` of config_file. 43 | 44 | #### Predict 45 | 46 | See `nlu.py` for usage 47 | 48 | #### Trained model 49 | 50 | We have trained two models: one use context information (last 3 utterances)(`configs/crosswoz_all_context.json`) and the other doesn't (`configs/crosswoz_all.json`) on **all** utterances of crosswoz dataset (`data/crosswoz/[train|val|test].json.zip`). Performance: 51 | 52 | | | F1 | 53 | | --------------- | ----- | 54 | | without context | 91.85 | 55 | | with context | 95.53 | 56 | 57 | Models can be download form: 58 | 59 | Without context: https://huggingface.co/ConvLab/ConvLab-2_models/resolve/main/bert_crosswoz_all.zip 60 | 61 | With context: https://huggingface.co/ConvLab/ConvLab-2_models/resolve/main/bert_crosswoz_all_context.zip 62 | 63 | 64 | 65 | ## Data 66 | 67 | We use the crosswoz data (`data/crosswoz/[train|val|test].json.zip`). 68 | 69 | ## References 70 | 71 | ``` 72 | @inproceedings{devlin2019bert, 73 | title={BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding}, 74 | author={Devlin, Jacob and Chang, Ming-Wei and Lee, Kenton and Toutanova, Kristina}, 75 | booktitle={Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)}, 76 | pages={4171--4186}, 77 | year={2019} 78 | } 79 | ``` -------------------------------------------------------------------------------- /convlab2/nlu/jointBERT/crosswoz/__init__.py: -------------------------------------------------------------------------------- 1 | from convlab2.nlu.jointBERT.crosswoz.nlu import BERTNLU -------------------------------------------------------------------------------- /convlab2/nlu/jointBERT/crosswoz/analyse.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pprint import pprint 3 | import zipfile 4 | 5 | 6 | def read_zipped_json(filepath, filename): 7 | archive = zipfile.ZipFile(filepath, 'r') 8 | return json.load(archive.open(filename)) 9 | 10 | 11 | def get_goal_type(data, mode): 12 | goal_types = [] 13 | for no, sess in data.items(): 14 | goal_type = sess['type'] 15 | for i, turn in enumerate(sess['messages']): 16 | if mode == 'usr' and turn['role'] == 'sys': 17 | continue 18 | elif mode == 'sys' and turn['role'] == 'usr': 19 | continue 20 | goal_types.append(goal_type) 21 | return goal_types 22 | 23 | 24 | def calculateF1(predict_golden, goal_type=None, intent=None, domain=None, slot=None): 25 | if domain=='General': 26 | domain = None 27 | intent = 'General' 28 | TP, FP, FN = 0, 0, 0 29 | for item in predict_golden: 30 | predicts = item['predict'] 31 | labels = item['golden'] 32 | for quad in predicts: 33 | if intent and quad[0] != intent: 34 | continue 35 | if domain and quad[1] != domain: 36 | continue 37 | if quad in labels: 38 | TP += 1 39 | else: 40 | FP += 1 41 | for quad in labels: 42 | if intent and quad[0] != intent: 43 | continue 44 | if domain and quad[1] != domain: 45 | continue 46 | if quad not in predicts: 47 | FN += 1 48 | # print(TP, FP, FN) 49 | precision = 1.0 * TP / (TP + FP) 50 | recall = 1.0 * TP / (TP + FN) 51 | F1 = 2.0 * precision * recall / (precision + recall) 52 | return precision, recall, F1 53 | 54 | 55 | if __name__ == '__main__': 56 | predict_golden = json.load(open('output/all_context/output.json',encoding='utf-8')) 57 | print('all', calculateF1(predict_golden)) 58 | goal_types = get_goal_type(read_zipped_json('../../../../data/crosswoz/test.json.zip', 'test.json',),mode='all') 59 | type_predict_golden = {} 60 | for goal_type, d in zip(goal_types,predict_golden): 61 | type_predict_golden.setdefault(goal_type, []) 62 | type_predict_golden[goal_type].append(d) 63 | for goal_type in type_predict_golden: 64 | print(goal_type,len(type_predict_golden[goal_type])) 65 | print([float('%.2f' % (x*100)) for x in calculateF1(type_predict_golden[goal_type])]) 66 | intents = ['Inform', 'Request', 'General', 'Recommend', 'Select', 'NoOffer'] 67 | domains = ['景点', '酒店', '餐馆', '出租', '地铁', 'General'] 68 | intent_predict_golden = dict.fromkeys(intents) 69 | domain_predict_golden = dict.fromkeys(domains) 70 | for intent in intents: 71 | intent_predict_golden[intent] = [float('%.2f' % (x*100)) for x in calculateF1(predict_golden,intent=intent)] 72 | for domain in domains: 73 | domain_predict_golden[domain] = [float('%.2f' % (x*100)) for x in calculateF1(predict_golden,domain=domain)] 74 | pprint(intent_predict_golden) 75 | pprint(domain_predict_golden) 76 | -------------------------------------------------------------------------------- /convlab2/nlu/jointBERT/crosswoz/configs/crosswoz_all.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_dir": "crosswoz/data/all_data", 3 | "output_dir": "crosswoz/output/all", 4 | "zipped_model_path": "crosswoz/output/all/bert_crosswoz_all.zip", 5 | "log_dir": "crosswoz/log/all", 6 | "DEVICE": "cuda:0", 7 | "seed": 2019, 8 | "cut_sen_len": 60, 9 | "use_bert_tokenizer": false, 10 | "model": { 11 | "finetune": true, 12 | "context": false, 13 | "context_grad": false, 14 | "pretrained_weights": "hfl/chinese-bert-wwm-ext", 15 | "check_step": 1000, 16 | "max_step": 40000, 17 | "batch_size": 20, 18 | "learning_rate": 3e-5, 19 | "adam_epsilon": 1e-8, 20 | "warmup_steps": 0, 21 | "weight_decay": 0.0, 22 | "dropout": 0.1, 23 | "hidden_units": 768 24 | } 25 | } -------------------------------------------------------------------------------- /convlab2/nlu/jointBERT/crosswoz/configs/crosswoz_all_context.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_dir": "crosswoz/data/all_data", 3 | "output_dir": "crosswoz/output/all_context", 4 | "zipped_model_path": "crosswoz/output/all_context/bert_crosswoz_all_context.zip", 5 | "log_dir": "crosswoz/log/all_context", 6 | "DEVICE": "cuda:1", 7 | "seed": 2019, 8 | "cut_sen_len": 60, 9 | "use_bert_tokenizer": false, 10 | "model": { 11 | "finetune": true, 12 | "context": true, 13 | "context_grad": true, 14 | "pretrained_weights": "hfl/chinese-bert-wwm-ext", 15 | "check_step": 1000, 16 | "max_step": 40000, 17 | "batch_size": 20, 18 | "learning_rate": 3e-5, 19 | "adam_epsilon": 1e-8, 20 | "warmup_steps": 0, 21 | "weight_decay": 0.0, 22 | "dropout": 0.1, 23 | "hidden_units": 1536 24 | } 25 | } -------------------------------------------------------------------------------- /convlab2/nlu/jointBERT/crosswoz/configs/crosswoz_all_context_fr.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_dir": "crosswoz/data/all_data", 3 | "output_dir": "crosswoz/output/all_context_fr", 4 | "zipped_model_path": "crosswoz/output/all_context_fr/bert_crosswoz_all_context_fr.zip", 5 | "log_dir": "crosswoz/log/all_context_fr", 6 | "DEVICE": "cuda:2", 7 | "seed": 2019, 8 | "cut_sen_len": 60, 9 | "use_bert_tokenizer": false, 10 | "model": { 11 | "finetune": false, 12 | "context": true, 13 | "context_grad": false, 14 | "pretrained_weights": "hfl/chinese-bert-wwm-ext", 15 | "check_step": 1000, 16 | "max_step": 40000, 17 | "batch_size": 100, 18 | "learning_rate": 1e-3, 19 | "adam_epsilon": 1e-8, 20 | "warmup_steps": 0, 21 | "weight_decay": 0.0, 22 | "dropout": 0.1, 23 | "hidden_units": 1536 24 | } 25 | } -------------------------------------------------------------------------------- /convlab2/nlu/jointBERT/crosswoz/configs/crosswoz_all_fr.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_dir": "crosswoz/data/all_data", 3 | "output_dir": "crosswoz/output/all_fr", 4 | "zipped_model_path": "crosswoz/output/all_fr/bert_crosswoz_all_fr.zip", 5 | "log_dir": "crosswoz/log/all_fr", 6 | "DEVICE": "cuda:3", 7 | "seed": 2019, 8 | "cut_sen_len": 60, 9 | "use_bert_tokenizer": false, 10 | "model": { 11 | "finetune": false, 12 | "context": false, 13 | "context_grad": false, 14 | "pretrained_weights": "hfl/chinese-bert-wwm-ext", 15 | "check_step": 1000, 16 | "max_step": 40000, 17 | "batch_size": 100, 18 | "learning_rate": 1e-3, 19 | "adam_epsilon": 1e-8, 20 | "warmup_steps": 0, 21 | "weight_decay": 0.0, 22 | "dropout": 0.1, 23 | "hidden_units": 768 24 | } 25 | } -------------------------------------------------------------------------------- /convlab2/nlu/jointBERT/crosswoz/nlu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | import json 4 | import torch 5 | 6 | from convlab2.util.file_util import cached_path 7 | from convlab2.nlu.nlu import NLU 8 | from convlab2.nlu.jointBERT.dataloader import Dataloader 9 | from convlab2.nlu.jointBERT.jointBERT import JointBERT 10 | from convlab2.nlu.jointBERT.crosswoz.postprocess import recover_intent 11 | from convlab2.nlu.jointBERT.crosswoz.preprocess import preprocess 12 | 13 | 14 | class BERTNLU(NLU): 15 | def __init__(self, mode='all', config_file='crosswoz_all_context.json', 16 | model_file='https://huggingface.co/ConvLab/ConvLab-2_models/resolve/main/bert_crosswoz_all_context.zip'): 17 | assert mode == 'usr' or mode == 'sys' or mode == 'all' 18 | config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs/{}'.format(config_file)) 19 | config = json.load(open(config_file)) 20 | DEVICE = config['DEVICE'] 21 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 22 | data_dir = os.path.join(root_dir, config['data_dir']) 23 | output_dir = os.path.join(root_dir, config['output_dir']) 24 | 25 | if not os.path.exists(os.path.join(data_dir, 'intent_vocab.json')): 26 | preprocess(mode) 27 | 28 | intent_vocab = json.load(open(os.path.join(data_dir, 'intent_vocab.json'))) 29 | tag_vocab = json.load(open(os.path.join(data_dir, 'tag_vocab.json'))) 30 | dataloader = Dataloader(intent_vocab=intent_vocab, tag_vocab=tag_vocab, 31 | pretrained_weights=config['model']['pretrained_weights']) 32 | 33 | print('intent num:', len(intent_vocab)) 34 | print('tag num:', len(tag_vocab)) 35 | 36 | best_model_path = os.path.join(output_dir, 'pytorch_model.bin') 37 | if not os.path.exists(best_model_path): 38 | if not os.path.exists(output_dir): 39 | os.makedirs(output_dir) 40 | print('Load from model_file param') 41 | archive_file = cached_path(model_file) 42 | archive = zipfile.ZipFile(archive_file, 'r') 43 | archive.extractall(root_dir) 44 | archive.close() 45 | print('Load from', best_model_path) 46 | model = JointBERT(config['model'], DEVICE, dataloader.tag_dim, dataloader.intent_dim) 47 | model.load_state_dict(torch.load(os.path.join(output_dir, 'pytorch_model.bin'), DEVICE)) 48 | model.to(DEVICE) 49 | model.eval() 50 | 51 | self.model = model 52 | self.dataloader = dataloader 53 | print("BERTNLU loaded") 54 | 55 | def predict(self, utterance, context=list()): 56 | ori_word_seq = self.dataloader.tokenizer.tokenize(utterance) 57 | ori_tag_seq = ['O'] * len(ori_word_seq) 58 | context_seq = self.dataloader.tokenizer.encode('[CLS] ' + ' [SEP] '.join(context[-3:])) 59 | intents = [] 60 | da = {} 61 | 62 | word_seq, tag_seq, new2ori = ori_word_seq, ori_tag_seq, None 63 | batch_data = [[ori_word_seq, ori_tag_seq, intents, da, context_seq, 64 | new2ori, word_seq, self.dataloader.seq_tag2id(tag_seq), self.dataloader.seq_intent2id(intents)]] 65 | 66 | pad_batch = self.dataloader.pad_batch(batch_data) 67 | pad_batch = tuple(t.to(self.model.device) for t in pad_batch) 68 | word_seq_tensor, tag_seq_tensor, intent_tensor, word_mask_tensor, tag_mask_tensor, context_seq_tensor, context_mask_tensor = pad_batch 69 | slot_logits, intent_logits = self.model.forward(word_seq_tensor, word_mask_tensor, 70 | context_seq_tensor=context_seq_tensor, 71 | context_mask_tensor=context_mask_tensor) 72 | intent = recover_intent(self.dataloader, intent_logits[0], slot_logits[0], tag_mask_tensor[0], 73 | batch_data[0][0], batch_data[0][-4]) 74 | return intent 75 | 76 | 77 | if __name__ == '__main__': 78 | nlu = BERTNLU(mode='all', config_file='crosswoz_all_context.json', 79 | model_file='https://huggingface.co/ConvLab/ConvLab-2_models/resolve/main/bert_crosswoz_all_context.zip') 80 | print(nlu.predict("北京布提克精品酒店酒店是什么类型,有健身房吗?", ['你好,给我推荐一个评分是5分,价格在100-200元的酒店。', '推荐您去北京布提克精品酒店。'])) 81 | -------------------------------------------------------------------------------- /convlab2/nlu/jointBERT/crosswoz/postprocess.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | 4 | 5 | def is_slot_da(da): 6 | tag_da = {'Inform', 'Recommend'} 7 | not_tag_slot = '酒店设施' 8 | if da[0] in tag_da and not_tag_slot not in da[2]: 9 | return True 10 | return False 11 | 12 | 13 | def calculateF1(predict_golden): 14 | TP, FP, FN = 0, 0, 0 15 | for item in predict_golden: 16 | predicts = item['predict'] 17 | labels = item['golden'] 18 | for ele in predicts: 19 | if ele in labels: 20 | TP += 1 21 | else: 22 | FP += 1 23 | for ele in labels: 24 | if ele not in predicts: 25 | FN += 1 26 | # print(TP, FP, FN) 27 | precision = 1.0 * TP / (TP + FP) if TP + FP else 0. 28 | recall = 1.0 * TP / (TP + FN) if TP + FN else 0. 29 | F1 = 2.0 * precision * recall / (precision + recall) if precision + recall else 0. 30 | return precision, recall, F1 31 | 32 | 33 | def tag2das(word_seq, tag_seq): 34 | assert len(word_seq)==len(tag_seq) 35 | das = [] 36 | i = 0 37 | while i < len(tag_seq): 38 | tag = tag_seq[i] 39 | if tag.startswith('B'): 40 | intent, domain, slot = tag[2:].split('+') 41 | value = word_seq[i] 42 | j = i + 1 43 | while j < len(tag_seq): 44 | if tag_seq[j].startswith('I') and tag_seq[j][2:] == tag[2:]: 45 | # tag_seq[j][2:].split('+')[-1]==slot or tag_seq[j][2:] == tag[2:] 46 | if word_seq[j].startswith('##'): 47 | value += word_seq[j][2:] 48 | else: 49 | value += word_seq[j] 50 | i += 1 51 | j += 1 52 | else: 53 | break 54 | das.append([intent, domain, slot, value]) 55 | i += 1 56 | return das 57 | 58 | 59 | def intent2das(intent_seq): 60 | triples = [] 61 | for intent in intent_seq: 62 | intent, domain, slot, value = re.split('\+', intent) 63 | triples.append([intent, domain, slot, value]) 64 | return triples 65 | 66 | 67 | def recover_intent(dataloader, intent_logits, tag_logits, tag_mask_tensor, ori_word_seq, new2ori): 68 | # tag_logits = [sequence_length, tag_dim] 69 | # intent_logits = [intent_dim] 70 | # tag_mask_tensor = [sequence_length] 71 | max_seq_len = tag_logits.size(0) 72 | das = [] 73 | for j in range(dataloader.intent_dim): 74 | if intent_logits[j] > 0: 75 | intent, domain, slot, value = re.split('\+', dataloader.id2intent[j]) 76 | das.append([intent, domain, slot, value]) 77 | tags = [] 78 | for j in range(1 , max_seq_len -1): 79 | if tag_mask_tensor[j] == 1: 80 | value, tag_id = torch.max(tag_logits[j], dim=-1) 81 | tags.append(dataloader.id2tag[tag_id.item()]) 82 | tag_intent = tag2das(ori_word_seq, tags) 83 | das += tag_intent 84 | return das 85 | -------------------------------------------------------------------------------- /convlab2/nlu/jointBERT/crosswoz/preprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import zipfile 4 | import sys 5 | from collections import Counter 6 | from transformers import BertTokenizer 7 | 8 | 9 | def read_zipped_json(filepath, filename): 10 | archive = zipfile.ZipFile(filepath, 'r') 11 | return json.load(archive.open(filename)) 12 | 13 | 14 | def preprocess(mode): 15 | assert mode == 'all' or mode == 'usr' or mode == 'sys' 16 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 17 | data_dir = os.path.join(cur_dir, '../../../../data/crosswoz') 18 | processed_data_dir = os.path.join(cur_dir, 'data/{}_data'.format(mode)) 19 | if not os.path.exists(processed_data_dir): 20 | os.makedirs(processed_data_dir) 21 | data_key = ['train', 'val', 'test'] 22 | data = {} 23 | for key in data_key: 24 | data[key] = read_zipped_json(os.path.join(data_dir, key + '.json.zip'), key + '.json') 25 | print('load {}, size {}'.format(key, len(data[key]))) 26 | 27 | processed_data = {} 28 | all_intent = [] 29 | all_tag = [] 30 | 31 | context_size = 3 32 | 33 | tokenizer = BertTokenizer.from_pretrained("hfl/chinese-bert-wwm-ext") 34 | 35 | for key in data_key: 36 | processed_data[key] = [] 37 | for no, sess in data[key].items(): 38 | context = [] 39 | for i, turn in enumerate(sess['messages']): 40 | if mode == 'usr' and turn['role'] == 'sys': 41 | context.append(turn['content']) 42 | continue 43 | elif mode == 'sys' and turn['role'] == 'usr': 44 | context.append(turn['content']) 45 | continue 46 | utterance = turn['content'] 47 | # Notice: ## prefix, space remove 48 | tokens = tokenizer.tokenize(utterance) 49 | golden = [] 50 | 51 | span_info = [] 52 | intents = [] 53 | for intent, domain, slot, value in turn['dialog_act']: 54 | if intent in ['Inform', 'Recommend'] and '酒店设施' not in slot: 55 | if value in utterance: 56 | idx = utterance.index(value) 57 | idx = len(tokenizer.tokenize(utterance[:idx])) 58 | span_info.append(('+'.join([intent,domain,slot]),idx,idx+len(tokenizer.tokenize(value)), value)) 59 | token_v = ''.join(tokens[idx:idx+len(tokenizer.tokenize(value))]) 60 | # if token_v != value: 61 | # print(slot, token_v, value) 62 | token_v = token_v.replace('##', '') 63 | golden.append([intent, domain, slot, token_v]) 64 | else: 65 | golden.append([intent, domain, slot, value]) 66 | else: 67 | intents.append('+'.join([intent, domain, slot, value])) 68 | golden.append([intent, domain, slot, value]) 69 | 70 | tags = [] 71 | for j, _ in enumerate(tokens): 72 | for span in span_info: 73 | if j == span[1]: 74 | tag = "B+" + span[0] 75 | tags.append(tag) 76 | break 77 | if span[1] < j < span[2]: 78 | tag = "I+" + span[0] 79 | tags.append(tag) 80 | break 81 | else: 82 | tags.append("O") 83 | 84 | processed_data[key].append([tokens, tags, intents, golden, context[-context_size:]]) 85 | 86 | all_intent += intents 87 | all_tag += tags 88 | 89 | context.append(turn['content']) 90 | 91 | all_intent = [x[0] for x in dict(Counter(all_intent)).items()] 92 | all_tag = [x[0] for x in dict(Counter(all_tag)).items()] 93 | print('loaded {}, size {}'.format(key, len(processed_data[key]))) 94 | json.dump(processed_data[key], open(os.path.join(processed_data_dir, '{}_data.json'.format(key)), 'w', encoding='utf-8'), 95 | indent=2, ensure_ascii=False) 96 | 97 | print('sentence label num:', len(all_intent)) 98 | print('tag num:', len(all_tag)) 99 | print(all_intent) 100 | json.dump(all_intent, open(os.path.join(processed_data_dir, 'intent_vocab.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False) 101 | json.dump(all_tag, open(os.path.join(processed_data_dir, 'tag_vocab.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False) 102 | 103 | 104 | if __name__ == '__main__': 105 | mode = sys.argv[1] 106 | preprocess(mode) 107 | -------------------------------------------------------------------------------- /convlab2/nlu/jointBERT/jointBERT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers import BertModel 4 | 5 | 6 | class JointBERT(nn.Module): 7 | def __init__(self, model_config, device, slot_dim, intent_dim, intent_weight=None): 8 | super(JointBERT, self).__init__() 9 | self.slot_num_labels = slot_dim 10 | self.intent_num_labels = intent_dim 11 | self.device = device 12 | self.intent_weight = intent_weight if intent_weight is not None else torch.tensor([1.]*intent_dim) 13 | 14 | print(model_config['pretrained_weights']) 15 | self.bert = BertModel.from_pretrained(model_config['pretrained_weights']) 16 | self.dropout = nn.Dropout(model_config['dropout']) 17 | self.context = model_config['context'] 18 | self.finetune = model_config['finetune'] 19 | self.context_grad = model_config['context_grad'] 20 | self.hidden_units = model_config['hidden_units'] 21 | if self.hidden_units > 0: 22 | if self.context: 23 | self.intent_classifier = nn.Linear(self.hidden_units, self.intent_num_labels) 24 | self.slot_classifier = nn.Linear(self.hidden_units, self.slot_num_labels) 25 | self.intent_hidden = nn.Linear(2 * self.bert.config.hidden_size, self.hidden_units) 26 | self.slot_hidden = nn.Linear(2 * self.bert.config.hidden_size, self.hidden_units) 27 | else: 28 | self.intent_classifier = nn.Linear(self.hidden_units, self.intent_num_labels) 29 | self.slot_classifier = nn.Linear(self.hidden_units, self.slot_num_labels) 30 | self.intent_hidden = nn.Linear(self.bert.config.hidden_size, self.hidden_units) 31 | self.slot_hidden = nn.Linear(self.bert.config.hidden_size, self.hidden_units) 32 | nn.init.xavier_uniform_(self.intent_hidden.weight) 33 | nn.init.xavier_uniform_(self.slot_hidden.weight) 34 | else: 35 | if self.context: 36 | self.intent_classifier = nn.Linear(2 * self.bert.config.hidden_size, self.intent_num_labels) 37 | self.slot_classifier = nn.Linear(2 * self.bert.config.hidden_size, self.slot_num_labels) 38 | else: 39 | self.intent_classifier = nn.Linear(self.bert.config.hidden_size, self.intent_num_labels) 40 | self.slot_classifier = nn.Linear(self.bert.config.hidden_size, self.slot_num_labels) 41 | nn.init.xavier_uniform_(self.intent_classifier.weight) 42 | nn.init.xavier_uniform_(self.slot_classifier.weight) 43 | 44 | self.intent_loss_fct = torch.nn.BCEWithLogitsLoss(pos_weight=self.intent_weight) 45 | self.slot_loss_fct = torch.nn.CrossEntropyLoss() 46 | 47 | def forward(self, word_seq_tensor, word_mask_tensor, tag_seq_tensor=None, tag_mask_tensor=None, 48 | intent_tensor=None, context_seq_tensor=None, context_mask_tensor=None): 49 | if not self.finetune: 50 | self.bert.eval() 51 | with torch.no_grad(): 52 | outputs = self.bert(input_ids=word_seq_tensor, 53 | attention_mask=word_mask_tensor) 54 | else: 55 | outputs = self.bert(input_ids=word_seq_tensor, 56 | attention_mask=word_mask_tensor) 57 | 58 | sequence_output = outputs[0] 59 | pooled_output = outputs[1] 60 | 61 | if self.context and (context_seq_tensor is not None): 62 | if not self.finetune or not self.context_grad: 63 | with torch.no_grad(): 64 | context_output = self.bert(input_ids=context_seq_tensor, attention_mask=context_mask_tensor)[1] 65 | else: 66 | context_output = self.bert(input_ids=context_seq_tensor, attention_mask=context_mask_tensor)[1] 67 | sequence_output = torch.cat( 68 | [context_output.unsqueeze(1).repeat(1, sequence_output.size(1), 1), 69 | sequence_output], dim=-1) 70 | pooled_output = torch.cat([context_output, pooled_output], dim=-1) 71 | 72 | if self.hidden_units > 0: 73 | sequence_output = nn.functional.relu(self.slot_hidden(self.dropout(sequence_output))) 74 | pooled_output = nn.functional.relu(self.intent_hidden(self.dropout(pooled_output))) 75 | 76 | sequence_output = self.dropout(sequence_output) 77 | slot_logits = self.slot_classifier(sequence_output) 78 | outputs = (slot_logits,) 79 | 80 | pooled_output = self.dropout(pooled_output) 81 | intent_logits = self.intent_classifier(pooled_output) 82 | outputs = outputs + (intent_logits,) 83 | 84 | if tag_seq_tensor is not None: 85 | active_tag_loss = tag_mask_tensor.view(-1) == 1 86 | active_tag_logits = slot_logits.view(-1, self.slot_num_labels)[active_tag_loss] 87 | active_tag_labels = tag_seq_tensor.view(-1)[active_tag_loss] 88 | slot_loss = self.slot_loss_fct(active_tag_logits, active_tag_labels) 89 | 90 | outputs = outputs + (slot_loss,) 91 | 92 | if intent_tensor is not None: 93 | intent_loss = self.intent_loss_fct(intent_logits, intent_tensor) 94 | outputs = outputs + (intent_loss,) 95 | 96 | return outputs # slot_logits, intent_logits, (slot_loss), (intent_loss), 97 | -------------------------------------------------------------------------------- /convlab2/nlu/jointBERT/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import random 5 | import numpy as np 6 | import torch 7 | from convlab2.nlu.jointBERT.dataloader import Dataloader 8 | from convlab2.nlu.jointBERT.jointBERT import JointBERT 9 | 10 | 11 | def set_seed(seed): 12 | random.seed(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | 16 | 17 | parser = argparse.ArgumentParser(description="Test a model.") 18 | parser.add_argument('--config_path', 19 | help='path to config file') 20 | 21 | 22 | if __name__ == '__main__': 23 | args = parser.parse_args() 24 | config = json.load(open(args.config_path)) 25 | data_dir = config['data_dir'] 26 | output_dir = config['output_dir'] 27 | log_dir = config['log_dir'] 28 | DEVICE = config['DEVICE'] 29 | 30 | set_seed(config['seed']) 31 | 32 | if 'multiwoz' in data_dir: 33 | print('-'*20 + 'dataset:multiwoz' + '-'*20) 34 | from convlab2.nlu.jointBERT.multiwoz.postprocess import is_slot_da, calculateF1, recover_intent 35 | elif 'camrest' in data_dir: 36 | print('-' * 20 + 'dataset:camrest' + '-' * 20) 37 | from convlab2.nlu.jointBERT.camrest.postprocess import is_slot_da, calculateF1, recover_intent 38 | elif 'crosswoz' in data_dir: 39 | print('-' * 20 + 'dataset:crosswoz' + '-' * 20) 40 | from convlab2.nlu.jointBERT.crosswoz.postprocess import is_slot_da, calculateF1, recover_intent 41 | 42 | intent_vocab = json.load(open(os.path.join(data_dir, 'intent_vocab.json'))) 43 | tag_vocab = json.load(open(os.path.join(data_dir, 'tag_vocab.json'))) 44 | dataloader = Dataloader(intent_vocab=intent_vocab, tag_vocab=tag_vocab, 45 | pretrained_weights=config['model']['pretrained_weights']) 46 | print('intent num:', len(intent_vocab)) 47 | print('tag num:', len(tag_vocab)) 48 | for data_key in ['val', 'test']: 49 | dataloader.load_data(json.load(open(os.path.join(data_dir, '{}_data.json'.format(data_key)))), data_key, 50 | cut_sen_len=0, use_bert_tokenizer=config['use_bert_tokenizer']) 51 | print('{} set size: {}'.format(data_key, len(dataloader.data[data_key]))) 52 | 53 | if not os.path.exists(output_dir): 54 | os.makedirs(output_dir) 55 | if not os.path.exists(log_dir): 56 | os.makedirs(log_dir) 57 | 58 | model = JointBERT(config['model'], DEVICE, dataloader.tag_dim, dataloader.intent_dim) 59 | model.load_state_dict(torch.load(os.path.join(output_dir, 'pytorch_model.bin'), DEVICE)) 60 | model.to(DEVICE) 61 | model.eval() 62 | 63 | batch_size = config['model']['batch_size'] 64 | 65 | data_key = 'test' 66 | predict_golden = {'intent': [], 'slot': [], 'overall': []} 67 | slot_loss, intent_loss = 0, 0 68 | for pad_batch, ori_batch, real_batch_size in dataloader.yield_batches(batch_size, data_key=data_key): 69 | pad_batch = tuple(t.to(DEVICE) for t in pad_batch) 70 | word_seq_tensor, tag_seq_tensor, intent_tensor, word_mask_tensor, tag_mask_tensor, context_seq_tensor, context_mask_tensor = pad_batch 71 | if not config['model']['context']: 72 | context_seq_tensor, context_mask_tensor = None, None 73 | 74 | with torch.no_grad(): 75 | slot_logits, intent_logits, batch_slot_loss, batch_intent_loss = model.forward(word_seq_tensor, 76 | word_mask_tensor, 77 | tag_seq_tensor, 78 | tag_mask_tensor, 79 | intent_tensor, 80 | context_seq_tensor, 81 | context_mask_tensor) 82 | slot_loss += batch_slot_loss.item() * real_batch_size 83 | intent_loss += batch_intent_loss.item() * real_batch_size 84 | for j in range(real_batch_size): 85 | predicts = recover_intent(dataloader, intent_logits[j], slot_logits[j], tag_mask_tensor[j], 86 | ori_batch[j][0], ori_batch[j][-4]) 87 | labels = ori_batch[j][3] 88 | 89 | predict_golden['overall'].append({ 90 | 'predict': predicts, 91 | 'golden': labels 92 | }) 93 | predict_golden['slot'].append({ 94 | 'predict': [x for x in predicts if is_slot_da(x)], 95 | 'golden': [x for x in labels if is_slot_da(x)] 96 | }) 97 | predict_golden['intent'].append({ 98 | 'predict': [x for x in predicts if not is_slot_da(x)], 99 | 'golden': [x for x in labels if not is_slot_da(x)] 100 | }) 101 | print('[%d|%d] samples' % (len(predict_golden['overall']), len(dataloader.data[data_key]))) 102 | 103 | total = len(dataloader.data[data_key]) 104 | slot_loss /= total 105 | intent_loss /= total 106 | print('%d samples %s' % (total, data_key)) 107 | print('\t slot loss:', slot_loss) 108 | print('\t intent loss:', intent_loss) 109 | 110 | for x in ['intent', 'slot', 'overall']: 111 | precision, recall, F1 = calculateF1(predict_golden[x]) 112 | print('-' * 20 + x + '-' * 20) 113 | print('\t Precision: %.2f' % (100 * precision)) 114 | print('\t Recall: %.2f' % (100 * recall)) 115 | print('\t F1: %.2f' % (100 * F1)) 116 | 117 | output_file = os.path.join(output_dir, 'output.json') 118 | json.dump(predict_golden['overall'], open(output_file, 'w', encoding='utf-8'), indent=2, ensure_ascii=False) 119 | -------------------------------------------------------------------------------- /convlab2/nlu/nlu.py: -------------------------------------------------------------------------------- 1 | """Natural language understanding interface.""" 2 | from convlab2.util.module import Module 3 | 4 | 5 | class NLU(Module): 6 | """NLU module interface.""" 7 | 8 | def predict(self, utterance, context=list()): 9 | """Predict the dialog act of a natural language utterance. 10 | 11 | Args: 12 | utterance (string): 13 | A natural language utterance. 14 | context (list of string): 15 | Previous utterances. 16 | 17 | Returns: 18 | action (list of list): 19 | The dialog act of utterance. 20 | """ 21 | return [] 22 | -------------------------------------------------------------------------------- /convlab2/policy/README.md: -------------------------------------------------------------------------------- 1 | # Dialog Policy 2 | 3 | In the pipeline task-oriented dialog framework, the dialog policy module 4 | takes as input the dialog state, and chooses the system action bases on 5 | it. 6 | 7 | This directory contains the interface definition of dialog policy 8 | module for both system side and user simulator side, as well as some 9 | implementations under different sub-directories. 10 | 11 | ## Interface 12 | 13 | The interfaces for dialog policy are defined in policy.Policy: 14 | 15 | - **predict** takes as input agent state (often the state tracked by DST) 16 | and outputs the next system action. 17 | 18 | - **init_session** reset the model variables for a new dialog session. 19 | -------------------------------------------------------------------------------- /convlab2/policy/__init__.py: -------------------------------------------------------------------------------- 1 | from convlab2.policy.policy import Policy -------------------------------------------------------------------------------- /convlab2/policy/mle/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/convlab2/policy/mle/__init__.py -------------------------------------------------------------------------------- /convlab2/policy/mle/crosswoz/README.md: -------------------------------------------------------------------------------- 1 | # Imitation on CrossWOZ 2 | 3 | Vanilla MLE Policy employs a multi-class classification via Imitation Learning with a set of compositional actions where a compositional action consists of a set of dialog act items. 4 | 5 | ## Train 6 | 7 | ``` 8 | python train.py 9 | ``` 10 | 11 | You can modify *config.json* to change the setting. 12 | 13 | ## Data 14 | 15 | data/crosswoz/[train/val/test].json 16 | -------------------------------------------------------------------------------- /convlab2/policy/mle/crosswoz/__init__.py: -------------------------------------------------------------------------------- 1 | from convlab2.policy.mle.crosswoz.mle import MLE -------------------------------------------------------------------------------- /convlab2/policy/mle/crosswoz/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "batchsz": 32, 3 | "epoch": 20, 4 | "lr": 0.001, 5 | "save_dir": "save", 6 | "log_dir": "log", 7 | "print_per_batch": 400, 8 | "save_per_epoch": 5, 9 | "h_dim": 100, 10 | "load": "save/best" 11 | } -------------------------------------------------------------------------------- /convlab2/policy/mle/crosswoz/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import zipfile 5 | import torch 6 | import torch.utils.data as data 7 | from convlab2.util.crosswoz.state import default_state 8 | from convlab2.dst.rule.crosswoz.dst import RuleDST 9 | from convlab2.policy.vector.vector_crosswoz import CrossWozVector 10 | from copy import deepcopy 11 | 12 | 13 | class PolicyDataLoaderCrossWoz(): 14 | 15 | def __init__(self): 16 | root_dir = os.path.dirname( 17 | os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) 18 | voc_file = os.path.join(root_dir, 'data/crosswoz/sys_da_voc.json') 19 | voc_opp_file = os.path.join(root_dir, 'data/crosswoz/usr_da_voc.json') 20 | self.vector = CrossWozVector(sys_da_voc_json=voc_file, usr_da_voc_json=voc_opp_file) 21 | 22 | processed_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'processed_data') 23 | if os.path.exists(processed_dir): 24 | print('Load processed data file') 25 | self._load_data(processed_dir) 26 | else: 27 | print('Start preprocessing the dataset') 28 | self._build_data(root_dir, processed_dir) 29 | 30 | def _build_data(self, root_dir, processed_dir): 31 | raw_data = {} 32 | for part in ['train', 'val', 'test']: 33 | archive = zipfile.ZipFile(os.path.join(root_dir, 'data/crosswoz/{}.json.zip'.format(part)), 'r') 34 | with archive.open('{}.json'.format(part), 'r') as f: 35 | raw_data[part] = json.load(f) 36 | 37 | self.data = {} 38 | # for cur domain update 39 | dst = RuleDST() 40 | for part in ['train', 'val', 'test']: 41 | self.data[part] = [] 42 | 43 | for key in raw_data[part]: 44 | sess = raw_data[part][key]['messages'] 45 | dst.init_session() 46 | for i, turn in enumerate(sess): 47 | if turn['role'] == 'usr': 48 | dst.update(usr_da=turn['dialog_act']) 49 | if i + 2 == len(sess): 50 | dst.state['terminated'] = True 51 | else: 52 | for domain, svs in turn['sys_state'].items(): 53 | for slot, value in svs.items(): 54 | if slot != 'selectedResults': 55 | dst.state['belief_state'][domain][slot] = value 56 | action = turn['dialog_act'] 57 | self.data[part].append([self.vector.state_vectorize(deepcopy(dst.state)), 58 | self.vector.action_vectorize(action)]) 59 | dst.state['system_action'] = turn['dialog_act'] 60 | 61 | os.makedirs(processed_dir) 62 | for part in ['train', 'val', 'test']: 63 | with open(os.path.join(processed_dir, '{}.pkl'.format(part)), 'wb') as f: 64 | pickle.dump(self.data[part], f) 65 | 66 | def _load_data(self, processed_dir): 67 | self.data = {} 68 | for part in ['train', 'val', 'test']: 69 | with open(os.path.join(processed_dir, '{}.pkl'.format(part)), 'rb') as f: 70 | self.data[part] = pickle.load(f) 71 | 72 | def create_dataset(self, part, batchsz): 73 | print('Start creating {} dataset'.format(part)) 74 | s = [] 75 | a = [] 76 | for item in self.data[part]: 77 | s.append(torch.Tensor(item[0])) 78 | a.append(torch.Tensor(item[1])) 79 | s = torch.stack(s) 80 | a = torch.stack(a) 81 | dataset = Dataset(s, a) 82 | dataloader = data.DataLoader(dataset, batchsz, True) 83 | print('Finish creating {} dataset'.format(part)) 84 | return dataloader 85 | 86 | 87 | class Dataset(data.Dataset): 88 | def __init__(self, s_s, a_s): 89 | self.s_s = s_s 90 | self.a_s = a_s 91 | self.num_total = len(s_s) 92 | 93 | def __getitem__(self, index): 94 | s = self.s_s[index] 95 | a = self.a_s[index] 96 | return s, a 97 | 98 | def __len__(self): 99 | return self.num_total 100 | 101 | -------------------------------------------------------------------------------- /convlab2/policy/mle/crosswoz/mle.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import os 4 | import json 5 | import zipfile 6 | from convlab2.util.file_util import cached_path 7 | from convlab2.policy.mle.mle import MLEAbstract 8 | from convlab2.policy.rlmodule import MultiDiscretePolicy 9 | from convlab2.policy.vector.vector_crosswoz import CrossWozVector 10 | 11 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | 13 | DEFAULT_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models") 14 | DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "mle_policy_crosswoz.zip") 15 | 16 | 17 | class MLE(MLEAbstract): 18 | 19 | def __init__(self, 20 | archive_file=DEFAULT_ARCHIVE_FILE, 21 | model_file='https://huggingface.co/ConvLab/ConvLab-2_models/resolve/main/mle_policy_crosswoz.zip'): 22 | root_dir = os.path.dirname( 23 | os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) 24 | 25 | with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f: 26 | cfg = json.load(f) 27 | 28 | voc_file = os.path.join(root_dir, 'data/crosswoz/sys_da_voc.json') 29 | voc_opp_file = os.path.join(root_dir, 'data/crosswoz/usr_da_voc.json') 30 | self.vector = CrossWozVector(sys_da_voc_json=voc_file, usr_da_voc_json=voc_opp_file) 31 | 32 | self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.sys_da_dim).to(device=DEVICE) 33 | 34 | if not os.path.isfile(archive_file): 35 | if not model_file: 36 | raise Exception("No model for MLE Policy is specified!") 37 | archive_file = cached_path(model_file) 38 | model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'save') 39 | if not os.path.exists(model_dir): 40 | os.mkdir(model_dir) 41 | if not os.path.exists(os.path.join(model_dir, 'best_mle.pol.mdl')): 42 | archive = zipfile.ZipFile(archive_file, 'r') 43 | archive.extractall(model_dir) 44 | self.load(archive_file, model_file, cfg['load']) 45 | -------------------------------------------------------------------------------- /convlab2/policy/mle/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import torch 4 | import torch.utils.data as data 5 | from convlab2.util.multiwoz.state import default_state 6 | from convlab2.policy.vector.dataset import ActDataset 7 | from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader 8 | from convlab2.util.dataloader.module_dataloader import ActPolicyDataloader 9 | 10 | class ActMLEPolicyDataLoader(): 11 | 12 | def __init__(self): 13 | self.vector = None 14 | 15 | def _build_data(self, root_dir, processed_dir): 16 | self.data = {} 17 | data_loader = ActPolicyDataloader(dataset_dataloader=MultiWOZDataloader()) 18 | for part in ['train', 'val', 'test']: 19 | self.data[part] = [] 20 | raw_data = data_loader.load_data(data_key=part, role='system')[part] 21 | 22 | for belief_state, context_dialog_act, terminated, dialog_act in \ 23 | zip(raw_data['belief_state'], raw_data['context_dialog_act'], raw_data['terminated'], raw_data['dialog_act']): 24 | state = default_state() 25 | state['belief_state'] = belief_state 26 | state['user_action'] = context_dialog_act[-1] 27 | state['system_action'] = context_dialog_act[-2] if len(context_dialog_act) > 1 else {} 28 | state['terminated'] = terminated 29 | action = dialog_act 30 | self.data[part].append([self.vector.state_vectorize(state), 31 | self.vector.action_vectorize(action)]) 32 | 33 | os.makedirs(processed_dir) 34 | for part in ['train', 'val', 'test']: 35 | with open(os.path.join(processed_dir, '{}.pkl'.format(part)), 'wb') as f: 36 | pickle.dump(self.data[part], f) 37 | 38 | def _load_data(self, processed_dir): 39 | self.data = {} 40 | for part in ['train', 'val', 'test']: 41 | with open(os.path.join(processed_dir, '{}.pkl'.format(part)), 'rb') as f: 42 | self.data[part] = pickle.load(f) 43 | 44 | def create_dataset(self, part, batchsz): 45 | print('Start creating {} dataset'.format(part)) 46 | s = [] 47 | a = [] 48 | for item in self.data[part]: 49 | s.append(torch.Tensor(item[0])) 50 | a.append(torch.Tensor(item[1])) 51 | s = torch.stack(s) 52 | a = torch.stack(a) 53 | dataset = ActDataset(s, a) 54 | dataloader = data.DataLoader(dataset, batchsz, True) 55 | print('Finish creating {} dataset'.format(part)) 56 | return dataloader 57 | -------------------------------------------------------------------------------- /convlab2/policy/mle/mle.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import os 4 | import zipfile 5 | from convlab2.policy.policy import Policy 6 | from convlab2.util.file_util import cached_path 7 | 8 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | class MLEAbstract(Policy): 12 | 13 | def __init__(self, archive_file, model_file): 14 | self.vector = None 15 | self.policy = None 16 | 17 | def predict(self, state): 18 | """ 19 | Predict an system action given state. 20 | Args: 21 | state (dict): Dialog state. Please refer to util/state.py 22 | Returns: 23 | action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...}) 24 | """ 25 | s_vec = torch.Tensor(self.vector.state_vectorize(state)) 26 | a = self.policy.select_action(s_vec.to(device=DEVICE), False).cpu() 27 | action = self.vector.action_devectorize(a.numpy()) 28 | state['system_action'] = action 29 | return action 30 | 31 | def init_session(self): 32 | """ 33 | Restore after one session 34 | """ 35 | pass 36 | 37 | def load(self, archive_file, model_file, filename): 38 | if not os.path.isfile(archive_file): 39 | if not model_file: 40 | raise Exception("No model for MLE Policy is specified!") 41 | archive_file = cached_path(model_file) 42 | model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'save') 43 | if not os.path.exists(model_dir): 44 | os.mkdir(model_dir) 45 | if not os.path.exists(os.path.join(model_dir, 'best_mle.pol.mdl')): 46 | archive = zipfile.ZipFile(archive_file, 'r') 47 | archive.extractall(model_dir) 48 | 49 | policy_mdl = os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '_mle.pol.mdl') 50 | if os.path.exists(policy_mdl): 51 | self.policy.load_state_dict(torch.load(policy_mdl, map_location=DEVICE)) 52 | print('<> loaded checkpoint from file: {}'.format(policy_mdl)) 53 | -------------------------------------------------------------------------------- /convlab2/policy/mle/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | import torch.nn as nn 5 | 6 | from convlab2.util.train_util import to_device 7 | 8 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | class MLE_Trainer_Abstract(): 11 | def __init__(self, manager, cfg): 12 | self._init_data(manager, cfg) 13 | self.policy = None 14 | self.policy_optim = None 15 | 16 | def _init_data(self, manager, cfg): 17 | self.data_train = manager.create_dataset('train', cfg['batchsz']) 18 | self.data_valid = manager.create_dataset('val', cfg['batchsz']) 19 | self.data_test = manager.create_dataset('test', cfg['batchsz']) 20 | self.save_dir = cfg['save_dir'] 21 | self.print_per_batch = cfg['print_per_batch'] 22 | self.save_per_epoch = cfg['save_per_epoch'] 23 | self.multi_entropy_loss = nn.MultiLabelSoftMarginLoss() 24 | 25 | def policy_loop(self, data): 26 | s, target_a = to_device(data) 27 | a_weights = self.policy(s) 28 | 29 | loss_a = self.multi_entropy_loss(a_weights, target_a) 30 | return loss_a 31 | 32 | def imitating(self, epoch): 33 | """ 34 | pretrain the policy by simple imitation learning (behavioral cloning) 35 | """ 36 | self.policy.train() 37 | a_loss = 0. 38 | for i, data in enumerate(self.data_train): 39 | self.policy_optim.zero_grad() 40 | loss_a = self.policy_loop(data) 41 | a_loss += loss_a.item() 42 | loss_a.backward() 43 | self.policy_optim.step() 44 | 45 | if (i+1) % self.print_per_batch == 0: 46 | a_loss /= self.print_per_batch 47 | logging.debug('<> epoch {}, iter {}, loss_a:{}'.format(epoch, i, a_loss)) 48 | a_loss = 0. 49 | 50 | if (epoch+1) % self.save_per_epoch == 0: 51 | self.save(self.save_dir, epoch) 52 | self.policy.eval() 53 | 54 | def imit_test(self, epoch, best): 55 | """ 56 | provide an unbiased evaluation of the policy fit on the training dataset 57 | """ 58 | a_loss = 0. 59 | for i, data in enumerate(self.data_valid): 60 | loss_a = self.policy_loop(data) 61 | a_loss += loss_a.item() 62 | 63 | a_loss /= len(self.data_valid) 64 | logging.debug('<> validation, epoch {}, loss_a:{}'.format(epoch, a_loss)) 65 | if a_loss < best: 66 | logging.info('<> best model saved') 67 | best = a_loss 68 | self.save(self.save_dir, 'best') 69 | 70 | a_loss = 0. 71 | for i, data in enumerate(self.data_test): 72 | loss_a = self.policy_loop(data) 73 | a_loss += loss_a.item() 74 | 75 | a_loss /= len(self.data_test) 76 | logging.debug('<> test, epoch {}, loss_a:{}'.format(epoch, a_loss)) 77 | return best 78 | 79 | def test(self): 80 | def f1(a, target): 81 | TP, FP, FN = 0, 0, 0 82 | real = target.nonzero().tolist() 83 | predict = a.nonzero().tolist() 84 | for item in real: 85 | if item in predict: 86 | TP += 1 87 | else: 88 | FN += 1 89 | for item in predict: 90 | if item not in real: 91 | FP += 1 92 | return TP, FP, FN 93 | 94 | a_TP, a_FP, a_FN = 0, 0, 0 95 | for i, data in enumerate(self.data_test): 96 | s, target_a = to_device(data) 97 | a_weights = self.policy(s) 98 | a = a_weights.ge(0) 99 | TP, FP, FN = f1(a, target_a) 100 | a_TP += TP 101 | a_FP += FP 102 | a_FN += FN 103 | 104 | prec = a_TP / (a_TP + a_FP) 105 | rec = a_TP / (a_TP + a_FN) 106 | F1 = 2 * prec * rec / (prec + rec) 107 | print(a_TP, a_FP, a_FN, F1) 108 | 109 | def save(self, directory, epoch): 110 | if not os.path.exists(directory): 111 | os.makedirs(directory) 112 | 113 | torch.save(self.policy.state_dict(), directory + '/' + str(epoch) + '_mle.pol.mdl') 114 | 115 | logging.info('<> epoch {}: saved network to mdl'.format(epoch)) 116 | 117 | -------------------------------------------------------------------------------- /convlab2/policy/policy.py: -------------------------------------------------------------------------------- 1 | """Policy Interface""" 2 | from convlab2.util.module import Module 3 | 4 | 5 | class Policy(Module): 6 | """Base class for policy model.""" 7 | 8 | def predict(self, state): 9 | """Predict the next agent action given dialog state. 10 | update state['system_action'] with predict system action 11 | 12 | Args: 13 | state (tuple or dict): 14 | when the DST and Policy module are separated, the type of state is tuple. 15 | else when they are aggregated together, the type of state is dict (dialog act). 16 | Returns: 17 | action (list of list): 18 | The next dialog action. 19 | """ 20 | return [] 21 | -------------------------------------------------------------------------------- /convlab2/policy/rule/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/convlab2/policy/rule/__init__.py -------------------------------------------------------------------------------- /convlab2/policy/rule/crosswoz/__init__.py: -------------------------------------------------------------------------------- 1 | from convlab2.policy.rule.crosswoz.rule_simulator import Simulator 2 | -------------------------------------------------------------------------------- /convlab2/policy/vec.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Vector Interface""" 3 | 4 | 5 | class Vector(): 6 | 7 | def __init__(self): 8 | pass 9 | 10 | def generate_dict(self): 11 | """init the dict for mapping state/action into vector""" 12 | 13 | def state_vectorize(self, state): 14 | """vectorize a state 15 | 16 | Args: 17 | state (tuple): 18 | Dialog state 19 | Returns: 20 | state_vec (np.array): 21 | Dialog state vector 22 | """ 23 | raise NotImplementedError 24 | 25 | def action_devectorize(self, action_vec): 26 | """recover an action 27 | 28 | Args: 29 | action_vec (np.array): 30 | Dialog act vector 31 | Returns: 32 | action (tuple): 33 | Dialog act 34 | """ 35 | raise NotImplementedError 36 | -------------------------------------------------------------------------------- /convlab2/policy/vector/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/convlab2/policy/vector/__init__.py -------------------------------------------------------------------------------- /convlab2/policy/vector/dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | class ActDataset(data.Dataset): 4 | def __init__(self, s_s, a_s): 5 | self.s_s = s_s 6 | self.a_s = a_s 7 | self.num_total = len(s_s) 8 | 9 | def __getitem__(self, index): 10 | s = self.s_s[index] 11 | a = self.a_s[index] 12 | return s, a 13 | 14 | def __len__(self): 15 | return self.num_total 16 | 17 | class ActStateDataset(data.Dataset): 18 | def __init__(self, s_s, a_s, next_s): 19 | self.s_s = s_s 20 | self.a_s = a_s 21 | self.next_s = next_s 22 | self.num_total = len(s_s) 23 | 24 | def __getitem__(self, index): 25 | s = self.s_s[index] 26 | a = self.a_s[index] 27 | next_s = self.next_s[index] 28 | return s, a, next_s 29 | 30 | def __len__(self): 31 | return self.num_total -------------------------------------------------------------------------------- /convlab2/policy/vector/vector_crosswoz.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from convlab2.policy.vec import Vector 5 | from convlab2.util.crosswoz.state import default_state 6 | from convlab2.util.crosswoz.lexicalize import delexicalize_da, lexicalize_da 7 | from convlab2.util.crosswoz.dbquery import Database 8 | 9 | 10 | class CrossWozVector(Vector): 11 | def __init__(self, sys_da_voc_json, usr_da_voc_json): 12 | self.sys_da_voc = json.load(open(sys_da_voc_json)) 13 | self.usr_da_voc = json.load(open(usr_da_voc_json)) 14 | self.database = Database() 15 | 16 | self.generate_dict() 17 | 18 | def generate_dict(self): 19 | self.sys_da2id = dict((a, i) for i, a in enumerate(self.sys_da_voc)) 20 | self.id2sys_da = dict((i, a) for i, a in enumerate(self.sys_da_voc)) 21 | 22 | # 155 23 | self.sys_da_dim = len(self.sys_da_voc) 24 | 25 | 26 | self.usr_da2id = dict((a, i) for i, a in enumerate(self.usr_da_voc)) 27 | self.id2usr_da = dict((i, a) for i, a in enumerate(self.usr_da_voc)) 28 | 29 | # 142 30 | self.usr_da_dim = len(self.usr_da_voc) 31 | 32 | # 26 33 | self.belief_state_dim = 0 34 | for domain, svs in default_state()['belief_state'].items(): 35 | self.belief_state_dim += len(svs) 36 | 37 | self.db_res_dim = 4 38 | 39 | self.state_dim = self.sys_da_dim + self.usr_da_dim + self.belief_state_dim + self.db_res_dim + 1 # terminated 40 | 41 | def state_vectorize(self, state): 42 | self.belief_state = state['belief_state'] 43 | self.cur_domain = state['cur_domain'] 44 | 45 | da = state['user_action'] 46 | da = delexicalize_da(da) 47 | usr_act_vec = np.zeros(self.usr_da_dim) 48 | for a in da: 49 | if a in self.usr_da2id: 50 | usr_act_vec[self.usr_da2id[a]] = 1. 51 | 52 | da = state['system_action'] 53 | da = delexicalize_da(da) 54 | sys_act_vec = np.zeros(self.sys_da_dim) 55 | for a in da: 56 | if a in self.sys_da2id: 57 | sys_act_vec[self.sys_da2id[a]] = 1. 58 | 59 | belief_state_vec = np.zeros(self.belief_state_dim) 60 | i = 0 61 | for domain, svs in state['belief_state'].items(): 62 | for slot, value in svs.items(): 63 | if value: 64 | belief_state_vec[i] = 1. 65 | i += 1 66 | 67 | self.db_res = self.database.query(state['belief_state'], state['cur_domain']) 68 | db_res_num = len(self.db_res) 69 | db_res_vec = np.zeros(4) 70 | if db_res_num == 0: 71 | db_res_vec[0] = 1. 72 | elif db_res_num == 1: 73 | db_res_vec[1] = 1. 74 | elif 1 < db_res_num < 5: 75 | db_res_vec[2] = 1. 76 | else: 77 | db_res_vec[3] = 1. 78 | 79 | terminated = 1. if state['terminated'] else 0. 80 | 81 | # print('state dim', self.state_dim) 82 | state_vec = np.r_[usr_act_vec, sys_act_vec, belief_state_vec, db_res_vec, terminated] 83 | # print('actual state vec dim', len(state_vec)) 84 | return state_vec 85 | 86 | def action_devectorize(self, action_vec): 87 | """ 88 | must call state_vectorize func before 89 | :param action_vec: 90 | :return: 91 | """ 92 | da = [] 93 | for i, idx in enumerate(action_vec): 94 | if idx == 1: 95 | da.append(self.id2sys_da[i]) 96 | lexicalized_da = lexicalize_da(da=da, cur_domain=self.cur_domain, entities=self.db_res) 97 | return lexicalized_da 98 | 99 | def action_vectorize(self, da): 100 | da = delexicalize_da(da) 101 | sys_act_vec = np.zeros(self.sys_da_dim) 102 | for a in da: 103 | if a in self.sys_da2id: 104 | sys_act_vec[self.sys_da2id[a]] = 1. 105 | return sys_act_vec 106 | 107 | 108 | if __name__ == '__main__': 109 | vec = CrossWozVector('../../../data/crosswoz/sys_da_voc.json','../../../data/crosswoz/usr_da_voc.json') 110 | print(vec.sys_da_dim, vec.usr_da_dim, vec.belief_state_dim, vec.db_res_dim, vec.state_dim) 111 | -------------------------------------------------------------------------------- /convlab2/task/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/convlab2/task/__init__.py -------------------------------------------------------------------------------- /convlab2/task/crosswoz/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/convlab2/task/crosswoz/__init__.py -------------------------------------------------------------------------------- /convlab2/task/crosswoz/attraction_generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import random 4 | from collections import Counter 5 | from copy import deepcopy 6 | 7 | import numpy as np 8 | 9 | 10 | class AttractionGenerator: 11 | def __init__(self, database): 12 | self.database = database.values() 13 | self.constraints2prob = { 14 | "名称": 0.1, 15 | "门票": 0.5, 16 | "游玩时间": 0.5, 17 | "评分": 0.5 18 | } 19 | self.constraints2weight = { 20 | "名称": dict.fromkeys([x['名称'] for x in self.database],1), 21 | "门票": {'免费': 10, '不免费': 1, '20元以下': 2, '20-50元': 4, '50-100元': 3, '100-150元': 3, '150-200元': 3, '200元以上': 2}, 22 | "游玩时间": dict(Counter([x['游玩时间'] for x in self.database])), 23 | "评分": {'4分以上': 0.2, '4.5分以上': 0.6, '5分': 0.2} 24 | } 25 | self.min_constraints = 1 26 | self.max_constraints = 3 27 | self.min_require = 1 28 | self.max_require = 3 29 | self.all_attrs = ['名称', '地址', '电话', '门票', '游玩时间', '评分', 30 | '周边景点', '周边餐馆', '周边酒店', 31 | # '官网', '介绍', '开放时间' 32 | ] 33 | 34 | def generate(self, goal_num=0, exist_goal=None, random_seed=None): 35 | name_flag = False 36 | if random_seed: 37 | random.seed(random_seed) 38 | np.random.seed(random_seed) 39 | goal = { 40 | "领域": "景点", 41 | "id": goal_num, 42 | "约束条件": [], 43 | "需求信息": [], 44 | "生成方式": "" 45 | } 46 | # generate method 47 | if exist_goal: 48 | goal['生成方式'] = 'id={}的周边{}'.format(exist_goal["id"], "景点") 49 | goal['约束条件'].append(['名称', '出现在id={}的周边{}里'.format(exist_goal["id"], "景点")]) 50 | name_flag = True 51 | else: 52 | goal['生成方式'] = '单领域生成' 53 | 54 | # generate constraints 55 | random_req = deepcopy(self.all_attrs) 56 | random_req.remove('名称') 57 | # if constraint == name ? 58 | if not exist_goal and random.random() < self.constraints2prob['名称']: 59 | v = self.constraints2weight['名称'] 60 | goal['约束条件'] = [['名称', random.choices(list(v.keys()),list(v.values()))[0]]] 61 | name_flag = True 62 | else: 63 | rest_constraints = list(self.constraints2prob.keys()) 64 | rest_constraints.remove('名称') 65 | random.shuffle(rest_constraints) 66 | # cons_num = random.randint(self.min_constraints, self.max_constraints) 67 | cons_num = random.choices([1, 2, 3], [20, 60, 20])[0] 68 | for k in rest_constraints: 69 | if cons_num > 0: 70 | v = self.constraints2weight[k] 71 | value = random.choices(list(v.keys()), list(v.values()))[0] 72 | goal['约束条件'].append([k, value]) 73 | random_req.remove(k) 74 | cons_num -= 1 75 | else: 76 | break 77 | 78 | # generate required information 79 | if not name_flag: 80 | goal['需求信息'].append(['名称', ""]) 81 | 82 | random.shuffle(random_req) 83 | # req_num = random.randint(self.min_require, self.max_require) 84 | req_num = random.choices([1, 2], [30, 70])[0] 85 | for k in random_req: 86 | if req_num > 0: 87 | goal['需求信息'].append([k,""]) 88 | req_num -= 1 89 | if k == '名称': 90 | name_flag = True 91 | else: 92 | break 93 | 94 | return goal 95 | -------------------------------------------------------------------------------- /convlab2/task/crosswoz/metro_generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import random 3 | 4 | import numpy as np 5 | 6 | 7 | class MetroGenerator: 8 | def generate(self, goal_list, goal_num, random_seed=None): 9 | if random_seed: 10 | random.seed(random_seed) 11 | np.random.seed(random_seed) 12 | goal = { 13 | "领域": "地铁", 14 | "id": goal_num, 15 | "约束条件": [], 16 | "需求信息": [], 17 | "生成方式": "" 18 | } 19 | goal1, goal2 = random.sample(goal_list, k=2) 20 | goal["约束条件"].append(["出发地", "id=%d" % goal1["id"]]) 21 | goal["约束条件"].append(["目的地", "id=%d" % goal2["id"]]) 22 | goal["需求信息"].append(["出发地附近地铁站", ""]) 23 | goal["需求信息"].append(["目的地附近地铁站", ""]) 24 | 25 | if goal1["领域"] == goal2["领域"]: 26 | goal["生成方式"] = "同领域" 27 | else: 28 | goal["生成方式"] = "不同领域" 29 | 30 | return goal 31 | -------------------------------------------------------------------------------- /convlab2/task/crosswoz/reorder.py: -------------------------------------------------------------------------------- 1 | """ 2 | reorder generated goals 3 | """ 4 | from copy import deepcopy 5 | import pprint 6 | import json 7 | import re 8 | import random 9 | 10 | 11 | def goals_reorder(goal_list): 12 | # pprint.pprint(goal_list) 13 | id_old2new = {} 14 | single_ids = [] 15 | cross_ids = {} 16 | move_ids = [] 17 | for goal in goal_list: 18 | if goal['生成方式'] == '单领域生成': 19 | single_ids.append(goal['id']) 20 | # print('单领域生成:',goal['id']) 21 | elif '周边' in goal['生成方式']: 22 | searchObj = re.search(r'id=(\d)', goal['生成方式']) 23 | src = int(searchObj.group(1)) 24 | cross_ids[goal['id']] = src 25 | # print('跨领域生成:', goal['id']) 26 | elif goal['领域'] == '地铁' or goal['领域'] == '出租': 27 | start_id, end_id = 0, 0 28 | for slot, value in goal['约束条件']: 29 | if slot == '出发地': 30 | start_id = int(value[-1]) 31 | elif slot == '目的地': 32 | end_id = int(value[-1]) 33 | else: 34 | assert 0 35 | move_ids.append((goal['id'],start_id,end_id)) 36 | # print('地铁/出租', goal['id']) 37 | else: 38 | assert 0 39 | 40 | # pprint.pprint(goal_list) 41 | 42 | order = [] 43 | for x in single_ids: 44 | order.append(x) 45 | for k, s, e in move_ids[:]: 46 | if s in order and e in order: 47 | order.append(k) 48 | move_ids.remove((k, s, e)) 49 | for tar, src in list(cross_ids.items())[:]: 50 | if src==x: 51 | order.append(tar) 52 | cross_ids.pop(tar) 53 | for k, s, e in move_ids[:]: 54 | if s in order and e in order: 55 | order.append(k) 56 | move_ids.remove((k, s, e)) 57 | # print(order) 58 | assert len(order) == len(goal_list) 59 | id_old2new = dict([(j,i+1) for i,j in enumerate(order)]) 60 | # print(id_old2new) 61 | for goal in goal_list: 62 | goal['id'] = id_old2new[goal['id']] 63 | if '周边' in goal['生成方式']: 64 | searchObj = re.search(r'id=(\d)', goal['生成方式']) 65 | src = int(searchObj.group(1)) 66 | goal['生成方式'] = re.sub('\d', str(id_old2new[src]),goal['生成方式']) 67 | for i in range(len(goal['约束条件'])): 68 | if goal['约束条件'][i][0] == '名称': 69 | assert 'id' in goal['约束条件'][i][1] 70 | goal['约束条件'][i][1] = re.sub('\d', str(id_old2new[src]),goal['约束条件'][i][1]) 71 | 72 | elif goal['领域'] == '地铁' or goal['领域'] == '出租': 73 | start_id = id_old2new[int(goal['约束条件'][0][1][-1])] 74 | end_id = id_old2new[int(goal['约束条件'][1][1][-1])] 75 | start_id, end_id = min(start_id, end_id), max(start_id, end_id) 76 | goal['约束条件'][0][1] = goal['约束条件'][0][1][:-1] + str(start_id) 77 | goal['约束条件'][1][1] = goal['约束条件'][0][1][:-1] + str(end_id) 78 | goal_list = sorted(goal_list,key=lambda x:x['id']) 79 | return goal_list 80 | 81 | 82 | if __name__ == '__main__': 83 | goals = json.load(open('result/goal_4.json', encoding='utf-8')) 84 | for goal in goals: 85 | if goal['timestamp'] == "2019-05-20 11:23:35.702161": 86 | # print(goal) 87 | goals_reorder(goal['goals']) -------------------------------------------------------------------------------- /convlab2/task/crosswoz/restaurant_generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import random 4 | from copy import deepcopy 5 | from collections import Counter 6 | 7 | import numpy as np 8 | 9 | 10 | class RestaurantGenerator: 11 | def __init__(self, database): 12 | self.database = database.values() 13 | self.constraints2prob = { 14 | "名称": 0.1, 15 | "推荐菜": 0.6, 16 | "人均消费": 0.5, 17 | "评分": 0.5 18 | } 19 | self.constraints2weight = { 20 | "名称": dict.fromkeys([x['名称'] for x in self.database], 1), 21 | "推荐菜": {1: 5, 2: 1}, 22 | "人均消费": {"50元以下": 1, 23 | "50-100元": 15, 24 | "100-150元": 15, 25 | "150-500元": 5, 26 | "500-1000元": 2, 27 | "1000元以上": 1 28 | }, 29 | "评分": {'4分以上': 0.2, '4.5分以上': 0.6, '5分': 0.2} 30 | } 31 | self.min_constraints = 1 32 | self.max_constraints = 3 33 | self.min_require = 1 34 | self.max_require = 3 35 | self.order_prob = 0.1 36 | self.twodish_prob = 0.15 37 | self.all_attrs = ['名称', '地址', '电话', '营业时间', '推荐菜', '人均消费', '评分', 38 | '周边景点', '周边餐馆', '周边酒店', 39 | # '交通', '介绍', 40 | ] 41 | self.cooccur = {} # check if the list is empty 42 | for res in self.database: 43 | for dish in res['推荐菜']: 44 | self.cooccur[dish] = self.cooccur.get(dish, set()).union(set(res['推荐菜'])) 45 | self.cooccur[dish].remove(dish) 46 | all_dish = [dish for res in self.database for dish in res['推荐菜']] 47 | all_dish = Counter(all_dish) 48 | for k,v in all_dish.items(): 49 | if v==1: 50 | del self.cooccur[k] 51 | self.time2weight = {} 52 | for hour in range(0, 23): 53 | for minute in [':00', ':30']: 54 | timePoint = str(hour) + minute 55 | if hour in [11, 12, 17, 18]: # 饭点 56 | self.time2weight[timePoint] = 20 57 | elif hour in list(range(0, 7)): # 深夜/清晨 58 | self.time2weight[timePoint] = 1 59 | else: # 白天非饭点 60 | self.time2weight[timePoint] = 5 61 | 62 | def generate(self, goal_num=0, exist_goal=None, random_seed=None): 63 | name_flag = False 64 | if random_seed: 65 | random.seed(random_seed) 66 | np.random.seed(random_seed) 67 | goal = { 68 | "领域": "餐馆", 69 | "id": goal_num, 70 | "约束条件": [], 71 | "需求信息": [], 72 | '预订信息': [], 73 | "生成方式": "" 74 | } 75 | # generate method 76 | if exist_goal: 77 | goal['生成方式'] = 'id={}的周边{}'.format(exist_goal["id"], "餐馆") 78 | goal['约束条件'].append(['名称', '出现在id={}的周边{}里'.format(exist_goal["id"], "餐馆")]) 79 | name_flag = True 80 | else: 81 | goal['生成方式'] = '单领域生成' 82 | # generate constraints 83 | random_req = deepcopy(self.all_attrs) 84 | random_req.remove('名称') 85 | # if constraint == name ? 86 | if not exist_goal and random.random() < self.constraints2prob['名称']: 87 | v = self.constraints2weight['名称'] 88 | goal['约束条件'] = [['名称', random.choices(list(v.keys()), list(v.values()))[0]]] 89 | name_flag = True 90 | 91 | else: 92 | rest_constraints = list(self.constraints2prob.keys()) 93 | rest_constraints.remove('名称') 94 | random.shuffle(rest_constraints) 95 | # cons_num = random.randint(self.min_constraints, self.max_constraints) 96 | cons_num = random.choices([1, 2, 3], [20, 60, 20])[0] 97 | for k in rest_constraints: 98 | if cons_num > 0: 99 | v = self.constraints2weight[k] 100 | if k == '推荐菜': 101 | value = random.choices(list(self.cooccur.keys())) 102 | if random.random() < self.twodish_prob and self.cooccur[value[0]]: 103 | value.append(random.choice(list(self.cooccur[value[0]]))) 104 | else: 105 | value = random.choices(list(v.keys()), list(v.values()))[0] 106 | goal['约束条件'].append([k, value]) 107 | random_req.remove(k) 108 | cons_num -= 1 109 | else: 110 | break 111 | 112 | # generate required information 113 | if not name_flag: 114 | goal['需求信息'].append(['名称', ""]) 115 | 116 | random.shuffle(random_req) 117 | req_num = random.choices([1, 2], [30, 70])[0] 118 | for k in random_req: 119 | if req_num > 0: 120 | goal['需求信息'].append([k, ""]) 121 | req_num -= 1 122 | if k == '名称': 123 | name_flag = True 124 | else: 125 | break 126 | 127 | 128 | 129 | # if random.random() < self.order_prob: 130 | # people_num = random.randint(1, 9) 131 | # week_day = random.choice(['周日', '周一', '周二', '周三', '周四', '周五', '周六', ]) 132 | # book_time = random.choices(list(self.time2weight.keys()), list(self.time2weight.values()))[0] 133 | # goal['预订信息'] = [["人数", people_num], ["日期", week_day], ["时间", book_time]] 134 | # goal['需求信息'].append(["预订订单号", ""]) 135 | 136 | return goal 137 | -------------------------------------------------------------------------------- /convlab2/task/crosswoz/sentence_generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import random 3 | 4 | import numpy as np 5 | 6 | 7 | class SentenceGenerator: 8 | def generate(self, goals, random_seed=None): 9 | sens = [] 10 | if random_seed: 11 | random.seed(random_seed) 12 | np.random.seed(random_seed) 13 | # print(goals) 14 | # for ls in goals: 15 | for goal in goals: 16 | # print(goal) 17 | sen = '' 18 | # if "周边" in goal["生成方式"]: 19 | # sen += goal["生成方式"] + "。" + "通过它的周边推荐," 20 | domain = goal["领域"] 21 | if domain == "酒店": 22 | for constraint in goal["约束条件"]: 23 | if constraint[0] == "名称": 24 | if '周边' in constraint[1]: 25 | origin_id = int(constraint[1].split('id=')[1][0]) 26 | sen += ('你要去id=%d附近的酒店(id=%d)住宿。' % (origin_id, goal['id'])) 27 | else: 28 | sen += ('你要去名叫%s的酒店(id=%d)住宿。' % (constraint[1], goal['id'])) 29 | if sen == '': 30 | sen += "你要去一个酒店(id=%d)住宿。" % goal['id'] 31 | 32 | for constraint in goal["约束条件"]: 33 | if constraint[0] == "酒店类型": 34 | sen += ('你希望酒店是%s的。' % constraint[1]) 35 | elif "酒店设施" in constraint[0]: 36 | sen += ('你希望酒店提供%s。' % constraint[0].split('-')[1]) 37 | elif constraint[0] == "价格": 38 | sen += ('你希望酒店的最低价格是%s的。' % constraint[1]) 39 | elif constraint[0] == "评分": 40 | sen += ('你希望酒店的评分是%s。' % constraint[1]) 41 | elif constraint[0] == "预订信息": 42 | sen += "" 43 | # if goal["预订信息"]: 44 | # sen += "你希望预订在%s入住,共%s人,住%s天。" % (goal["预订信息"][1][1], goal["预订信息"][0][1], goal["预订信息"][2][1]) 45 | elif domain == "景点": 46 | 47 | for constraint in goal["约束条件"]: 48 | if constraint[0] == "名称": 49 | if '周边' in constraint[1]: 50 | origin_id = int(constraint[1].split('id=')[1][0]) 51 | sen += ('你要去id=%d附近的景点(id=%d)游玩。' % (origin_id, goal['id'])) 52 | else: 53 | sen += ('你要去名叫%s的景点(id=%d)游玩。' % (constraint[1], goal['id'])) 54 | if sen == '': 55 | sen += "你要去一个景点(id=%d)游玩。" % goal['id'] 56 | 57 | for constraint in goal["约束条件"]: 58 | if constraint[0] == "门票": 59 | sen += ('你希望景点的票价是%s的。' % constraint[1]) 60 | elif constraint[0] == "游玩时间": 61 | sen += ('你希望游玩的时长是%s。' % constraint[1]) 62 | elif constraint[0] == "评分": 63 | sen += ('你希望景点的评分是%s。' % constraint[1]) 64 | elif domain == "餐馆": 65 | for constraint in goal["约束条件"]: 66 | if constraint[0] == "名称": 67 | if '周边' in constraint[1]: 68 | origin_id = int(constraint[1].split('id=')[1][0]) 69 | sen += ('你要去id=%d附近的餐馆(id=%d)用餐。' % (origin_id, goal['id'])) 70 | else: 71 | sen += ('你要去名叫%s的餐馆(id=%d)用餐。' % (constraint[1], goal['id'])) 72 | if sen == '': 73 | sen += "你要去一个餐馆(id=%d)用餐。" % goal['id'] 74 | 75 | for constraint in goal["约束条件"]: 76 | if constraint[0] == "推荐菜": 77 | sen += ('你想吃的菜肴是%s。' % '、'.join(constraint[1])) 78 | elif constraint[0] == "人均消费": 79 | sen += ('你希望餐馆的人均消费是%s的。' % constraint[1]) 80 | elif constraint[0] == "评分": 81 | sen += ('你希望餐馆的评分是%s。' % constraint[1]) 82 | # if goal["预订信息"]: 83 | # sen += "你希望预订在%s%s共%s人一起用餐。" % (goal["预订信息"][1][1], goal["预订信息"][2][1], goal["预订信息"][0][1]) 84 | elif domain == "出租": 85 | sen += '你想呼叫从%s到%s的出租车。' % (goal["约束条件"][0][1], goal["约束条件"][1][1]) 86 | elif domain == "地铁": 87 | sen += '你想乘坐从%s到%s的地铁。' % (goal["约束条件"][0][1], goal["约束条件"][1][1]) 88 | sen += '你想知道这个%s的%s。' % (domain, '、'.join(["酒店设施是否包含%s" % item[0].split('-')[1] 89 | if "酒店设施" in item[0] 90 | else item[0] 91 | for item in goal['需求信息']])) 92 | sens.append(sen) 93 | return sens 94 | -------------------------------------------------------------------------------- /convlab2/task/crosswoz/taxi_generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import random 3 | 4 | import numpy as np 5 | 6 | 7 | class TaxiGenerator: 8 | def generate(self, goal_list, goal_num, random_seed=None): 9 | if random_seed: 10 | random.seed(random_seed) 11 | np.random.seed(random_seed) 12 | goal = { 13 | "领域": "出租", 14 | "id": goal_num, 15 | "约束条件": [], 16 | "需求信息": [], 17 | "生成方式": "" 18 | } 19 | goal1, goal2 = random.sample(goal_list, k=2) 20 | goal["约束条件"].append(["出发地", "id=%d" % goal1["id"]]) 21 | goal["约束条件"].append(["目的地", "id=%d" % goal2["id"]]) 22 | goal["需求信息"].append(["车型", ""]) 23 | goal["需求信息"].append(["车牌", ""]) 24 | # goal["需求信息"].append(["距离", ""]) 25 | # goal["需求信息"].append(["电话", ""]) 26 | 27 | if goal1["领域"] == goal2["领域"]: 28 | goal["生成方式"] = "同领域" 29 | else: 30 | goal["生成方式"] = "不同领域" 31 | 32 | return goal 33 | -------------------------------------------------------------------------------- /convlab2/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/convlab2/util/__init__.py -------------------------------------------------------------------------------- /convlab2/util/crosswoz/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/convlab2/util/crosswoz/__init__.py -------------------------------------------------------------------------------- /convlab2/util/crosswoz/lexicalize.py: -------------------------------------------------------------------------------- 1 | def delexicalize_da(da): 2 | delexicalized_da = [] 3 | counter = {} 4 | for intent, domain, slot, value in da: 5 | if intent in ['Inform', 'Recommend']: 6 | key = '+'.join([intent, domain, slot]) 7 | counter.setdefault(key,0) 8 | counter[key] += 1 9 | delexicalized_da.append(key+'+'+str(counter[key])) 10 | else: 11 | delexicalized_da.append('+'.join([intent, domain, slot, value])) 12 | 13 | return delexicalized_da 14 | 15 | 16 | def lexicalize_da(da, cur_domain, entities): 17 | not_dish = {'当地口味', '老字号', '其他', '美食林风味', '特色小吃', '美食林臻选', '深夜营业', '名人光顾', '四合院'} 18 | lexicalized_da = [] 19 | for a in da: 20 | intent, domain, slot, value = a.split('+') 21 | if intent in ['General', 'NoOffer']: 22 | lexicalized_da.append([intent, domain, slot, value]) 23 | elif domain==cur_domain: 24 | value = int(value)-1 25 | if domain == '出租': 26 | assert intent=='Inform' 27 | assert slot in ['车型', '车牌'] 28 | assert value == 0 29 | value = entities[0][1][slot] 30 | lexicalized_da.append([intent, domain, slot, value]) 31 | elif domain == '地铁': 32 | assert intent=='Inform' 33 | assert slot in ['出发地附近地铁站', '目的地附近地铁站'] 34 | assert value == 0 35 | if slot == '出发地附近地铁站': 36 | candidates = [v for n, v in entities if '起点' in n] 37 | if candidates: 38 | value = candidates[0]['地铁'] 39 | else: 40 | value = '无' 41 | else: 42 | candidates = [v for n, v in entities if '终点' in n] 43 | if candidates: 44 | value = candidates[0]['地铁'] 45 | else: 46 | value = '无' 47 | lexicalized_da.append([intent, domain, slot, value]) 48 | else: 49 | if intent=='Recommend': 50 | assert slot=='名称' 51 | if len(entities)>value: 52 | value = entities[value][0] 53 | lexicalized_da.append([intent, domain, slot, value]) 54 | else: 55 | assert intent=='Inform' 56 | if len(entities)>value: 57 | entity = entities[0][1] 58 | if '周边' in slot: 59 | assert isinstance(entity[slot], list) 60 | if value < len(entity[slot]): 61 | value = entity[slot][value] 62 | lexicalized_da.append([intent, domain, slot, value]) 63 | elif slot=='推荐菜': 64 | assert isinstance(entity[slot], list) 65 | dishes = [x for x in entity[slot] if x not in not_dish] 66 | if len(dishes)>value: 67 | value = dishes[value] 68 | lexicalized_da.append([intent, domain, slot, value]) 69 | elif '酒店设施' in slot: 70 | assert value == 0 71 | slot, value = slot.split('-') 72 | assert isinstance(entity[slot], list) 73 | if value in entity[slot]: 74 | lexicalized_da.append([intent, domain, '-'.join([slot, value]), '是']) 75 | else: 76 | lexicalized_da.append([intent, domain, '-'.join([slot, value]), '否']) 77 | elif slot in ['门票', '价格', '人均消费']: 78 | assert value == 0 79 | value = entity[slot] 80 | lexicalized_da.append([intent, domain, slot, '{}元'.format(value)]) 81 | elif slot == '评分': 82 | assert value == 0 83 | value = entity[slot] 84 | lexicalized_da.append([intent, domain, slot, '{}分'.format(value)]) 85 | else: 86 | assert value == 0 87 | value = entity[slot] 88 | lexicalized_da.append([intent, domain, slot, value]) 89 | return lexicalized_da 90 | -------------------------------------------------------------------------------- /convlab2/util/crosswoz/state.py: -------------------------------------------------------------------------------- 1 | def default_state(): 2 | state = dict(user_action=[], 3 | system_action=[], 4 | belief_state={}, 5 | cur_domain=None, 6 | request_slots=[], 7 | terminated=False, 8 | history=[]) 9 | state['belief_state'] = { 10 | '景点': { 11 | '名称': '', 12 | '门票': '', 13 | '游玩时间': '', 14 | '评分': '', 15 | '周边景点': '', 16 | '周边餐馆': '', 17 | '周边酒店': '', 18 | }, 19 | '餐馆': { 20 | '名称': '', 21 | '推荐菜': '', 22 | '人均消费': '', 23 | '评分': '', 24 | '周边景点': '', 25 | '周边餐馆': '', 26 | '周边酒店': '', 27 | }, 28 | '酒店': { 29 | '名称': '', 30 | '酒店类型': '', 31 | '酒店设施': '', 32 | '价格': '', 33 | '评分': '', 34 | '周边景点': '', 35 | '周边餐馆': '', 36 | '周边酒店': '', 37 | }, 38 | '地铁': { 39 | '出发地': '', 40 | '目的地': '', 41 | }, 42 | '出租': { 43 | '出发地': '', 44 | '目的地': '', 45 | } 46 | } 47 | return state 48 | -------------------------------------------------------------------------------- /convlab2/util/file_util.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import zipfile 3 | import json 4 | from convlab2.util.allennlp_file_utils import cached_path as allennlp_cached_path 5 | 6 | 7 | def cached_path(file_path, cached_dir=None): 8 | if not cached_dir: 9 | cached_dir = str(Path(Path.home() / '.convlab2') / "cache") 10 | 11 | return allennlp_cached_path(file_path, cached_dir) 12 | 13 | 14 | def read_zipped_json(zip_path, filepath): 15 | archive = zipfile.ZipFile(zip_path, 'r') 16 | return json.load(archive.open(filepath)) 17 | 18 | 19 | def dump_json(content, filepath): 20 | json.dump(content, open(filepath, 'w', encoding='utf-8'), indent=2, ensure_ascii=False) 21 | 22 | 23 | def write_zipped_json(zip_path, filepath): 24 | with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf: 25 | zf.write(filepath) 26 | -------------------------------------------------------------------------------- /convlab2/util/module.py: -------------------------------------------------------------------------------- 1 | """module interface.""" 2 | from abc import ABC 3 | 4 | 5 | class Module(ABC): 6 | 7 | def train(self, *args, **kwargs): 8 | """Model training entry point""" 9 | pass 10 | 11 | def test(self, *args, **kwargs): 12 | """Model testing entry point""" 13 | pass 14 | 15 | def from_cache(self, *args, **kwargs): 16 | """restore internal state for multi-turn dialog""" 17 | return None 18 | 19 | def to_cache(self, *args, **kwargs): 20 | """save internal state for multi-turn dialog""" 21 | return None 22 | 23 | def init_session(self): 24 | """Init the class variables for a new session.""" 25 | pass 26 | -------------------------------------------------------------------------------- /convlab2/util/train_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | import torch 5 | 6 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | 9 | def init_logging_handler(log_dir, extra=''): 10 | if not os.path.exists(log_dir): 11 | os.makedirs(log_dir) 12 | current_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) 13 | 14 | stderr_handler = logging.StreamHandler() 15 | file_handler = logging.FileHandler('{}/log_{}.txt'.format(log_dir, current_time + extra)) 16 | logging.basicConfig(handlers=[stderr_handler, file_handler]) 17 | logger = logging.getLogger() 18 | logger.setLevel(logging.DEBUG) 19 | 20 | 21 | def to_device(data): 22 | if type(data) == dict: 23 | for k, v in data.items(): 24 | data[k] = v.to(device=DEVICE) 25 | else: 26 | for idx, item in enumerate(data): 27 | data[idx] = item.to(device=DEVICE) 28 | return data 29 | -------------------------------------------------------------------------------- /data/crosswoz/README.md: -------------------------------------------------------------------------------- 1 | ### Data format 2 | 3 | - task_id: dialog 4 | - sys-usr: system annotator ID and user annotation ID. 5 | - goal: list of tuples, includes: 6 | - sub-goal id 7 | - domain name 8 | - slot name 9 | - constraint if filled, else requirement 10 | - whether be mentioned in previous turns 11 | - message: dialog turns. Each turn contains 12 | - content: utterance 13 | - role: user or system side 14 | - dialog_act: list of dialog act tuples, includes: 15 | - domain 16 | - intent 17 | - slot 18 | - value 19 | - user_state: same format as "goal", can be viewed as dynamic goal 20 | - sys_state_init: the **first** db query emitted, records user constraints faithfully. If the system find no result that matches, he/she may relax the constraints manually and search db multiple times. 21 | - domain: slot-value pairs 22 | - selectedResults: db search result that would be used in this turn 23 | - sys_state: the **final** db query emitted, records the db used by the system in this turn. Note that this may not satisfy all user constraints. 24 | - final_goal: user state/goal at the end of dialog 25 | - task description: natural language description of the user goal. 26 | - type: dialog type. -------------------------------------------------------------------------------- /data/crosswoz/database/database.md: -------------------------------------------------------------------------------- 1 | # database 2 | 3 | - 值缺失一律为 None,导出到json中显示为null,从json导入后是None。 4 | - 周边xx具有对称性,A在B的周边里则B也在A的周边里。条目较多,显示时可截取前五个。 5 | - 门票、评分、人均消费、价格 要用区间查询,支持 小于"x"|区间(包含端点)"x-y"。 6 | - *: 允许查询的内容。string 类型用字符串匹配,list of string 逐个匹配,int/float 涉及大小比较。**推荐菜**和**酒店设施**支持多个条件匹配,用空格分隔,如 "东北杀猪菜 锅包肉",检索时要求两个都出现在推荐菜中。 7 | - 出租数据库是模板,不查询,均为占位符。 8 | 9 | ### 景点 10 | 11 | - 领域: "景点" 12 | - 名称*: string 13 | - 地址: string 14 | - 地铁: string 15 | - 电话: string 16 | - 门票*: int (缺失则为None) 17 | - 游玩时间*: string 18 | - 评分*: float (缺失则为None) 19 | - 周边景点*: list of string 20 | - 周边餐馆*: list of string 21 | - 周边酒店*: list of string 22 | 23 | 24 | 25 | ### 餐馆 26 | 27 | - 领域: "餐馆" 28 | - 名称*: string 29 | - 地址: string 30 | - 地铁: string 31 | - 电话: string 32 | - 营业时间: string 33 | - 推荐菜*: list of string 34 | - 人均消费*: int 35 | - 评分*: float (缺失则为None) 36 | - 周边景点*: list of string 37 | - 周边餐馆*: list of string 38 | - 周边酒店*: list of string 39 | 40 | 41 | 42 | ### 酒店 43 | 44 | - 名称*: string 45 | - 酒店类型*: string 46 | - 地址: string 47 | - 地铁: string 48 | - 电话: string 49 | - 酒店设施*: list of string 50 | - 价格*: int (缺失则为None) 51 | - 评分*: float 52 | - 周边景点*: list of string 53 | - 周边餐馆*: list of string 54 | - 周边酒店*: list of string 55 | 56 | 57 | 58 | 59 | ### 地铁 60 | 61 | - 名称*: string 62 | - 地铁*: string (缺失则为None) 63 | 64 | 65 | 66 | ### 出租 67 | 68 | - 车型: "#CX" 69 | - 车牌: "#CP" 70 | -------------------------------------------------------------------------------- /data/crosswoz/database/taxi_db.json: -------------------------------------------------------------------------------- 1 | [ 2 | [ 3 | "出租 ($出发地 - $目的地)", 4 | { 5 | "领域": "出租", 6 | "车型": "#CX", 7 | "车牌": "#CP" 8 | } 9 | ] 10 | ] -------------------------------------------------------------------------------- /data/crosswoz/gen_da_voc.py: -------------------------------------------------------------------------------- 1 | import json 2 | import zipfile 3 | import os 4 | from convlab2.util.crosswoz.lexicalize import delexicalize_da 5 | 6 | 7 | def read_zipped_json(filepath, filename): 8 | archive = zipfile.ZipFile(filepath, 'r') 9 | return json.load(archive.open(filename)) 10 | 11 | 12 | def gen_da_voc(data): 13 | usr_da_voc, sys_da_voc = {}, {} 14 | for task_id, item in data.items(): 15 | for i, turn in enumerate(item['messages']): 16 | if turn['role'] == 'usr': 17 | da_voc = usr_da_voc 18 | else: 19 | da_voc = sys_da_voc 20 | for da in delexicalize_da(turn['dialog_act']): 21 | da_voc[da] = 0 22 | return sorted(usr_da_voc.keys()), sorted(sys_da_voc.keys()) 23 | 24 | 25 | if __name__ == '__main__': 26 | data = read_zipped_json('train.json.zip','train.json') 27 | usr_da_voc, sys_da_voc = gen_da_voc(data) 28 | json.dump(usr_da_voc, open('usr_da_voc.json', 'w', encoding='utf-8'), indent=4, ensure_ascii=False) 29 | json.dump(sys_da_voc, open('sys_da_voc.json', 'w', encoding='utf-8'), indent=4, ensure_ascii=False) 30 | 31 | -------------------------------------------------------------------------------- /data/crosswoz/sys_da_voc.json: -------------------------------------------------------------------------------- 1 | [ 2 | "General+bye+none+none", 3 | "General+greet+none+none", 4 | "General+reqmore+none+none", 5 | "General+thank+none+none", 6 | "General+welcome+none+none", 7 | "Inform+出租+车型+1", 8 | "Inform+出租+车牌+1", 9 | "Inform+地铁+出发地附近地铁站+1", 10 | "Inform+地铁+目的地附近地铁站+1", 11 | "Inform+景点+名称+1", 12 | "Inform+景点+周边景点+1", 13 | "Inform+景点+周边景点+2", 14 | "Inform+景点+周边景点+3", 15 | "Inform+景点+周边景点+4", 16 | "Inform+景点+周边景点+5", 17 | "Inform+景点+周边酒店+1", 18 | "Inform+景点+周边酒店+2", 19 | "Inform+景点+周边酒店+3", 20 | "Inform+景点+周边酒店+4", 21 | "Inform+景点+周边酒店+5", 22 | "Inform+景点+周边餐馆+1", 23 | "Inform+景点+周边餐馆+2", 24 | "Inform+景点+周边餐馆+3", 25 | "Inform+景点+周边餐馆+4", 26 | "Inform+景点+周边餐馆+5", 27 | "Inform+景点+周边餐馆+6", 28 | "Inform+景点+地址+1", 29 | "Inform+景点+游玩时间+1", 30 | "Inform+景点+电话+1", 31 | "Inform+景点+评分+1", 32 | "Inform+景点+门票+1", 33 | "Inform+酒店+价格+1", 34 | "Inform+酒店+名称+1", 35 | "Inform+酒店+周边景点+1", 36 | "Inform+酒店+周边景点+2", 37 | "Inform+酒店+周边景点+3", 38 | "Inform+酒店+周边景点+4", 39 | "Inform+酒店+周边景点+5", 40 | "Inform+酒店+周边餐馆+1", 41 | "Inform+酒店+周边餐馆+2", 42 | "Inform+酒店+周边餐馆+3", 43 | "Inform+酒店+周边餐馆+4", 44 | "Inform+酒店+周边餐馆+5", 45 | "Inform+酒店+地址+1", 46 | "Inform+酒店+电话+1", 47 | "Inform+酒店+评分+1", 48 | "Inform+酒店+酒店类型+1", 49 | "Inform+酒店+酒店设施-24小时热水+1", 50 | "Inform+酒店+酒店设施-SPA+1", 51 | "Inform+酒店+酒店设施-中式餐厅+1", 52 | "Inform+酒店+酒店设施-会议室+1", 53 | "Inform+酒店+酒店设施-健身房+1", 54 | "Inform+酒店+酒店设施-免费国内长途电话+1", 55 | "Inform+酒店+酒店设施-免费市内电话+1", 56 | "Inform+酒店+酒店设施-公共区域和部分房间提供wifi+1", 57 | "Inform+酒店+酒店设施-公共区域提供wifi+1", 58 | "Inform+酒店+酒店设施-叫醒服务+1", 59 | "Inform+酒店+酒店设施-吹风机+1", 60 | "Inform+酒店+酒店设施-商务中心+1", 61 | "Inform+酒店+酒店设施-国际长途电话+1", 62 | "Inform+酒店+酒店设施-室内游泳池+1", 63 | "Inform+酒店+酒店设施-室外游泳池+1", 64 | "Inform+酒店+酒店设施-宽带上网+1", 65 | "Inform+酒店+酒店设施-所有房间提供wifi+1", 66 | "Inform+酒店+酒店设施-接待外宾+1", 67 | "Inform+酒店+酒店设施-接机服务+1", 68 | "Inform+酒店+酒店设施-接站服务+1", 69 | "Inform+酒店+酒店设施-收费停车位+1", 70 | "Inform+酒店+酒店设施-无烟房+1", 71 | "Inform+酒店+酒店设施-早餐服务+1", 72 | "Inform+酒店+酒店设施-早餐服务免费+1", 73 | "Inform+酒店+酒店设施-暖气+1", 74 | "Inform+酒店+酒店设施-桑拿+1", 75 | "Inform+酒店+酒店设施-棋牌室+1", 76 | "Inform+酒店+酒店设施-残疾人设施+1", 77 | "Inform+酒店+酒店设施-洗衣服务+1", 78 | "Inform+酒店+酒店设施-温泉+1", 79 | "Inform+酒店+酒店设施-看护小孩服务+1", 80 | "Inform+酒店+酒店设施-租车+1", 81 | "Inform+酒店+酒店设施-行李寄存+1", 82 | "Inform+酒店+酒店设施-西式餐厅+1", 83 | "Inform+酒店+酒店设施-部分房间提供wifi+1", 84 | "Inform+酒店+酒店设施-酒吧+1", 85 | "Inform+酒店+酒店设施-酒店各处提供wifi+1", 86 | "Inform+餐馆+人均消费+1", 87 | "Inform+餐馆+名称+1", 88 | "Inform+餐馆+周边景点+1", 89 | "Inform+餐馆+周边景点+2", 90 | "Inform+餐馆+周边景点+3", 91 | "Inform+餐馆+周边景点+4", 92 | "Inform+餐馆+周边景点+5", 93 | "Inform+餐馆+周边酒店+1", 94 | "Inform+餐馆+周边酒店+2", 95 | "Inform+餐馆+周边酒店+3", 96 | "Inform+餐馆+周边酒店+4", 97 | "Inform+餐馆+周边酒店+5", 98 | "Inform+餐馆+周边餐馆+1", 99 | "Inform+餐馆+周边餐馆+2", 100 | "Inform+餐馆+周边餐馆+3", 101 | "Inform+餐馆+周边餐馆+4", 102 | "Inform+餐馆+周边餐馆+5", 103 | "Inform+餐馆+地址+1", 104 | "Inform+餐馆+推荐菜+1", 105 | "Inform+餐馆+推荐菜+10", 106 | "Inform+餐馆+推荐菜+11", 107 | "Inform+餐馆+推荐菜+12", 108 | "Inform+餐馆+推荐菜+13", 109 | "Inform+餐馆+推荐菜+14", 110 | "Inform+餐馆+推荐菜+15", 111 | "Inform+餐馆+推荐菜+16", 112 | "Inform+餐馆+推荐菜+17", 113 | "Inform+餐馆+推荐菜+18", 114 | "Inform+餐馆+推荐菜+19", 115 | "Inform+餐馆+推荐菜+2", 116 | "Inform+餐馆+推荐菜+20", 117 | "Inform+餐馆+推荐菜+21", 118 | "Inform+餐馆+推荐菜+22", 119 | "Inform+餐馆+推荐菜+23", 120 | "Inform+餐馆+推荐菜+24", 121 | "Inform+餐馆+推荐菜+25", 122 | "Inform+餐馆+推荐菜+26", 123 | "Inform+餐馆+推荐菜+27", 124 | "Inform+餐馆+推荐菜+28", 125 | "Inform+餐馆+推荐菜+29", 126 | "Inform+餐馆+推荐菜+3", 127 | "Inform+餐馆+推荐菜+30", 128 | "Inform+餐馆+推荐菜+31", 129 | "Inform+餐馆+推荐菜+32", 130 | "Inform+餐馆+推荐菜+4", 131 | "Inform+餐馆+推荐菜+5", 132 | "Inform+餐馆+推荐菜+6", 133 | "Inform+餐馆+推荐菜+7", 134 | "Inform+餐馆+推荐菜+8", 135 | "Inform+餐馆+推荐菜+9", 136 | "Inform+餐馆+电话+1", 137 | "Inform+餐馆+营业时间+1", 138 | "Inform+餐馆+评分+1", 139 | "NoOffer+景点+none+none", 140 | "NoOffer+酒店+none+none", 141 | "NoOffer+餐馆+none+none", 142 | "Recommend+景点+名称+1", 143 | "Recommend+景点+名称+2", 144 | "Recommend+景点+名称+3", 145 | "Recommend+景点+名称+4", 146 | "Recommend+景点+名称+5", 147 | "Recommend+酒店+名称+1", 148 | "Recommend+酒店+名称+2", 149 | "Recommend+酒店+名称+3", 150 | "Recommend+酒店+名称+4", 151 | "Recommend+酒店+名称+5", 152 | "Recommend+餐馆+名称+1", 153 | "Recommend+餐馆+名称+2", 154 | "Recommend+餐馆+名称+3", 155 | "Recommend+餐馆+名称+4", 156 | "Recommend+餐馆+名称+5" 157 | ] -------------------------------------------------------------------------------- /data/crosswoz/test.json.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/data/crosswoz/test.json.zip -------------------------------------------------------------------------------- /data/crosswoz/train.json.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/data/crosswoz/train.json.zip -------------------------------------------------------------------------------- /data/crosswoz/usr_da_voc.json: -------------------------------------------------------------------------------- 1 | [ 2 | "General+bye+none+none", 3 | "General+greet+none+none", 4 | "General+thank+none+none", 5 | "Inform+出租+出发地+1", 6 | "Inform+出租+目的地+1", 7 | "Inform+出租+目的地+2", 8 | "Inform+地铁+出发地+1", 9 | "Inform+地铁+目的地+1", 10 | "Inform+地铁+目的地+2", 11 | "Inform+景点+名称+1", 12 | "Inform+景点+名称+2", 13 | "Inform+景点+名称+3", 14 | "Inform+景点+名称+4", 15 | "Inform+景点+游玩时间+1", 16 | "Inform+景点+评分+1", 17 | "Inform+景点+门票+1", 18 | "Inform+酒店+价格+1", 19 | "Inform+酒店+名称+1", 20 | "Inform+酒店+名称+2", 21 | "Inform+酒店+名称+3", 22 | "Inform+酒店+评分+1", 23 | "Inform+酒店+酒店类型+1", 24 | "Inform+酒店+酒店设施-24小时热水+1", 25 | "Inform+酒店+酒店设施-SPA+1", 26 | "Inform+酒店+酒店设施-中式餐厅+1", 27 | "Inform+酒店+酒店设施-会议室+1", 28 | "Inform+酒店+酒店设施-健身房+1", 29 | "Inform+酒店+酒店设施-免费国内长途电话+1", 30 | "Inform+酒店+酒店设施-免费市内电话+1", 31 | "Inform+酒店+酒店设施-公共区域和部分房间提供wifi+1", 32 | "Inform+酒店+酒店设施-公共区域提供wifi+1", 33 | "Inform+酒店+酒店设施-叫醒服务+1", 34 | "Inform+酒店+酒店设施-吹风机+1", 35 | "Inform+酒店+酒店设施-商务中心+1", 36 | "Inform+酒店+酒店设施-国际长途电话+1", 37 | "Inform+酒店+酒店设施-室内游泳池+1", 38 | "Inform+酒店+酒店设施-室外游泳池+1", 39 | "Inform+酒店+酒店设施-宽带上网+1", 40 | "Inform+酒店+酒店设施-所有房间提供wifi+1", 41 | "Inform+酒店+酒店设施-接待外宾+1", 42 | "Inform+酒店+酒店设施-接机服务+1", 43 | "Inform+酒店+酒店设施-接站服务+1", 44 | "Inform+酒店+酒店设施-收费停车位+1", 45 | "Inform+酒店+酒店设施-无烟房+1", 46 | "Inform+酒店+酒店设施-早餐服务+1", 47 | "Inform+酒店+酒店设施-早餐服务免费+1", 48 | "Inform+酒店+酒店设施-暖气+1", 49 | "Inform+酒店+酒店设施-桑拿+1", 50 | "Inform+酒店+酒店设施-棋牌室+1", 51 | "Inform+酒店+酒店设施-残疾人设施+1", 52 | "Inform+酒店+酒店设施-洗衣服务+1", 53 | "Inform+酒店+酒店设施-温泉+1", 54 | "Inform+酒店+酒店设施-看护小孩服务+1", 55 | "Inform+酒店+酒店设施-租车+1", 56 | "Inform+酒店+酒店设施-行李寄存+1", 57 | "Inform+酒店+酒店设施-西式餐厅+1", 58 | "Inform+酒店+酒店设施-部分房间提供wifi+1", 59 | "Inform+酒店+酒店设施-酒吧+1", 60 | "Inform+酒店+酒店设施-酒店各处提供wifi+1", 61 | "Inform+餐馆+人均消费+1", 62 | "Inform+餐馆+名称+1", 63 | "Inform+餐馆+名称+2", 64 | "Inform+餐馆+推荐菜+1", 65 | "Inform+餐馆+推荐菜+2", 66 | "Inform+餐馆+推荐菜+3", 67 | "Inform+餐馆+评分+1", 68 | "Request+出租+车型+", 69 | "Request+出租+车牌+", 70 | "Request+地铁+出发地附近地铁站+", 71 | "Request+地铁+目的地附近地铁站+", 72 | "Request+景点+名称+", 73 | "Request+景点+周边景点+", 74 | "Request+景点+周边酒店+", 75 | "Request+景点+周边餐馆+", 76 | "Request+景点+地址+", 77 | "Request+景点+游玩时间+", 78 | "Request+景点+电话+", 79 | "Request+景点+评分+", 80 | "Request+景点+门票+", 81 | "Request+酒店+价格+", 82 | "Request+酒店+名称+", 83 | "Request+酒店+周边景点+", 84 | "Request+酒店+周边餐馆+", 85 | "Request+酒店+地址+", 86 | "Request+酒店+电话+", 87 | "Request+酒店+评分+", 88 | "Request+酒店+酒店类型+", 89 | "Request+酒店+酒店设施-24小时热水+", 90 | "Request+酒店+酒店设施-SPA+", 91 | "Request+酒店+酒店设施-中式餐厅+", 92 | "Request+酒店+酒店设施-会议室+", 93 | "Request+酒店+酒店设施-健身房+", 94 | "Request+酒店+酒店设施-免费国内长途电话+", 95 | "Request+酒店+酒店设施-免费市内电话+", 96 | "Request+酒店+酒店设施-公共区域和部分房间提供wifi+", 97 | "Request+酒店+酒店设施-公共区域提供wifi+", 98 | "Request+酒店+酒店设施-叫醒服务+", 99 | "Request+酒店+酒店设施-吹风机+", 100 | "Request+酒店+酒店设施-商务中心+", 101 | "Request+酒店+酒店设施-国际长途电话+", 102 | "Request+酒店+酒店设施-室内游泳池+", 103 | "Request+酒店+酒店设施-室外游泳池+", 104 | "Request+酒店+酒店设施-宽带上网+", 105 | "Request+酒店+酒店设施-所有房间提供wifi+", 106 | "Request+酒店+酒店设施-接待外宾+", 107 | "Request+酒店+酒店设施-接机服务+", 108 | "Request+酒店+酒店设施-接站服务+", 109 | "Request+酒店+酒店设施-收费停车位+", 110 | "Request+酒店+酒店设施-无烟房+", 111 | "Request+酒店+酒店设施-早餐服务+", 112 | "Request+酒店+酒店设施-早餐服务免费+", 113 | "Request+酒店+酒店设施-暖气+", 114 | "Request+酒店+酒店设施-桑拿+", 115 | "Request+酒店+酒店设施-棋牌室+", 116 | "Request+酒店+酒店设施-残疾人设施+", 117 | "Request+酒店+酒店设施-洗衣服务+", 118 | "Request+酒店+酒店设施-温泉+", 119 | "Request+酒店+酒店设施-看护小孩服务+", 120 | "Request+酒店+酒店设施-租车+", 121 | "Request+酒店+酒店设施-行李寄存+", 122 | "Request+酒店+酒店设施-西式餐厅+", 123 | "Request+酒店+酒店设施-部分房间提供wifi+", 124 | "Request+酒店+酒店设施-酒吧+", 125 | "Request+酒店+酒店设施-酒店各处提供wifi+", 126 | "Request+餐馆+人均消费+", 127 | "Request+餐馆+名称+", 128 | "Request+餐馆+周边景点+", 129 | "Request+餐馆+周边酒店+", 130 | "Request+餐馆+周边餐馆+", 131 | "Request+餐馆+地址+", 132 | "Request+餐馆+推荐菜+", 133 | "Request+餐馆+电话+", 134 | "Request+餐馆+营业时间+", 135 | "Request+餐馆+评分+", 136 | "Select+景点+源领域+景点", 137 | "Select+景点+源领域+酒店", 138 | "Select+景点+源领域+餐馆", 139 | "Select+酒店+源领域+景点", 140 | "Select+酒店+源领域+餐馆", 141 | "Select+餐馆+源领域+景点", 142 | "Select+餐馆+源领域+酒店", 143 | "Select+餐馆+源领域+餐馆" 144 | ] -------------------------------------------------------------------------------- /data/crosswoz/val.json.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/data/crosswoz/val.json.zip -------------------------------------------------------------------------------- /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/example.png -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | python-coveralls 2 | pytest-dependency 3 | pytest-mock 4 | requests-mock 5 | pytest>=3.6.0 6 | pytest-cov==2.4.0 7 | checksumdir 8 | bs4 9 | lxml 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nltk>=3.4 2 | tqdm>=4.30 3 | checksumdir>=1.1 4 | visdom 5 | Pillow 6 | future 7 | torch 8 | numpy>=1.15.0 9 | scipy 10 | scikit-learn==0.20.3 11 | pytorch-pretrained-bert>=0.6.1 12 | transformers>=2.3.0 13 | tensorflow==1.14 14 | tensorboard>=1.14.0 15 | tensorboardX==1.7 16 | allennlp 17 | requests 18 | simplejson 19 | unidecode 20 | jieba 21 | embeddings 22 | quadprog 23 | -------------------------------------------------------------------------------- /result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/result.png -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | release = egg_info -Db '' 3 | 4 | [egg_info] 5 | tag_build = .dev 6 | tag_date = 1 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | ''' 2 | setup.py for ConvLab-2 3 | ''' 4 | import sys 5 | import os 6 | from typing import List 7 | from setuptools import setup, find_packages 8 | from setuptools.command.test import test as TestCommand 9 | 10 | 11 | class LibTest(TestCommand): 12 | 13 | def run_tests(self): 14 | # import here, cause outside the eggs aren't loaded 15 | ret = os.system("pytest --cov=ConvLab-2 tests/ --cov-report term-missing") 16 | sys.exit(ret >> 8) 17 | 18 | 19 | def read_requirements(require_file: str = './requirements.txt') -> List[str]: 20 | """read the dependency requirements from the file 21 | 22 | Returns: 23 | List[str]: the list of dependency packages 24 | """ 25 | if not os.path.exists(require_file): 26 | raise FileNotFoundError(f'{require_file} file not found') 27 | with open(require_file, 'r+') as f: 28 | return list(f.readlines()) 29 | 30 | setup( 31 | name='ConvLab-2', 32 | version='0.0.1', 33 | packages=find_packages(exclude=[]), 34 | license='Apache', 35 | description='Task-oriented Dialog System Toolkits', 36 | long_description=open('README.md', encoding='UTF-8').read(), 37 | long_description_content_type="text/markdown", 38 | classifiers=[ 39 | 'Development Status :: 2 - Pre-Alpha', 40 | 'License :: OSI Approved :: Apache Software License', 41 | 'Programming Language :: Python :: 3.5', 42 | 'Programming Language :: Python :: 3.6', 43 | 'Intended Audience :: Science/Research', 44 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 45 | ], 46 | install_requires=read_requirements(), 47 | extras_require={ 48 | 'develop': read_requirements('./requirements-dev.txt') 49 | }, 50 | cmdclass={'test': LibTest}, 51 | entry_points={ 52 | 'console_scripts': [ 53 | "ConvLab-2-report=convlab2.scripts:report" 54 | ] 55 | }, 56 | include_package_data=True, 57 | url='https://github.com/thu-coai/ConvLab-2', 58 | author='thu-coai', 59 | author_email='thu-coai-developer@googlegroups.com', 60 | python_requires='>=3.5', 61 | zip_safe=False 62 | ) 63 | -------------------------------------------------------------------------------- /web/.editorconfig: -------------------------------------------------------------------------------- 1 | [*] 2 | charset = utf-8 3 | end_of_line = lf 4 | insert_final_newline = true 5 | indent_style = space 6 | indent_size = 4 7 | 8 | [{*.sht,*.html,*.shtm,*.shtml,*.ng,*.htm,.babelrc,.stylelintrc,.eslintrc,jest.config,*.bowerrc,*.jsb3,*.jsb2,*.json,*.js.map}] 9 | indent_style = space 10 | indent_size = 2 11 | -------------------------------------------------------------------------------- /web/.env.example: -------------------------------------------------------------------------------- 1 | FLASK_ENV = development 2 | FLASK_DEBUG = 1 3 | 4 | DATABASE_PATH = ./app.db 5 | REDIS_URL = redis://localhost:6379/0 6 | -------------------------------------------------------------------------------- /web/.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | *.pyc 4 | *.pyo 5 | __pycache__ 6 | 7 | venv 8 | .env 9 | .flaskenv 10 | 11 | .idea 12 | .vscode 13 | 14 | *.db 15 | 16 | data_labelling/settings.py 17 | data_labelling/results/**/*.json 18 | data_labelling/results/**/all.zip 19 | -------------------------------------------------------------------------------- /web/README.md: -------------------------------------------------------------------------------- 1 | # A Data Labelling System for CoAI Lab, May 2019 2 | 3 | ## Requirements 4 | 5 | * python3.6 6 | 7 | ## Development Startup 8 | 9 | Copy `.env.example` to `.env`, `settings.example.py` to `settings.py` and update them with local config. 10 | 11 | Run the following command and visit http://localhost:5000. 12 | 13 | ```bash 14 | # install dependencies 15 | pip install -r requirements.txt 16 | 17 | # initialize the database 18 | python resetdb.py 19 | 20 | # start the server 21 | python run.py 22 | ``` 23 | 24 | ## Deployment Instruction 25 | 26 | You can use exactly same steps as development startup. 27 | 28 | ## 操作指南 29 | 30 | 启动服务器以后,可以使用默认管理员账户登录。 31 | 32 | 用户名:root 33 | 34 | 密码:root 35 | 36 | 若要进入普通用户界面,请点击注册链接,使用邀请码注册。 37 | 38 | 当前邀请码为 `959592`,可以修改 `data_labelling/app.py` 的 `invitation_code`。 39 | 40 | 普通用户登录后,进入对话匹配界面,匹配需要**至少有一人选择系统端,一人选择用户端**,此时系统会自动完成匹配,两人进入对话界面。 41 | 42 | 提示:如果在本地测试功能,可以使用 Chrome 的无痕窗口同时登录两个用户。 43 | 44 | 在对话界面,用户端先开始对话,以一问一答的方式进行。某一方发送消息后,另一方可以立即看到,但是只有当前者完成必要的信息标注后,后者才可以发送消息。 45 | 46 | 对话页面相关逻辑参见 `templates` 目录下的文件。 47 | 48 | 管理员可以在后台导入预先定义的任务,导出对话数据。 49 | 50 | 导入任务的步骤是:在 Result Files 选项卡下,进入 inputs 目录,上传任务定义文件 [tasks.json](example_goal.json),再回到管理首页,点导入按钮。导入成功后,系统会告知导入成功的任务数量。此时也可以到 Task 选项卡查看详情。 51 | 52 | 导出数据直接在管理首页点下载全部即可。 53 | 54 | 55 | -------------------------------------------------------------------------------- /web/data_labelling/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import timedelta 3 | from flask_socketio import SocketIO 4 | import flask_socketio 5 | 6 | from data_labelling.redis import rd 7 | from .app import app 8 | from .admin import admin 9 | from .models import * 10 | from . import routes 11 | 12 | 13 | def run(): 14 | running_rooms = Room.select().where(Room.status_code == Room.Status.RUNNING.value) 15 | for room in running_rooms: 16 | room.task.finished = False 17 | room.task.save() 18 | room.status = Room.Status.ABORTED 19 | room.save() 20 | 21 | admin.init_app(app) 22 | rd.flushall() 23 | 24 | socket_io = SocketIO(app) 25 | socket_io.run(app) 26 | -------------------------------------------------------------------------------- /web/data_labelling/admin.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from flask_admin import Admin, AdminIndexView, expose 3 | from flask_admin.contrib.peewee import ModelView 4 | from flask_admin.contrib.fileadmin import FileAdmin 5 | from flask import g, redirect, url_for 6 | import os 7 | 8 | from .models import * 9 | 10 | 11 | class MyIndexView(AdminIndexView): 12 | def is_accessible(self): 13 | return g.me and g.me.is_admin 14 | 15 | def inaccessible_callback(self, name, **kwargs): 16 | return redirect(url_for('misc_bp.login')) 17 | 18 | 19 | class MyFileAdmin(FileAdmin): 20 | def is_accessible(self): 21 | return g.me and g.me.is_admin 22 | 23 | def inaccessible_callback(self, name, **kwargs): 24 | return redirect(url_for('misc_bp.login')) 25 | 26 | 27 | class MyModelView(ModelView): 28 | def is_accessible(self): 29 | return g.me and g.me.is_admin 30 | 31 | def inaccessible_callback(self, name, **kwargs): 32 | return redirect(url_for('misc_bp.login')) 33 | 34 | 35 | class UserView(MyModelView): 36 | column_exclude_list = ['password_hash', ] 37 | 38 | 39 | class TaskView(MyModelView): 40 | pass 41 | 42 | 43 | class RoomView(MyModelView): 44 | pass 45 | 46 | 47 | load_dotenv() 48 | 49 | admin = Admin( 50 | name='任务导向对话系统· 管理界面', 51 | template_mode='bootstrap3', 52 | index_view=MyIndexView() 53 | ) 54 | 55 | admin.add_views( 56 | UserView(User), 57 | TaskView(Task), 58 | RoomView(Room) 59 | ) 60 | 61 | results_dir = os.path.join(os.path.dirname(__file__), 'results') 62 | admin.add_view(MyFileAdmin(results_dir, name='Result Files')) 63 | -------------------------------------------------------------------------------- /web/data_labelling/app.py: -------------------------------------------------------------------------------- 1 | import string 2 | import random 3 | 4 | from flask import Flask, session, g 5 | 6 | from .models import * 7 | from .routes import * 8 | 9 | app = Flask(__name__) 10 | 11 | app.register_blueprint(misc.bp, url_prefix='') 12 | app.register_blueprint(room.bp, url_prefix='/room') 13 | app.register_blueprint(services.bp, url_prefix='/services') 14 | app.register_blueprint(match.bp, url_prefix='/match') 15 | 16 | app.config.from_pyfile('settings.py') 17 | 18 | invitation_code = '959592' 19 | 20 | @app.before_request 21 | def before_request(): 22 | g.invitation_code = invitation_code 23 | 24 | user_id = session.get('user_id') 25 | try: 26 | g.me = User.get(User.id == user_id) 27 | except: 28 | g.me = None 29 | -------------------------------------------------------------------------------- /web/data_labelling/match_making/__init__.py: -------------------------------------------------------------------------------- 1 | from .match import update, get_status, num_waiting, add_user, free_user, Status, leave_room 2 | -------------------------------------------------------------------------------- /web/data_labelling/match_making/helpers.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from peewee import * 3 | 4 | from ..models import * 5 | from ..redis import rd, with_lock 6 | 7 | 8 | class Status(Enum): 9 | FREE = 0 10 | IN_QUEUE = 1 11 | MATCHED = 2 12 | 13 | 14 | class RedisQueue: 15 | def __init__(self, name): 16 | self.name = name 17 | 18 | @with_lock("matching") 19 | def __len__(self): 20 | return rd.llen('queue:{}'.format(self.name)) 21 | 22 | @with_lock("matching") 23 | def push(self, v): 24 | rd.rpush('queue:{}'.format(self.name), v) 25 | 26 | @with_lock("matching") 27 | def pop(self): 28 | return int(rd.lpop('queue:{}'.format(self.name))) 29 | 30 | @with_lock("matching") 31 | def remove(self, v): 32 | rd.lrem('queue:{}'.format(self.name), 0, v) 33 | 34 | 35 | def create_room(system, client): 36 | user0 = User.get(User.id == system) 37 | user1 = User.get(User.id == client) 38 | task = Task.select().where(Task.finished == False).order_by(fn.Random()).get() 39 | print(task.content) 40 | task.finished = True 41 | task.save() 42 | room = Room.create( 43 | task=task, 44 | user0=user0, 45 | user1=user1, 46 | status=Room.Status.RUNNING 47 | ) 48 | 49 | 50 | @with_lock("matching") 51 | def get_status(uid): 52 | s = rd.get('user_status:{}'.format(uid)) or 0 53 | return Status(int(s)) 54 | 55 | 56 | @with_lock("matching") 57 | def set_status(uid, status: Status): 58 | s = status.value 59 | rd.set('user_status:{}'.format(uid), s) 60 | -------------------------------------------------------------------------------- /web/data_labelling/match_making/match.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from .helpers import * 4 | 5 | systems = RedisQueue("systems") 6 | clients = RedisQueue("clients") 7 | 8 | 9 | def add_user(uid, role): 10 | if get_status(uid) != Status.FREE: 11 | return False 12 | 13 | set_status(uid, Status.IN_QUEUE) 14 | if role == 0: 15 | systems.push(uid) 16 | else: 17 | clients.push(uid) 18 | 19 | return True 20 | 21 | 22 | def free_user(uid): 23 | if get_status(uid) == Status.IN_QUEUE: 24 | set_status(uid, Status.FREE) 25 | try: 26 | systems.remove(uid) 27 | except ValueError: 28 | pass 29 | try: 30 | clients.remove(uid) 31 | except ValueError: 32 | pass 33 | 34 | 35 | def leave_room(uid): 36 | if get_status(uid) == Status.MATCHED: 37 | set_status(uid, Status.FREE) 38 | 39 | 40 | last_update_time = 0.0 41 | UPDATE_MINIMAL_INTERVAL = 1.0 42 | 43 | 44 | def update(): 45 | global last_update_time 46 | now = time.time() 47 | if now - last_update_time >= UPDATE_MINIMAL_INTERVAL: 48 | last_update_time = now 49 | 50 | while len(systems) > 0 and len(clients) > 0: 51 | system = systems.pop() 52 | client = clients.pop() 53 | 54 | set_status(system, Status.MATCHED) 55 | set_status(client, Status.MATCHED) 56 | print('matched {} and {}'.format(system, client)) 57 | create_room(system, client) 58 | 59 | 60 | def num_waiting(): 61 | return len(systems), len(clients) 62 | -------------------------------------------------------------------------------- /web/data_labelling/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | 4 | from enum import Enum 5 | from dotenv import load_dotenv 6 | from playhouse.sqlite_ext import * 7 | 8 | __all__ = ['db', 'User', 'Task', 'Room', 'Message'] 9 | 10 | load_dotenv() 11 | 12 | db = SqliteExtDatabase(os.getenv('DATABASE_PATH')) 13 | 14 | SQLITE_EXT_JSON1_PATH = os.getenv('SQLITE_EXT_JSON1_PATH') 15 | if SQLITE_EXT_JSON1_PATH: 16 | db.load_extension(SQLITE_EXT_JSON1_PATH) 17 | 18 | 19 | class BaseModel(Model): 20 | class Meta: 21 | database = db 22 | 23 | 24 | class User(BaseModel): 25 | username = CharField(unique=True) 26 | password_hash = CharField() 27 | 28 | is_admin = BooleanField(default=False) 29 | 30 | tasks_done = IntegerField(default=0) 31 | 32 | created_at = DateTimeField(default=datetime.datetime.now) 33 | 34 | def updateTasksCount(self): 35 | self.tasks_done = Room.select().where((Room.status_code == Room.Status.SUCCESS.value) & ((Room.user0 == self) | (Room.user1 == self))).count() 36 | self.save() 37 | 38 | 39 | class Task(BaseModel): 40 | content = JSONField() 41 | 42 | created_at = DateTimeField(default=datetime.datetime.now) 43 | 44 | finished = BooleanField(default=False) 45 | 46 | 47 | class Room(BaseModel): 48 | task = ForeignKeyField(Task, backref='rooms') 49 | 50 | user0 = ForeignKeyField(User) 51 | user1 = ForeignKeyField(User) 52 | 53 | status_code = IntegerField() 54 | 55 | created_at = DateTimeField(default=datetime.datetime.now) 56 | 57 | class Status(Enum): 58 | RUNNING = 0 59 | SUCCESS = 1 60 | ABORTED = 2 61 | 62 | @property 63 | def status(self): 64 | return Room.Status(self.status_code) 65 | 66 | @status.setter 67 | def status(self, status): 68 | assert isinstance(status, Room.Status) 69 | self.status_code = status.value 70 | 71 | 72 | class Message(BaseModel): 73 | room = ForeignKeyField(Room, backref='messages') 74 | role = IntegerField() 75 | 76 | content = TextField() 77 | payload = JSONField() 78 | 79 | created_at = DateTimeField(default=datetime.datetime.now) 80 | -------------------------------------------------------------------------------- /web/data_labelling/redis.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | import redis 3 | from redis_lock import Lock 4 | from dotenv import load_dotenv 5 | from os import getenv 6 | 7 | load_dotenv() 8 | rd = redis.from_url(getenv('REDIS_URL', 'redis://localhost:6379/0')) 9 | 10 | 11 | def with_lock(name): 12 | def decorator(f): 13 | wraps(f) 14 | 15 | def decorated(*args, **kwargs): 16 | with Lock(rd, name): 17 | res = f(*args, **kwargs) 18 | return res 19 | 20 | return decorated 21 | 22 | return decorator 23 | -------------------------------------------------------------------------------- /web/data_labelling/results/input/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/web/data_labelling/results/input/.gitkeep -------------------------------------------------------------------------------- /web/data_labelling/routes/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['room', 'misc', 'services', 'match'] 2 | -------------------------------------------------------------------------------- /web/data_labelling/routes/match.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint, render_template, jsonify 2 | from flask_socketio import emit 3 | 4 | from ..utils import * 5 | from ..models import * 6 | from .. import match_making 7 | 8 | bp = Blueprint('match_bp', __name__) 9 | 10 | 11 | @bp.route('/') 12 | @login_required 13 | def index(): 14 | return render_template('match.html') 15 | 16 | 17 | @bp.route('/enter-queue/', methods=['POST']) 18 | @login_required 19 | def enter_queue(role): 20 | match_making.update() 21 | if match_making.add_user(g.me.id, role): 22 | return 'OK' 23 | else: 24 | return '用户已经在队列中', 400 25 | 26 | 27 | @bp.route('/quit-queue', methods=['POST']) 28 | @login_required 29 | def quit_queue(): 30 | match_making.update() 31 | 32 | uid = g.me.id 33 | if match_making.get_status(uid) == match_making.Status.IN_QUEUE: 34 | match_making.free_user(uid) 35 | return 'OK' 36 | else: 37 | return '用户不在队列中', 400 38 | 39 | 40 | @bp.route('/num-waiting') 41 | @login_required 42 | def num_waiting(): 43 | match_making.update() 44 | return jsonify(match_making.num_waiting()) 45 | 46 | 47 | @bp.route('/get-room') 48 | @login_required 49 | def get_room(): 50 | match_making.update() 51 | 52 | uid = g.me.id 53 | if match_making.get_status(uid) == match_making.Status.MATCHED: 54 | try: 55 | room = Room.get((Room.status_code == Room.Status.RUNNING.value) & 56 | ((Room.user0 == g.me) | (Room.user1 == g.me))) 57 | role = 0 58 | if room.user1 == g.me: 59 | role = 1 60 | return jsonify({ 61 | 'room_id': room.id, 62 | 'role': role 63 | }) 64 | except: 65 | return '房间不存在', 400 66 | 67 | return '204', 204 68 | -------------------------------------------------------------------------------- /web/data_labelling/routes/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from peewee import prefetch, fn 4 | from datetime import datetime, date 5 | from flask import Blueprint, render_template, request, url_for, jsonify 6 | 7 | from ..utils import * 8 | 9 | bp = Blueprint('misc_bp', __name__) 10 | 11 | 12 | @bp.route('/login') 13 | def login(): 14 | return render_template('login.html') if not g.me else redirect('/') 15 | 16 | 17 | @bp.route('/register') 18 | def register(): 19 | return render_template('register.html') if not g.me else redirect('/') 20 | 21 | 22 | @bp.route('/') 23 | @login_required 24 | def index(): 25 | return redirect('/match') if not g.me.is_admin else redirect('/admin') 26 | 27 | 28 | @bp.route('/num-tasks-unfinished') 29 | @login_required 30 | def num_tasks_unfinished(): 31 | num = Task.select().where(Task.finished == False).count() 32 | return jsonify(num) 33 | 34 | 35 | @bp.route('/import-all', methods=['POST']) 36 | @admin_required 37 | def import_all(): 38 | basepath = 'data_labelling/results/input' 39 | count = 0 40 | 41 | def import_one(path): 42 | with open(path, 'r', encoding='utf-8') as f: 43 | data = json.load(f) 44 | def f(options): 45 | items = [] 46 | for goal in options['goals']: 47 | field = goal['领域'] 48 | _id = goal['id'] 49 | kv = goal.get('约束条件', []) + goal.get('需求信息', []) + goal.get('预订信息', []) 50 | for k, v in kv: 51 | items.append([_id, field, k, v]) 52 | options['items'] = items 53 | return options 54 | 55 | count = 0 56 | for x in data: 57 | Task.create(content=f(x)) 58 | count += 1 59 | return count 60 | 61 | for filename in os.listdir(basepath): 62 | if not filename.endswith('.json'): 63 | continue 64 | fullpath = os.path.join(basepath, filename) 65 | try: 66 | count += import_one(fullpath) 67 | except Exception as e: 68 | print(repr(e)) 69 | try: 70 | os.remove(fullpath) 71 | except: 72 | pass 73 | return str(count) 74 | 75 | 76 | @bp.route('/export-all', methods=['POST']) 77 | @admin_required 78 | def export_all(): 79 | basepath = 'data_labelling/results/output' 80 | n = Task.select(fn.Max(Task.id)).scalar() 81 | step = 100 82 | 83 | class JSONEncoder(json.JSONEncoder): 84 | def default(self, obj): 85 | if isinstance(obj, datetime): 86 | return obj.strftime('%Y-%m-%d %H:%M:%S') 87 | elif isinstance(obj, date): 88 | return obj.strftime('%Y-%m-%d') 89 | else: 90 | return json.JSONEncoder.default(self, obj) 91 | count = 0 92 | files = [] 93 | for i in range(0, n, step): 94 | j = min(i + step, n) 95 | messages = Message.select().order_by(Message.id) 96 | rooms = Room.select().where((Room.status_code == Room.Status.SUCCESS.value) & (Room.task > i) & (Room.task <= j)).order_by(Room.task) 97 | rooms = prefetch(rooms, messages) 98 | data = [] 99 | for room in rooms: 100 | data.append({ 101 | 'task': room.task_id, 102 | 'user': [room.user0_id, room.user1_id], 103 | 'messages': list(map(lambda msg: dict({ 104 | 'role': msg.role, 105 | 'content': msg.content, 106 | 'payload': msg.payload, 107 | 'created_at': msg.created_at 108 | }), room.messages)), 109 | 'created_at': room.created_at 110 | }) 111 | count += 1 112 | fullpath = os.path.join(basepath, '%s.json' % j) 113 | with open(fullpath, 'w', encoding='utf-8') as f: 114 | json.dump(data, f, cls=JSONEncoder, indent=4, ensure_ascii=False) 115 | files.append(fullpath) 116 | zip_file = os.path.join(basepath, 'all.zip') 117 | try: 118 | os.remove(zip_file) 119 | except: 120 | pass 121 | os.system('zip -j %s %s' % (zip_file, ' '.join(files))) 122 | return str(count) 123 | 124 | @bp.route('/remove-waiting-tasks', methods=['POST']) 125 | @admin_required 126 | def remove_waiting_tasks(): 127 | Task.delete().where(Task.id.not_in(Room.select(Task.id).where(Room.task == Task.id))).execute() 128 | return 'OK' 129 | -------------------------------------------------------------------------------- /web/data_labelling/routes/room.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from flask import Blueprint, render_template, redirect, request, jsonify, g 4 | from flask_socketio import emit 5 | 6 | from ..utils import * 7 | from ..models import * 8 | from ..match_making import leave_room 9 | 10 | bp = Blueprint('room_bp', __name__) 11 | 12 | 13 | @bp.route('//messages') 14 | @room_guard 15 | def get_messages(room): 16 | def wrap(message): 17 | message['payload'] = bool(message['payload']) 18 | return message 19 | return jsonify([wrap(message) for message in room.messages.select(Message.id, Message.role, Message.content, Message.payload).dicts()]) 20 | 21 | 22 | @bp.route('///message/content', methods=['POST']) 23 | @room_guard 24 | def post_message_content(room, role): 25 | try: 26 | data = request.get_data() 27 | data = json.loads(data) 28 | content = data['content'] 29 | except: 30 | return '格式错误', 400 31 | try: 32 | lastmsg = Message.select().where(Message.room == room).order_by(-Message.id).first() 33 | if lastmsg.role == role: 34 | return '不得连续两次发送消息', 403 35 | elif not lastmsg.payload: 36 | return '对方尚未提交表单', 403 37 | except: 38 | pass 39 | Message.create(room=room, role=role, content=content, payload={}) 40 | emit('update', namespace='/room/%s' % room.id, broadcast=True) 41 | return 'OK' 42 | 43 | @bp.route('///message/payload', methods=['POST']) 44 | @room_guard 45 | def post_message_payload(room, role): 46 | try: 47 | data = request.get_data() 48 | data = json.loads(data) 49 | payload = data['payload'] 50 | except: 51 | return '格式错误', 400 52 | try: 53 | lastmsg = Message.select().where(Message.room == room).order_by(-Message.id).first() 54 | if lastmsg.role != role: 55 | return '上一次不是己方发送消息', 403 56 | elif lastmsg.payload: 57 | return '已经提交表单', 403 58 | except: 59 | pass 60 | lastmsg.payload = payload 61 | lastmsg.save() 62 | emit('update', namespace='/room/%s' % room.id, broadcast=True) 63 | return 'OK' 64 | 65 | @bp.route('//') 66 | @room_guard 67 | def room(room, role): 68 | return render_template('room.html', room=room, role=role) 69 | 70 | 71 | @bp.route('//leave') 72 | @room_guard 73 | def leave(room): 74 | if not room.user1 == g.me: 75 | return '只有用户可以结束任务', 400 76 | try: 77 | lastmsg = Message.select().where((Message.room == room) & (Message.role == 1)).order_by(-Message.id).first() 78 | if not lastmsg: 79 | return '还没消息', 403 80 | gugu = lastmsg.payload 81 | for item in gugu: 82 | if not item[3]: 83 | return '表单没填完', 403 84 | except: 85 | pass 86 | 87 | room.status = Room.Status.SUCCESS 88 | room.save() 89 | 90 | close_room(room) 91 | 92 | emit('finished', namespace='/room/%s' % room.id, broadcast=True) 93 | return 'OK' 94 | 95 | 96 | @bp.route('//abort') 97 | @room_guard 98 | def abort(room): 99 | task = room.task 100 | task.finished = False 101 | task.save() 102 | 103 | room.status = Room.Status.ABORTED 104 | room.save() 105 | 106 | close_room(room) 107 | 108 | emit('finished', namespace='/room/%s' % room.id, broadcast=True) 109 | return 'OK' 110 | 111 | 112 | def close_room(room): 113 | leave_room(room.user0.id) 114 | leave_room(room.user1.id) 115 | room.user0.updateTasksCount() 116 | room.user1.updateTasksCount() 117 | -------------------------------------------------------------------------------- /web/data_labelling/routes/services.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import bcrypt 4 | 5 | from flask import Blueprint, render_template, redirect, request, session 6 | 7 | from ..utils import * 8 | from ..models import * 9 | 10 | bp = Blueprint('services_bp', __name__) 11 | 12 | 13 | @bp.route('/resetdb', methods=['POST']) 14 | @admin_required 15 | def service_resetdb(): 16 | resetdb() 17 | 18 | return 'OK' 19 | 20 | 21 | @bp.route('/register', methods=['POST']) 22 | def service_register(): 23 | data = request.get_data() 24 | try: 25 | data = json.loads(data) 26 | username = data['username'] 27 | password = data['password'] 28 | invitation_code = data['invitationCode'] 29 | except: 30 | return '格式错误', 400 31 | 32 | if invitation_code != g.invitation_code: 33 | return '邀请码不正确', 403 34 | 35 | if not re.fullmatch(r'\w+', username): 36 | return '用户名仅含字母数字下划线', 400 37 | if not (2 <= len(username) <= 16): 38 | return '用户名的长度范围 [2, 16]', 400 39 | if not (8 <= len(password) <= 32): 40 | return '密码的长度范围 [8, 32]', 400 41 | if User.select().where(User.username == username): 42 | return '用户名已被使用', 422 43 | 44 | User.create(username=username, password_hash=bcrypt.hashpw(str.encode(password), bcrypt.gensalt())) 45 | 46 | return 'OK' 47 | 48 | 49 | @bp.route('/login', methods=['POST']) 50 | def services_login(): 51 | data = request.get_data() 52 | try: 53 | data = json.loads(data) 54 | username = data['username'] 55 | password = data['password'] 56 | except: 57 | return '格式错误', 400 58 | 59 | try: 60 | user = User.get(User.username == username) 61 | assert bcrypt.checkpw(str.encode(password), str.encode(user.password_hash)) 62 | except: 63 | return '用户名或密码错误', 401 64 | 65 | session['user_id'] = user.id 66 | session.permanent = True 67 | 68 | return 'OK' 69 | 70 | 71 | @bp.route('/logout', methods=['POST']) 72 | @login_required 73 | def service_logout(): 74 | session.pop('user_id') 75 | 76 | return 'OK' 77 | -------------------------------------------------------------------------------- /web/data_labelling/settings.example.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | 3 | SECRET_KEY = 'SECRET_KEY_FOR_FLASK_SESSION' 4 | PERMANENT_SESSION_LIFETIME = timedelta(days=30) 5 | 6 | FLASK_ADMIN_SWATCH = 'united' 7 | -------------------------------------------------------------------------------- /web/data_labelling/static/webfonts/fa-brands-400.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/web/data_labelling/static/webfonts/fa-brands-400.eot -------------------------------------------------------------------------------- /web/data_labelling/static/webfonts/fa-brands-400.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/web/data_labelling/static/webfonts/fa-brands-400.ttf -------------------------------------------------------------------------------- /web/data_labelling/static/webfonts/fa-brands-400.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/web/data_labelling/static/webfonts/fa-brands-400.woff -------------------------------------------------------------------------------- /web/data_labelling/static/webfonts/fa-brands-400.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/web/data_labelling/static/webfonts/fa-brands-400.woff2 -------------------------------------------------------------------------------- /web/data_labelling/static/webfonts/fa-regular-400.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/web/data_labelling/static/webfonts/fa-regular-400.eot -------------------------------------------------------------------------------- /web/data_labelling/static/webfonts/fa-regular-400.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/web/data_labelling/static/webfonts/fa-regular-400.ttf -------------------------------------------------------------------------------- /web/data_labelling/static/webfonts/fa-regular-400.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/web/data_labelling/static/webfonts/fa-regular-400.woff -------------------------------------------------------------------------------- /web/data_labelling/static/webfonts/fa-regular-400.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/web/data_labelling/static/webfonts/fa-regular-400.woff2 -------------------------------------------------------------------------------- /web/data_labelling/static/webfonts/fa-solid-900.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/web/data_labelling/static/webfonts/fa-solid-900.eot -------------------------------------------------------------------------------- /web/data_labelling/static/webfonts/fa-solid-900.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/web/data_labelling/static/webfonts/fa-solid-900.ttf -------------------------------------------------------------------------------- /web/data_labelling/static/webfonts/fa-solid-900.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/web/data_labelling/static/webfonts/fa-solid-900.woff -------------------------------------------------------------------------------- /web/data_labelling/static/webfonts/fa-solid-900.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/CrossWOZ/df82c9fdff91b9b130f2d6b89110d3870ba6260e/web/data_labelling/static/webfonts/fa-solid-900.woff2 -------------------------------------------------------------------------------- /web/data_labelling/templates/admin/index.html: -------------------------------------------------------------------------------- 1 | {% extends 'admin/master.html' %} 2 | {% block access_control %} 3 | 8 | {% endblock %} 9 | {% block body %} 10 |
11 |
    12 |
  • 注册邀请码: {{ g.invitation_code }}
  • 13 |
  • 导入任务: 请将任务配置文件(json 格式)上传至 input 文件夹,导入完成后这些文件将被删除。点此导入
  • 14 |
  • 导出数据: 所有数据将被导出至 output 文件夹,按照导入顺序作为 ID 每 100 个形成一个文件。点此导出 下载全部
  • 15 |
  • 删除未标注任务: 点此执行
  • 16 |
  • 没有正在执行的管理操作
  • 17 | 18 |
19 |
20 | {% endblock %} 21 | 22 | {% block tail %} 23 | 24 | 25 | 68 | {% endblock %} 69 | -------------------------------------------------------------------------------- /web/data_labelling/templates/base.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | {% block title %}{% endblock %} 7 | 8 | 10 | 11 | 12 | 13 | 22 | {% block style %}{% endblock %} 23 | 24 | 25 | {% block content %}{% endblock %} 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | {% block script %}{% endblock %} 34 | 35 | 36 | -------------------------------------------------------------------------------- /web/data_labelling/templates/clientside.html: -------------------------------------------------------------------------------- 1 | {% macro render1(room) %} 2 |
3 |
4 |

任务描述

5 |
6 |
    7 | {% for line in room.task.content['description'] %} 8 |
  1. {{ line }}
  2. 9 | {% endfor %} 10 |
11 |
12 |
13 | {% endmacro %} 14 | 15 | {% macro render2(room) %} 16 |
17 |
18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 |
id领域
40 |
41 | {% endmacro %} 42 | 43 | {% macro script(room) %} 44 | 93 | {% endmacro %} 94 | 95 | {% macro style() %} 96 | 103 | {% endmacro %} 104 | 105 | -------------------------------------------------------------------------------- /web/data_labelling/templates/dashboard.html: -------------------------------------------------------------------------------- 1 | {% extends 'base.html' %} 2 | {% import 'heading.html' as heading %} 3 | 4 | {% block title %} 控制面板 {% endblock %} 5 | 6 | {% block content %} 7 | {{ heading.render() }} 8 |
9 |
10 | 15 | 23 |
24 | 注意: 25 |
    26 |
  • 上述按钮在最终版本中并不存在,仅为展示阶段性 demo 所设。
  • 27 |
  • 现阶段主要实现了聊天室功能,并未设置额外的检查,下次更新时将完善。
  • 28 |
  • 为了更好地感受聊天体验,可在多台电脑登录(均用 root / root),分别体验系统端 / 用户端。也可在多个网页中打开。
  • 29 |
30 |
31 |
32 |
33 | {% endblock %} 34 | 35 | {% block script %} 36 | 39 | {% endblock %} 40 | -------------------------------------------------------------------------------- /web/data_labelling/templates/heading.html: -------------------------------------------------------------------------------- 1 | {% macro render() %} 2 | 28 | {% endmacro %} 29 | -------------------------------------------------------------------------------- /web/data_labelling/templates/index.html: -------------------------------------------------------------------------------- 1 | {% extends 'base.html' %} 2 | {% import 'heading.html' as heading %} 3 | 4 | {% block title %} 首页 {% endblock %} 5 | 6 | {% block content %} 7 |
8 | {{ heading.render() }} 9 |
10 |
11 |
12 |

选择您偏好的角色:

13 | 14 |
15 |
16 |
17 |
18 | {% endblock %} 19 | -------------------------------------------------------------------------------- /web/data_labelling/templates/login.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block title %} 登录 {% endblock %} 4 | 5 | {% block content %} 6 |
7 |
8 |
9 |
10 |

登录 11 | 任务导向对话系统 12 |

13 |
14 |
15 |
16 | 17 |
18 | 19 |
20 |
21 |
22 | 23 |
24 | 25 |
26 |
27 | 33 | 36 |
37 |
38 |
39 |
40 |
41 | {% endblock %} 42 | 43 | {% block script %} 44 | 74 | {% endblock %} 75 | -------------------------------------------------------------------------------- /web/data_labelling/templates/match.html: -------------------------------------------------------------------------------- 1 | {% extends 'base.html' %} 2 | {% import 'heading.html' as heading %} 3 | 4 | {% block title %} 匹配 {% endblock %} 5 | 6 | {% block style %} 7 | 13 | {% endblock %} 14 | 15 | {% block content %} 16 |
17 | {{ heading.render() }} 18 |
19 |
20 |

寻找匹配

21 |

22 | 选择一个偏好的角色,与其它用户进行匹配。 23 |

24 |
25 |
26 |
27 |

系统

28 | 31 |

32 | 当前等待人数: 33 |

34 |
35 |
36 |

用户

37 | 40 |

41 | 当前等待人数: 42 |

43 |
44 |
45 | 50 |
51 |
52 | {% endblock %} 53 | 54 | {% block script %} 55 | 160 | {% endblock %} 161 | -------------------------------------------------------------------------------- /web/data_labelling/templates/register.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block title %} 注册 {% endblock %} 4 | 5 | {% block content %} 6 |
7 |
8 |
9 |
10 |

注册 11 | 任务导向对话系统 12 |

13 |
14 |
15 |
16 | 17 |
18 | 19 |
20 |
21 |
22 | 23 |
24 | 25 |
26 |
27 |
28 | 29 |
30 | 31 |
32 |
33 |
34 | 35 |
36 | 37 |
38 |
39 | 45 | 48 |
49 |
50 |
51 |
52 |
53 | {% endblock %} 54 | 55 | {% block script %} 56 | 93 | {% endblock %} 94 | -------------------------------------------------------------------------------- /web/data_labelling/templates/room.html: -------------------------------------------------------------------------------- 1 | {% extends 'base.html' %} 2 | {% import 'heading.html' as heading %} 3 | {% import 'chatbox.html' as chatbox %} 4 | {% import 'systemside.html' as systemside %} 5 | {% import 'clientside.html' as clientside %} 6 | {% set ourside = clientside if role else systemside %} 7 | 8 | {% block title %} 任务 - {{ '用户' if role else '系统' }}端 {% endblock %} 9 | 10 | {% block content %} 11 |
12 | {{ heading.render() }} 13 |
14 |
15 |
16 | {{ ourside.render1(room) }} 17 |
18 |
19 | {{ ourside.render2(room) }} 20 |
21 |
22 | {{ chatbox.render(room, role) }} 23 |
24 |
25 |
26 |
27 | {% endblock %} 28 | 29 | {% block script %} 30 | {{ ourside.script(room) }} 31 | {{ chatbox.script(room, role) }} 32 | {% endblock %} 33 | 34 | {% block style %} 35 | {{ ourside.style() }} 36 | {{ chatbox.style() }} 37 | {% endblock %} 38 | -------------------------------------------------------------------------------- /web/data_labelling/utils.py: -------------------------------------------------------------------------------- 1 | import bcrypt 2 | 3 | from flask import redirect, abort, g 4 | from functools import wraps 5 | 6 | from .models import * 7 | 8 | 9 | def resetdb(): 10 | all_tables = [ 11 | User, Task, Room, Message 12 | ] 13 | db.drop_tables(all_tables) 14 | db.create_tables(all_tables) 15 | 16 | username = 'root' 17 | password = 'root' 18 | password_hash = bcrypt.hashpw(str.encode(password), bcrypt.gensalt()) 19 | User.create(username=username, password_hash=password_hash, is_admin=True) 20 | 21 | 22 | def login_required(f): 23 | @wraps(f) 24 | def decorated_f(*args, **kwargs): 25 | if not g.me: 26 | return redirect('/login') 27 | return f(*args, **kwargs) 28 | 29 | return decorated_f 30 | 31 | 32 | def admin_required(f): 33 | @wraps(f) 34 | def decorated_f(*args, **kwargs): 35 | if not g.me or not g.me.is_admin: 36 | return redirect('/') 37 | return f(*args, **kwargs) 38 | 39 | return decorated_f 40 | 41 | 42 | def room_guard(f): 43 | @wraps(f) 44 | def decorated_f(room_id, **kwargs): 45 | if not g.me: 46 | return redirect('/') 47 | try: 48 | room = Room.get(Room.id == room_id) 49 | if 'role' in kwargs: 50 | role = kwargs['role'] 51 | assert role in [0, 1] 52 | assert (g.me == (room.user0 if role == 0 else room.user1)) 53 | except: 54 | return abort(404) 55 | if room.status != Room.Status.RUNNING: 56 | return redirect('/') 57 | return f(room, **kwargs) 58 | 59 | return decorated_f 60 | -------------------------------------------------------------------------------- /web/ecosystem.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | apps: [ 3 | { 4 | name: "data_labelling", 5 | script: "run.py", 6 | interpreter: "venv/bin/python" 7 | } 8 | ] 9 | }; 10 | -------------------------------------------------------------------------------- /web/requirements.txt: -------------------------------------------------------------------------------- 1 | python-dotenv 2 | flask 3 | flask-socketio 4 | flask-admin 5 | flask-login 6 | python-redis-lock 7 | bcrypt 8 | peewee 9 | wtf-peewee 10 | eventlet 11 | redis 12 | -------------------------------------------------------------------------------- /web/resetdb.py: -------------------------------------------------------------------------------- 1 | from data_labelling.utils import resetdb 2 | 3 | if __name__ == '__main__': 4 | resetdb() 5 | -------------------------------------------------------------------------------- /web/run.py: -------------------------------------------------------------------------------- 1 | from data_labelling import run 2 | 3 | if __name__ == '__main__': 4 | run() 5 | -------------------------------------------------------------------------------- /web/setup.bat: -------------------------------------------------------------------------------- 1 | pip install -r requirements.txt 2 | -------------------------------------------------------------------------------- /web/setup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | pip install -r requirements.txt 3 | --------------------------------------------------------------------------------