├── simdial ├── agent │ ├── __init__.py │ ├── core.py │ ├── nlg.py │ ├── user.py │ └── system.py ├── config.py ├── __init__.py ├── database.py ├── channel.py ├── domain.py ├── generator.py └── complexity.py ├── README.md ├── .gitignore ├── LICENSE └── multiple_domains.py /simdial/agent/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # author: Tiancheng Zhao 3 | -------------------------------------------------------------------------------- /simdial/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # author: Tiancheng Zhao 3 | 4 | 5 | class Config(object): 6 | debug = False 7 | -------------------------------------------------------------------------------- /simdial/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # author: Tiancheng Zhao 3 | import logging 4 | from simdial.config import Config 5 | 6 | logging.basicConfig(filename='simdial.log' if Config.debug is False else None, level=logging.DEBUG, 7 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CMU SimDial: Synthetic Task-oriented Dialog Generator with Controllable Complexity 2 | This is the dialog data used 3 | by our SIGDIAL 2018 paper: [Zero-Shot Dialog Generation with Cross-Domain Latent Actions](https://arxiv.org/abs/1805.04803). 4 | See paper for details. The source code and data used for the paper can be found at [here](https://github.com/snakeztc/NeuralDialog-ZSDG). 5 | 6 | ## Prerequisites 7 | - Python 2.7 8 | - Numpy 9 | - NLTK 10 | - progressbar 11 | 12 | 13 | ## Usage 14 | Run the following code to generate dialog data for multiple domains that are defined in the 15 | *multiple_domains.py* script. 16 | 17 | python multiple_domains.py 18 | 19 | The data will be saved into two folders 20 | 21 | test/ for testing data 22 | train/ for training data 23 | 24 | ## References 25 | If you use any source codes or datasets included in this toolkit in your work, please cite the following paper. The bibtex are listed below: 26 | 27 | @article{zhao2018zero, 28 | title={Zero-Shot Dialog Generation with Cross-Domain Latent Actions}, 29 | author={Zhao, Tiancheng and Eskenazi, Maxine}, 30 | journal={arXiv preprint arXiv:1805.04803}, 31 | year={2018} 32 | } 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | .idea/ 106 | -------------------------------------------------------------------------------- /simdial/database.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | 4 | 5 | class Database(object): 6 | """ 7 | A table-based database class. Each row is an entry and each column is an attribute. Each attribute 8 | has vocabulary size called modality. 9 | 10 | :ivar usr_dirichlet_priors: the prior for each attribute : 2D list [[]*modality] 11 | :ivar num_usr_slots: the number of columns: Int 12 | :ivar usr_modalities: the vocab size of each column : List 13 | :ivar usr_pdf: the PDF for each columns : 2D list 14 | :ivar num_rows: the number of entries 15 | :ivar table: the content : 2D list [[] *num_rows] 16 | :ivar indexes: for efficient SELECT : [{attribute_word -> corresponding rows}] 17 | """ 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | def __init__(self, usr_dirichlet_priors, sys_dirichlet_priors, num_rows): 22 | """ 23 | :param usr_dirichlet_priors: 2D list [[]_0, []_1, ... []_k] for each searchable attributes 24 | :param sys_dirichlet_priors: 2D llst for each entry (non-searchable attributes) 25 | :param num_rows: the number of row in the database 26 | """ 27 | self.usr_dirichlet_priors = usr_dirichlet_priors 28 | self.sys_dirichlet_priors = sys_dirichlet_priors 29 | 30 | self.num_usr_slots = len(usr_dirichlet_priors) 31 | self.usr_modalities = [len(p) for p in usr_dirichlet_priors] 32 | 33 | self.num_sys_slots = len(sys_dirichlet_priors) 34 | self.sys_modalities = [len(p) for p in sys_dirichlet_priors] 35 | 36 | # sample attr_pdf for each attribute from the dirichlet prior 37 | self.usr_pdf = [np.random.dirichlet(d_p) for d_p in self.usr_dirichlet_priors] 38 | self.sys_pdf = [np.random.dirichlet(d_p) for d_p in self.sys_dirichlet_priors] 39 | self.num_rows = num_rows 40 | 41 | # begin to generate the table 42 | usr_table, usr_index = self._gen_table(self.usr_pdf, self.usr_modalities, self.num_usr_slots, num_rows) 43 | sys_table, sys_index = self._gen_table(self.sys_pdf, self.sys_modalities, self.num_sys_slots, num_rows) 44 | 45 | # append the UID in the first column 46 | sys_table.insert(0, range(self.num_rows)) 47 | 48 | self.table = np.array(usr_table).transpose() 49 | self.indexes = usr_index 50 | self.sys_table = np.array(sys_table).transpose() 51 | 52 | @staticmethod 53 | def _gen_table(pdf, modalities, num_cols, num_rows): 54 | list_table = [] 55 | indexes = [] 56 | for idx in range(num_cols): 57 | col = np.random.choice(range(modalities[idx]), p=pdf[idx], size=num_rows) 58 | list_table.append(col) 59 | # indexing 60 | index = {} 61 | for m_id in range(modalities[idx]): 62 | matched_list = np.squeeze(np.argwhere(col == m_id)).tolist() 63 | matched_list = set(matched_list) if type(matched_list) is list else {matched_list} 64 | index[m_id] = matched_list 65 | indexes.append(index) 66 | return list_table, indexes 67 | 68 | def sample_unique_row(self): 69 | """ 70 | :return: a unique row in the searchable table 71 | """ 72 | unique_rows = np.unique(self.table, axis=0) 73 | idxes = range(len(unique_rows)) 74 | np.random.shuffle(idxes) 75 | return unique_rows[idxes[0]] 76 | 77 | def select(self, query, return_index=False): 78 | """ 79 | Filter the database entries according the query. 80 | 81 | :param query: 1D [] equal to the number of attributes, None means don't care 82 | :param return_index: if return the db index 83 | :return return a list system_entries and (optional)index that satisfy all constrains 84 | 85 | """ 86 | valid_idx = set(range(self.num_rows)) 87 | for q, a_id in zip(query, range(self.num_usr_slots)): 88 | if q: 89 | valid_idx -= self.indexes[a_id][q] 90 | if len(valid_idx) == 0: 91 | break 92 | valid_idx = list(valid_idx) 93 | if return_index: 94 | return self.sys_table[valid_idx, :], valid_idx 95 | else: 96 | return self.sys_table[valid_idx, :] 97 | 98 | def pprint(self): 99 | """ 100 | print statistics of the database in a beautiful format. 101 | """ 102 | 103 | self.logger.info("DB contains %d rows (%d unique ones), with %d attributes" 104 | % (self.num_rows, len(np.unique(self.table, axis=0)), self.num_usr_slots)) 105 | -------------------------------------------------------------------------------- /simdial/channel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # author: Tiancheng Zhao 3 | import numpy as np 4 | from simdial.agent.core import UserAct, BaseUsrSlot 5 | import copy 6 | 7 | 8 | class AbstractNoise(object): 9 | def __init__(self, domain, complexity): 10 | self.complexity = complexity 11 | self.domain = domain 12 | 13 | def transmit(self, actions): 14 | raise NotImplementedError 15 | 16 | def transmit_words(self, utt): 17 | pass 18 | 19 | 20 | class EnvironmentNoise(AbstractNoise): 21 | def __init__(self, domain, complexity): 22 | super(EnvironmentNoise, self).__init__(domain, complexity) 23 | self.dim_map = {slot.name: slot.dim for slot in domain.usr_slots} 24 | 25 | def transmit(self, actions): 26 | conf = np.random.normal(self.complexity.asr_acc, self.complexity.asr_std) 27 | conf = np.clip(conf, 0.1, 0.99) 28 | noisy_actions = [] 29 | # check has yes no 30 | has_confirm = False 31 | for a in actions: 32 | if a.act in [UserAct.DISCONFIRM, UserAct.CONFIRM]: 33 | has_confirm = True 34 | break 35 | if has_confirm: 36 | conf = np.clip(conf+0.1, 0.1, 0.99) 37 | 38 | for a in actions: 39 | if a.act == UserAct.CONFIRM: 40 | if np.random.rand() > conf: 41 | a.act = UserAct.DISCONFIRM 42 | elif a.act == UserAct.DISCONFIRM: 43 | if np.random.rand() > conf: 44 | a.act = UserAct.CONFIRM 45 | elif a.act == UserAct.INFORM: 46 | if np.random.rand() > conf: 47 | slot, value = a.parameters[0] 48 | choices = range(self.dim_map[slot]) + [None] 49 | a.parameters[0] = (slot, np.random.choice(choices)) 50 | 51 | noisy_actions.append(a) 52 | 53 | return noisy_actions, conf 54 | 55 | 56 | class InteractionNoise(AbstractNoise): 57 | 58 | def transmit(self, actions): 59 | return self.add_self_correct(actions) 60 | 61 | def transmit_words(self, utt): 62 | # hesitation 63 | utt = self.add_hesitation(utt) 64 | 65 | # self-restart 66 | return self.add_self_restart(utt) 67 | 68 | def add_hesitation(self, utt): 69 | tokens = utt.split(" ") 70 | if len(tokens) > 4 and np.random.rand() < self.complexity.hesitation: 71 | pos = np.random.randint(1, len(tokens)-1) 72 | tokens.insert(pos, np.random.choice(["hmm", "uhm", "hmm ...",])) 73 | return " ".join(tokens) 74 | return utt 75 | 76 | def add_self_restart(self, utt): 77 | tokens = utt.split(" ") 78 | if len(tokens) > 4 and np.random.rand() < self.complexity.self_restart: 79 | length = np.random.randint(1, 3) 80 | tokens = tokens[0:length] + ["uhm yeah"] + tokens 81 | return " ".join(tokens) 82 | return utt 83 | 84 | def add_self_correct(self, actions): 85 | for a in actions: 86 | if a.act == UserAct.INFORM and np.random.rand() < self.complexity.self_correct: 87 | a.parameters.append((BaseUsrSlot.SELF_CORRECT, True)) 88 | return actions 89 | 90 | 91 | class SocialNoise(AbstractNoise): 92 | def transmit(self, actions): 93 | return actions 94 | 95 | 96 | # Channels at action-level and word-level 97 | 98 | class ActionChannel(object): 99 | """ 100 | A class to simulate the complex behviaor of human-computer conversation. 101 | """ 102 | 103 | def __init__(self, domain, complexity): 104 | self.environment = EnvironmentNoise(domain, complexity) 105 | self.interaction = InteractionNoise(domain, complexity) 106 | self.social = SocialNoise(domain, complexity) 107 | 108 | def transmit2sys(self, actions): 109 | """ 110 | Given a list of action from a user to a system, add noise to the actions. 111 | 112 | :param actions: a list of clean action from the user to the system 113 | :return: a list of corrupted actions. 114 | """ 115 | action_copy = [copy.deepcopy(a) for a in actions] 116 | noisy_actions = self.interaction.transmit(action_copy) 117 | noisy_actions = self.social.transmit(noisy_actions) 118 | noisy_actions, conf = self.environment.transmit(noisy_actions) 119 | return noisy_actions, conf 120 | 121 | 122 | class WordChannel(object): 123 | """ 124 | A class to simulate the complex behviaor of human-computer conversation. 125 | """ 126 | 127 | def __init__(self, domain, complexity): 128 | self.interaction = InteractionNoise(domain, complexity) 129 | 130 | def transmit2sys(self, utt): 131 | """ 132 | Given a list of action from a user to a system, add noise to the actions. 133 | 134 | :param actions: a list of clean action from the user to the system 135 | :return: a list of corrupted actions. 136 | """ 137 | return self.interaction.transmit_words(utt) -------------------------------------------------------------------------------- /simdial/domain.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # author: Tiancheng Zhao 3 | from simdial.database import Database 4 | import numpy as np 5 | from simdial.agent.core import BaseSysSlot 6 | import logging 7 | 8 | 9 | class DomainSpec(object): 10 | """ 11 | Abstract specification template. 12 | 13 | :cvar usr_slots: [(slot_name, slot_description, dim) ...] 14 | :cvar sys_slots: [(slot_name, slot_description, dim) ...] 15 | :cvar nlg_spec: {slot_type -> {inform: [], request: [], yn_question: [(utt, target)]}} 16 | :cvar db_size: the size of database 17 | """ 18 | nlg_spec = None 19 | usr_slots = None 20 | sys_slots = None 21 | db_size = None 22 | name = None 23 | greet = None 24 | 25 | def to_dict(self): 26 | return {'nlg_spec': self.nlg_spec, 27 | 'usr_slots': self.usr_slots, 28 | 'sys_slots': self.sys_slots, 29 | 'db_size': self.db_size, 30 | 'name': self.name, 31 | 'greet': self.greet} 32 | 33 | 34 | class Slot(object): 35 | """ 36 | Class for sys/usr slot 37 | """ 38 | def __init__(self, name, description, vocabulary): 39 | self.name = name 40 | self.description = description 41 | self.vocabulary = vocabulary 42 | self.dim = len(vocabulary) 43 | self.requests = [] 44 | self.informs = [] 45 | self.yn_questions = {} 46 | 47 | def sample_request(self): 48 | if self.requests: 49 | return np.random.choice(self.requests) 50 | else: 51 | raise ValueError("Sample from empty request_utt pool") 52 | 53 | def sample_inform(self): 54 | if self.informs: 55 | return np.random.choice(self.informs) 56 | else: 57 | raise ValueError("Sample from empty inform_utt pool") 58 | 59 | def sample_yn_question(self, expect_val): 60 | questions = self.yn_questions.get(expect_val, []) 61 | if questions: 62 | return np.random.choice(questions) 63 | else: 64 | raise ValueError("Sample from empty yn_questions pool") 65 | 66 | def sample_different(self, value): 67 | if value is None: 68 | return np.random.randint(0, self.dim) 69 | else: 70 | return np.random.choice([None] + [i for i in range(self.dim) if i != value]) 71 | 72 | 73 | class Domain(object): 74 | """ 75 | A class that contains sufficient info about a slot-filling domain. Including: 76 | 77 | :ivar db: table with N items, each has I+R attributes 78 | :ivar sys_slots: a list of that the system can tell the users. Each slot is a dictionary 79 | that contains slot_name, slot_description, dimension 80 | :ivar usr_slots: a list of slots that users can impose a constrains. Each slot is a dictionary 81 | that contains slot_name, slot_description, dimension 82 | """ 83 | 84 | logger = logging.getLogger(__name__) 85 | 86 | def __init__(self, domain_spec): 87 | """ 88 | :param domain_spec: an implementation of DomainSpec 89 | """ 90 | self.name = domain_spec.name 91 | self.greet = domain_spec.greet 92 | self.usr_slots = [Slot("#"+name, desc, vocab) for name, desc, vocab in domain_spec.usr_slots] 93 | self.sys_slots = [Slot("#"+name, desc, vocab) for name, desc, vocab in domain_spec.sys_slots] 94 | self.sys_slots.insert(0, Slot(BaseSysSlot.DEFAULT, "", [str(i) for i in range(domain_spec.db_size)])) 95 | 96 | for slot_name, slot_nlg in domain_spec.nlg_spec.items(): 97 | slot_name = "#"+slot_name 98 | slot = self.get_usr_slot(slot_name) if self.is_usr_slot(slot_name) else self.get_sys_slot(slot_name) 99 | if slot: 100 | slot.informs.extend(slot_nlg['inform']) 101 | slot.requests.extend(slot_nlg['request']) 102 | slot.yn_questions = slot_nlg.get('yn_question', {}) 103 | else: 104 | raise Exception("Fail to align %s nlg spec with the rest of domain" % slot_name) 105 | usr_slot_priors = [np.ones(s.dim) for s in self.usr_slots] # we assume a uniform prior 106 | # we left out DEFAULT from prior since it'e KEY 107 | sys_slot_priors = [np.ones(s.dim) for s in self.sys_slots[1:]] 108 | 109 | self.db = Database(usr_slot_priors, sys_slot_priors, num_rows=domain_spec.db_size) 110 | self.db.pprint() 111 | 112 | def get_usr_slot(self, slot_name, return_idx=False): 113 | """ 114 | :param slot_name: the target slot name 115 | :param return_idx: True/False to return slot index 116 | :return: slot, (index) or None if it's not user slot 117 | """ 118 | for s_id, s in enumerate(self.usr_slots): 119 | if s.name == slot_name: 120 | if return_idx: 121 | return s, s_id 122 | else: 123 | return s 124 | return None 125 | 126 | def get_sys_slot(self, slot_name, return_idx=False): 127 | """ 128 | :param slot_name: the target slot name 129 | :param return_idx: True/False to return slot index 130 | :return: slot, (index) or None if it's not system slot 131 | """ 132 | for s_id, s in enumerate(self.sys_slots): 133 | if s.name == slot_name: 134 | if return_idx: 135 | return s, s_id 136 | else: 137 | return s 138 | return None 139 | 140 | def is_usr_slot(self, query_name): 141 | """ 142 | :param query_name: a slot name 143 | :return: True if slot_name is user slot, False o/w 144 | """ 145 | return query_name in [s.name for s in self.usr_slots] 146 | -------------------------------------------------------------------------------- /simdial/agent/core.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # author: Tiancheng Zhao 3 | 4 | import logging 5 | import copy 6 | 7 | 8 | class Agent(object): 9 | """ 10 | Abstract class for Agent (user or system) 11 | """ 12 | 13 | def __init__(self, domain, complexity): 14 | self.domain = domain 15 | self.complexity = complexity 16 | 17 | def step(self, *args, **kwargs): 18 | """ 19 | Given the new inputs, generate the next response 20 | 21 | :return: reward, terminal, response 22 | """ 23 | raise NotImplementedError("Implement step function is required") 24 | 25 | 26 | class Action(dict): 27 | """ 28 | A generic class that corresponds to a discourse unit. An action is made of an Act and a list of parameters. 29 | 30 | :ivar act: dialog act String 31 | :ivar parameters: [{slot -> usr_constrain}, {sys_slot -> value}] for INFORM, and [(type, value)...] for other acts. 32 | 33 | """ 34 | 35 | def __init__(self, act, parameters=None): 36 | self.act = act 37 | if parameters is None: 38 | self.parameters = [] 39 | elif type(parameters) is not list: 40 | self.parameters = [parameters] 41 | else: 42 | self.parameters = parameters 43 | super(Action, self).__init__(act=self.act, parameters=self.parameters) 44 | 45 | def add_parameter(self, type, value): 46 | self.parameters.append((type, value)) 47 | 48 | def dump_string(self): 49 | str_paras = [] 50 | for p in self.parameters: 51 | if type(p) is not str: 52 | str_paras.append(str(p)) 53 | else: 54 | str_paras.append(p) 55 | str_paras = "-".join(str_paras) 56 | return "%s:%s" % (self.act, str_paras) 57 | 58 | 59 | class State(object): 60 | """ 61 | The base class for a dialog state 62 | 63 | :ivar history: a list of turns 64 | :cvar USR: user name 65 | :cvar SYS: system name 66 | :cvar LISTEN: the agent is waiting for other's input 67 | :cvar SPEAK: the agent is generating it's output 68 | :cvar EXT: the agent leaves the session 69 | """ 70 | 71 | USR = "usr" 72 | SYS = "sys" 73 | 74 | LISTEN = "listen" 75 | SPEAK = "speak" 76 | EXIT = "exit" 77 | 78 | def __init__(self): 79 | self.history = [] 80 | 81 | def yield_floor(self, *args, **kwargs): 82 | """ 83 | Base function that decides if the agent should yield the conversation floor 84 | """ 85 | raise NotImplementedError("Yield is required") 86 | 87 | def is_terminal(self, *args, **kwargs): 88 | """ 89 | Base function decides if the agent is left 90 | """ 91 | raise NotImplementedError("is_terminal is required") 92 | 93 | def last_actions(self, target_speaker): 94 | """ 95 | Search in the dialog hisotry given a speaker. 96 | 97 | :param target_speaker: the target speaker 98 | :return: the last turn produced by the given speaker. None if not found. 99 | """ 100 | for spk, utt in self.history[::-1]: 101 | if spk == target_speaker: 102 | return utt 103 | return None 104 | 105 | def update_history(self, speaker, actions): 106 | """ 107 | Append the new turn into the history 108 | 109 | :param speaker: SYS or USR 110 | :param actions: a list of Action 111 | """ 112 | # make a deep copy of actions 113 | self.history.append((speaker, copy.deepcopy(actions))) 114 | 115 | 116 | class SystemAct(object): 117 | """ 118 | :cvar IMPLICIT_CONFIRM: you said XX 119 | :cvar EXPLICIT_CONFIRM: do you mean XX 120 | :cvar INFORM: I think XX is a good fit 121 | :cvar REQUEST: which location? 122 | :cvar GREET: hello 123 | :cvar GOODBYE: goodbye 124 | :cvar CLARIFY: I think you want either A or B. Which one is right? 125 | :cvar ASK_REPHRASE: can you please say it in another way? 126 | :cvar ASK_REPEAT: what did you say? 127 | """ 128 | 129 | IMPLICIT_CONFIRM = "implicit_confirm" 130 | EXPLICIT_CONFIRM = "explicit_confirm" 131 | INFORM = "inform" 132 | REQUEST = "request" 133 | GREET = "greet" 134 | GOODBYE = "goodbye" 135 | CLARIFY = "clarify" 136 | ASK_REPHRASE = "ask_rephrase" 137 | ASK_REPEAT = "ask_repeat" 138 | QUERY = "query" 139 | 140 | 141 | class UserAct(object): 142 | """ 143 | :cvar CONFIRM: yes 144 | :cvar DISCONFIRM: no 145 | :cvar YN_QUESTION: Is it going to rain? 146 | :cvar INFORM: I like Chinese food. 147 | :cvar REQUEST: find me a place to eat. 148 | :cvar GREET: hello 149 | :cvar NEW_SEARCH: I have a new request. 150 | :cvar GOODBYE: goodbye 151 | :cvar CHAT: how is your day 152 | """ 153 | GREET = "greet" 154 | INFORM = "inform" 155 | REQUEST = "request" 156 | YN_QUESTION = "yn_question" 157 | CONFIRM = "confirm" 158 | DISCONFIRM = "disconfirm" 159 | GOODBYE = "goodbye" 160 | NEW_SEARCH = "new_search" 161 | CHAT = "chat" 162 | SATISFY = "satisfy" 163 | MORE_REQUEST = "more_request" 164 | KB_RETURN = "kb_return" 165 | 166 | 167 | class BaseSysSlot(object): 168 | """ 169 | :cvar DEFAULT: the db entry 170 | :cvar PURPOSE: what's the purpose of the system 171 | """ 172 | 173 | PURPOSE = "#purpose" 174 | DEFAULT = "#default" 175 | 176 | 177 | class BaseUsrSlot(object): 178 | """ 179 | :cvar NEED: what user want 180 | :cvar HAPPY: if user is satisfied about system's results 181 | :cvar AGAIN: the user rephrase the same sentence. 182 | :cvar SELF_CORRECT: the user correct itself. 183 | """ 184 | NEED = "#need" 185 | HAPPY = "#happy" 186 | AGAIN = "#again" 187 | SELF_CORRECT = "#self_correct" 188 | -------------------------------------------------------------------------------- /simdial/generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # author: Tiancheng Zhao 3 | 4 | from simdial.agent.user import User 5 | from simdial.agent.system import System 6 | from simdial.channel import ActionChannel, WordChannel 7 | from simdial.agent.nlg import SysNlg, UserNlg 8 | from simdial.complexity import Complexity 9 | from simdial.domain import Domain 10 | import progressbar 11 | import json 12 | import numpy as np 13 | import sys 14 | import os 15 | import re 16 | 17 | class Generator(object): 18 | """ 19 | The generator class used to generate synthetic slot-filling human-computer conversation in any domain. 20 | The generator can be configured to generate data with varying complexity at: propositional, interaction and social 21 | level. 22 | 23 | The required input is a domain specification dictionary + a configuration dict. 24 | """ 25 | 26 | @staticmethod 27 | def pack_msg(speaker, utt, **kwargs): 28 | resp = {k: v for k, v in kwargs.items()} 29 | resp["speaker"] = speaker 30 | resp["utt"] = utt 31 | return resp 32 | 33 | @staticmethod 34 | def pprint(dialogs, in_json, domain_spec, output_file=None): 35 | """ 36 | Print the dailog to a file or STDOUT 37 | 38 | :param dialogs: a list of dialogs generated 39 | :param output_file: None if print to STDOUT. Otherwise write the file in the path 40 | """ 41 | f = sys.stdout if output_file is None else open(output_file, "wb") 42 | 43 | if in_json: 44 | combo = {'dialogs': dialogs, 'meta': domain_spec.to_dict()} 45 | json.dump(combo, f, indent=2) 46 | else: 47 | for idx, d in enumerate(dialogs): 48 | f.write("## DIALOG %d ##\n" % idx) 49 | for turn in d: 50 | speaker, utt, actions = turn["speaker"], turn["utt"], turn["actions"] 51 | if utt: 52 | str_actions = utt 53 | else: 54 | str_actions = " ".join([a.dump_string() for a in actions]) 55 | if speaker == "USR": 56 | f.write("%s(%f)-> %s\n" % (speaker, turn['conf'], str_actions)) 57 | else: 58 | f.write("%s -> %s\n" % (speaker, str_actions)) 59 | 60 | if output_file is not None: 61 | f.close() 62 | 63 | @staticmethod 64 | def print_stats(dialogs): 65 | """ 66 | Print some basic stats of the dialog. 67 | 68 | :param dialogs: A list of dialogs generated. 69 | """ 70 | print("%d dialogs" % len(dialogs)) 71 | all_lens = [len(d) for d in dialogs] 72 | print("Avg len {} Max Len {}".format(np.mean(all_lens), np.max(all_lens))) 73 | 74 | total_cnt = 0. 75 | kb_cnt = 0. 76 | ratio = [] 77 | for d in dialogs: 78 | local_cnt = 0. 79 | for t in d: 80 | total_cnt +=1 81 | if 'QUERY' in t['utt']: 82 | kb_cnt += 1 83 | local_cnt += 1 84 | ratio.append(local_cnt/len(d)) 85 | print(kb_cnt/total_cnt) 86 | print(np.mean(ratio)) 87 | 88 | def gen(self, domain, complexity, num_sess=1): 89 | """ 90 | Generate synthetic dialogs in the given domain. 91 | 92 | :param domain: a domain specification dictionary 93 | :param complexity: an implmenetaiton of Complexity 94 | :param num_sess: how dialogs to generate 95 | :return: a list of dialogs. Each dialog is a list of turns. 96 | """ 97 | dialogs = [] 98 | action_channel = ActionChannel(domain, complexity) 99 | word_channel = WordChannel(domain, complexity) 100 | 101 | # natural language generators 102 | sys_nlg = SysNlg(domain, complexity) 103 | usr_nlg = UserNlg(domain, complexity) 104 | 105 | bar = progressbar.ProgressBar(max_value=num_sess) 106 | for i in range(num_sess): 107 | bar.update(i) 108 | usr = User(domain, complexity) 109 | sys = System(domain, complexity) 110 | 111 | # begin conversation 112 | noisy_usr_as = [] 113 | dialog = [] 114 | conf = 1.0 115 | while True: 116 | # make a decision 117 | sys_r, sys_t, sys_as, sys_s = sys.step(noisy_usr_as, conf) 118 | sys_utt, sys_str_as = sys_nlg.generate_sent(sys_as, domain=domain) 119 | dialog.append(self.pack_msg("SYS", sys_utt, actions=sys_str_as, domain=domain.name, state=sys_s)) 120 | 121 | if sys_t: 122 | break 123 | 124 | usr_r, usr_t, usr_as = usr.step(sys_as) 125 | 126 | # passing through noise, nlg and noise! 127 | noisy_usr_as, conf = action_channel.transmit2sys(usr_as) 128 | usr_utt = usr_nlg.generate_sent(noisy_usr_as) 129 | noisy_usr_utt = word_channel.transmit2sys(usr_utt) 130 | 131 | dialog.append(self.pack_msg("USR", noisy_usr_utt, actions=noisy_usr_as, conf=conf, domain=domain.name)) 132 | 133 | dialogs.append(dialog) 134 | 135 | return dialogs 136 | 137 | def gen_corpus(self, name, domain_spec, complexity_spec, size): 138 | if not os.path.exists(name): 139 | os.mkdir(name) 140 | 141 | # create meta specifications 142 | domain = Domain(domain_spec) 143 | complex = Complexity(complexity_spec) 144 | 145 | # generate the corpus conditioned on domain & complexity 146 | corpus = self.gen(domain, complex, num_sess=size) 147 | 148 | # txt_file = "{}-{}-{}.{}".format(domain_spec.name, 149 | # complexity_spec.__name__, 150 | # size, 'txt') 151 | 152 | json_file = "{}-{}-{}.{}".format(domain_spec.name, 153 | complexity_spec.__name__, 154 | size, 'json') 155 | 156 | json_file = os.path.join(name, json_file) 157 | self.pprint(corpus, True, domain_spec, json_file) 158 | self.print_stats(corpus) 159 | -------------------------------------------------------------------------------- /simdial/complexity.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: Tiancheng Zhao 3 | # Date: 9/13/17 4 | 5 | 6 | class ComplexitySpec(object): 7 | """ 8 | Base class of complexity specification 9 | 10 | :cvar environment: configs for environmental noise 11 | :cvar propostion: configs for propositional noise 12 | :cvar interaction: configs for interactional noise 13 | :cvar social: configs for social noise 14 | """ 15 | environment = None 16 | proposition = None 17 | interaction = None 18 | social = None 19 | 20 | 21 | class Complexity(object): 22 | """ 23 | Complexity object used to decides the task difficulities 24 | 25 | :ivar asr_acc: the mean value of asr confidence 26 | :ivar asr_std: the std of asr confidence distribution 27 | :ivar yn_question: the chance the user will ask yn_question 28 | :ivar reject_stype: the distribution over different rejection style 29 | :ivar multi_slots: the distriibution over how many slots in a inform 30 | :ivar multi_goals: the distribution over how many goals in a dialog 31 | :ivar dont_care: the chance that user dont are about a slot 32 | :ivar hesitation: the chance that user will hesitate in an utterance 33 | :ivar self_restart: the chance that user will restart 34 | :ivar self_correct: the chance that user will correct itself in an utterance. 35 | :ivar self_discloure: the chance that system will do self discloure 36 | :ivar ref_shared: the chacne that system will do refernece 37 | :ivar violation_sn: the chance that system will do VSN 38 | """ 39 | 40 | def __init__(self, complexity_spec): 41 | # environment 42 | self.asr_acc = complexity_spec.environment['asr_acc'] 43 | self.asr_std = complexity_spec.environment['asr_std'] 44 | 45 | # propositional 46 | self.yn_question = complexity_spec.proposition['yn_question'] 47 | self.reject_style = complexity_spec.proposition['reject_style'] 48 | self.multi_slots = complexity_spec.proposition['multi_slots'] 49 | self.multi_goals = complexity_spec.proposition['multi_goals'] 50 | self.dont_care = complexity_spec.proposition['dont_care'] 51 | 52 | # interactional 53 | self.hesitation = complexity_spec.interaction['hesitation'] 54 | self.self_restart = complexity_spec.interaction['self_restart'] 55 | self.self_correct = complexity_spec.interaction['self_correct'] 56 | 57 | # social 58 | self.self_disclosure = complexity_spec.social['self_disclosure'] 59 | self.ref_shared = complexity_spec.social['ref_shared'] 60 | self.violation_sn = complexity_spec.social['violation_sn'] 61 | 62 | def get_name(self): 63 | return self.__class__.__name__ 64 | 65 | 66 | class MixSpec(ComplexitySpec): 67 | """ 68 | An example spec for the easy setting 69 | """ 70 | 71 | environment = {'asr_acc': 0.7, 72 | 'asr_std': 0.15} 73 | 74 | proposition = {'yn_question': 0.4, 75 | 'reject_style': {'reject': 0.5, 'reject+inform': 0.5}, 76 | 'multi_slots': {1: 0.7, 2: 0.3}, 77 | 'dont_care': 0.1, 78 | 'multi_goals': {1: 0.6, 2: 0.4}, 79 | } 80 | 81 | interaction = {'hesitation': 0.4, 82 | 'self_restart': 0.1, 83 | 'self_correct': 0.2} 84 | 85 | social = {'self_disclosure': None, 86 | 'ref_shared': None, 87 | 'violation_sn': None} 88 | 89 | 90 | class PropSpec(ComplexitySpec): 91 | """ 92 | An example spec for the easy setting 93 | """ 94 | 95 | environment = {'asr_acc': 1.0, 96 | 'asr_std': 0.0} 97 | 98 | proposition = {'yn_question': 0.4, 99 | 'reject_style': {'reject': 0.5, 'reject+inform': 0.5}, 100 | 'multi_slots': {1: 0.7, 2: 0.3}, 101 | 'dont_care': 0.1, 102 | 'multi_goals': {1: 0.7, 2: 0.3}, 103 | } 104 | 105 | interaction = {'hesitation': 0.0, 106 | 'self_restart': 0.0, 107 | 'self_correct': 0.0} 108 | 109 | social = {'self_disclosure': None, 110 | 'ref_shared': None, 111 | 'violation_sn': None} 112 | 113 | 114 | class EnvSpec(ComplexitySpec): 115 | """ 116 | An example spec for the easy setting 117 | """ 118 | 119 | environment = {'asr_acc': 0.7, 120 | 'asr_std': 0.2} 121 | 122 | proposition = {'yn_question': 0.0, 123 | 'reject_style': {'reject': 1.0, 'reject+inform': 0.0}, 124 | 'multi_slots': {1: 1.0, 2: 0.0}, 125 | 'dont_care': 0.0, 126 | 'multi_goals': {1: 1.0, 2: 0.0}, 127 | } 128 | 129 | interaction = {'hesitation': 0.0, 130 | 'self_restart': 0.0, 131 | 'self_correct': 0.0} 132 | 133 | social = {'self_disclosure': None, 134 | 'ref_shared': None, 135 | 'violation_sn': None} 136 | 137 | 138 | class InteractSpec(ComplexitySpec): 139 | """ 140 | An example spec for the easy setting 141 | """ 142 | 143 | environment = {'asr_acc': 1.0, 144 | 'asr_std': 0.0} 145 | 146 | proposition = {'yn_question': 0.0, 147 | 'reject_style': {'reject': 1.0, 'reject+inform': 0.0}, 148 | 'multi_slots': {1: 1.0, 2: 0.0}, 149 | 'dont_care': 0.0, 150 | 'multi_goals': {1: 1.0, 2: 0.0}, 151 | } 152 | 153 | interaction = {'hesitation': 0.4, 154 | 'self_restart': 0.1, 155 | 'self_correct': 0.2} 156 | 157 | social = {'self_disclosure': None, 158 | 'ref_shared': None, 159 | 'violation_sn': None} 160 | 161 | 162 | class CleanSpec(ComplexitySpec): 163 | """ 164 | An example spec for the easy setting 165 | """ 166 | 167 | environment = {'asr_acc': 1.0, 168 | 'asr_std': 0.0} 169 | 170 | proposition = {'yn_question': 0.0, 171 | 'reject_style': {'reject': 1.0, 'reject+inform': 0.0}, 172 | 'multi_slots': {1: 1.0, 2: 0.0}, 173 | 'dont_care': 0.0, 174 | 'multi_goals': {1: 1.0, 2: 0.0}, 175 | } 176 | 177 | interaction = {'hesitation': 0.0, 178 | 'self_restart': 0.0, 179 | 'self_correct': 0.0} 180 | 181 | social = {'self_disclosure': None, 182 | 'ref_shared': None, 183 | 'violation_sn': None} -------------------------------------------------------------------------------- /simdial/agent/nlg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: Tiancheng Zhao 3 | # Date: 9/13/17 4 | 5 | import numpy as np 6 | from simdial.agent.core import SystemAct, UserAct, BaseUsrSlot 7 | from simdial.agent import core 8 | import json 9 | import copy 10 | 11 | 12 | class AbstractNlg(object): 13 | """ 14 | Abstract class of NLG 15 | """ 16 | 17 | def __init__(self, domain, complexity): 18 | self.domain = domain 19 | self.complexity = complexity 20 | 21 | def generate_sent(self, actions, **kwargs): 22 | """ 23 | Map a list of actions to a string. 24 | 25 | :param actions: a list of actions 26 | :return: uttearnces in string 27 | """ 28 | raise NotImplementedError("Generate sent is required for NLG") 29 | 30 | def sample(self, examples): 31 | return np.random.choice(examples) 32 | 33 | 34 | class SysCommonNlg(object): 35 | templates = {SystemAct.GREET: ["Hello.", "Hi.", "Greetings.", "How are you doing?"], 36 | SystemAct.ASK_REPEAT: ["Can you please repeat that?", "What did you say?"], 37 | SystemAct.ASK_REPHRASE: ["Can you please rephrase that?", "Can you say it in another way?"], 38 | SystemAct.GOODBYE: ["Goodbye.", "See you next time."], 39 | SystemAct.CLARIFY: ["I didn't catch you."], 40 | SystemAct.REQUEST+core.BaseUsrSlot.NEED: ["What can I do for you?", 41 | "What do you need?", 42 | "How can I help?"], 43 | SystemAct.REQUEST+core.BaseUsrSlot.HAPPY: ["What else can I do?", 44 | "Are you happy about my answer?", 45 | "Anything else?"], 46 | SystemAct.EXPLICIT_CONFIRM+"dont_care": ["Okay, you dont_care, do you?", 47 | "You dont_care, right?"], 48 | SystemAct.IMPLICIT_CONFIRM+"dont_care": ["Okay, you dont_care.", 49 | "Alright, dont_care."]} 50 | 51 | class SysNlg(AbstractNlg): 52 | """ 53 | NLG class to generate utterances for the system side. 54 | """ 55 | 56 | def generate_sent(self, actions, domain=None, templates=SysCommonNlg.templates): 57 | """ 58 | Map a list of system actions to a string. 59 | 60 | :param actions: a list of actions 61 | :param templates: a common NLG template that uses the default one if not given 62 | :return: uttearnces in string 63 | """ 64 | str_actions = [] 65 | lexicalized_actions = [] 66 | for a in actions: 67 | a_copy = copy.deepcopy(a) 68 | if a.act == SystemAct.GREET: 69 | if domain: 70 | str_actions.append(domain.greet) 71 | else: 72 | str_actions.append(self.sample(templates[a.act])) 73 | 74 | elif a.act == SystemAct.QUERY: 75 | usr_constrains = a.parameters[0] 76 | sys_goals = a.parameters[1] 77 | 78 | # create string list for KB_SEARCH 79 | search_dict = {} 80 | for k, v in usr_constrains: 81 | slot = self.domain.get_usr_slot(k) 82 | if v is None: 83 | search_dict[k] = 'dont_care' 84 | else: 85 | search_dict[k] = slot.vocabulary[v] 86 | 87 | a_copy.parameters[0] = search_dict 88 | a_copy.parameters[1] = sys_goals 89 | str_actions.append(json.dumps({"QUERY": search_dict, 90 | "GOALS": sys_goals})) 91 | 92 | elif a.act == SystemAct.INFORM: 93 | sys_goals = a.parameters[1] 94 | 95 | # create string list for RET + Informs 96 | informs = [] 97 | sys_goal_dict = {} 98 | for k, (v, e_v) in sys_goals.items(): 99 | slot = self.domain.get_sys_slot(k) 100 | sys_goal_dict[k] = slot.vocabulary[v] 101 | 102 | if e_v is not None: 103 | prefix = "Yes, " if v == e_v else "No, " 104 | else: 105 | prefix = "" 106 | informs.append(prefix + slot.sample_inform() 107 | % slot.vocabulary[v]) 108 | a_copy['parameters'] = [sys_goal_dict] 109 | str_actions.append(" ".join(informs)) 110 | 111 | elif a.act == SystemAct.REQUEST: 112 | slot_type, _ = a.parameters[0] 113 | if slot_type in [core.BaseUsrSlot.NEED, core.BaseUsrSlot.HAPPY]: 114 | str_actions.append(self.sample(templates[SystemAct.REQUEST+slot_type])) 115 | else: 116 | target_slot = self.domain.get_usr_slot(slot_type) 117 | if target_slot is None: 118 | raise ValueError("none slot %s" % slot_type) 119 | str_actions.append(target_slot.sample_request()) 120 | 121 | elif a.act == SystemAct.EXPLICIT_CONFIRM: 122 | slot_type, slot_val = a.parameters[0] 123 | if slot_val is None: 124 | str_actions.append(self.sample(templates[SystemAct.EXPLICIT_CONFIRM+"dont_care"])) 125 | a_copy.parameters[0] = (slot_type, "dont_care") 126 | else: 127 | slot = self.domain.get_usr_slot(slot_type) 128 | str_actions.append("Do you mean %s?" 129 | % slot.vocabulary[slot_val]) 130 | a_copy.parameters[0] = (slot_type, slot.vocabulary[slot_val]) 131 | 132 | elif a.act == SystemAct.IMPLICIT_CONFIRM: 133 | slot_type, slot_val = a.parameters[0] 134 | if slot_val is None: 135 | str_actions.append(self.sample(templates[SystemAct.IMPLICIT_CONFIRM+"dont_care"])) 136 | a_copy.parameters[0] = (slot_type, "dont_care") 137 | else: 138 | slot = self.domain.get_usr_slot(slot_type) 139 | str_actions.append("I believe you said %s." 140 | % slot.vocabulary[slot_val]) 141 | a_copy.parameters[0] = (slot_type, slot.vocabulary[slot_val]) 142 | 143 | elif a.act in templates.keys(): 144 | str_actions.append(self.sample(templates[a.act])) 145 | 146 | else: 147 | raise ValueError("Unknown dialog act %s" % a.act) 148 | 149 | lexicalized_actions.append(a_copy) 150 | 151 | return " ".join(str_actions), lexicalized_actions 152 | 153 | 154 | class UserNlg(AbstractNlg): 155 | """ 156 | NLG class to generate utterances for the user side. 157 | """ 158 | 159 | def generate_sent(self, actions): 160 | """ 161 | Map a list of user actions to a string. 162 | 163 | :param actions: a list of actions 164 | :return: uttearnces in string 165 | """ 166 | str_actions = [] 167 | for a in actions: 168 | if a.act == UserAct.KB_RETURN: 169 | sys_goals = a.parameters[1] 170 | sys_goal_dict = {} 171 | for k, v in sys_goals.items(): 172 | slot = self.domain.get_sys_slot(k) 173 | sys_goal_dict[k] = slot.vocabulary[v] 174 | 175 | str_actions.append(json.dumps({"RET": sys_goal_dict})) 176 | elif a.act == UserAct.GREET: 177 | str_actions.append(self.sample(["Hi.", "Hello robot.", "What's up?"])) 178 | 179 | elif a.act == UserAct.GOODBYE: 180 | str_actions.append(self.sample(["That's all.", "Thank you.", "See you."])) 181 | 182 | elif a.act == UserAct.REQUEST: 183 | slot_type, _ = a.parameters[0] 184 | target_slot = self.domain.get_sys_slot(slot_type) 185 | str_actions.append(target_slot.sample_request()) 186 | 187 | elif a.act == UserAct.INFORM: 188 | has_self_correct = a.parameters[-1][0] == BaseUsrSlot.SELF_CORRECT 189 | slot_type, slot_value = a.parameters[0] 190 | target_slot = self.domain.get_usr_slot(slot_type) 191 | 192 | def get_inform_utt(val): 193 | if val is None: 194 | return self.sample(["Anything is fine.", "I don't care.", "Whatever is good."]) 195 | else: 196 | return target_slot.sample_inform() % target_slot.vocabulary[val] 197 | 198 | if has_self_correct: 199 | wrong_value = target_slot.sample_different(slot_value) 200 | wrong_utt = get_inform_utt(wrong_value) 201 | correct_utt = get_inform_utt(slot_value) 202 | connector = self.sample(["Oh no,", "Uhm sorry,", "Oh sorry,"]) 203 | str_actions.append("%s %s %s" % (wrong_utt, connector, correct_utt)) 204 | else: 205 | str_actions.append(get_inform_utt(slot_value)) 206 | 207 | elif a.act == UserAct.CHAT: 208 | str_actions.append(self.sample(["What's your name?", "Where are you from?"])) 209 | 210 | elif a.act == UserAct.YN_QUESTION: 211 | slot_type, expect_id = a.parameters[0] 212 | target_slot = self.domain.get_sys_slot(slot_type) 213 | expect_val = target_slot.vocabulary[expect_id] 214 | str_actions.append(target_slot.sample_yn_question(expect_val)) 215 | 216 | elif a.act == UserAct.CONFIRM: 217 | str_actions.append(self.sample(["Yes.", "Yep.", "Yeah.", "That's correct.", "Uh-huh."])) 218 | 219 | elif a.act == UserAct.DISCONFIRM: 220 | str_actions.append(self.sample(["No.", "Nope.", "Wrong.", "That's wrong.", "Nay."])) 221 | 222 | elif a.act == UserAct.SATISFY: 223 | str_actions.append(self.sample(["No more questions.", "I have all I need.", "All good."])) 224 | 225 | elif a.act == UserAct.MORE_REQUEST: 226 | str_actions.append(self.sample(["I have more requests.", "One more thing.", "Not done yet."])) 227 | 228 | elif a.act == UserAct.NEW_SEARCH: 229 | str_actions.append(self.sample(["I want to search a new one.", "New request.", "A new search."])) 230 | 231 | else: 232 | raise ValueError("Unknown user act %s for NLG" % a.act) 233 | 234 | return " ".join(str_actions) 235 | 236 | def add_hesitation(self, sents, actions): 237 | pass 238 | 239 | def add_self_restart(self, sents, actions): 240 | pass 241 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /simdial/agent/user.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # author: Tiancheng Zhao 3 | from simdial.agent.core import Agent, Action, UserAct, SystemAct, BaseSysSlot, BaseUsrSlot, State 4 | import logging 5 | import numpy as np 6 | import copy 7 | from collections import OrderedDict 8 | 9 | 10 | class User(Agent): 11 | """ 12 | Basic user agent 13 | 14 | :ivar usr_constrains: a combination of user slots 15 | :ivar domain: the given domain 16 | :ivar state: the dialog state 17 | """ 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | class DialogState(State): 22 | """ 23 | The dialog state object for this user simulator 24 | 25 | :ivar history: a list of tuple [(speaker, actions) ... ] 26 | :ivar spk_state: LISTEN, SPEAK or EXIT 27 | :ivar goals_met: if the system propose anything that's in user's goal 28 | :ivar: input_buffer: a list of system action that is not being handled in this turn 29 | """ 30 | def __init__(self, sys_goals): 31 | super(State, self).__init__() 32 | self.history = [] 33 | self.spk_state = self.LISTEN 34 | self.input_buffer = [] 35 | self.goals_met = OrderedDict([(g, False) for g in sys_goals]) 36 | 37 | def update_history(self, speaker, actions): 38 | """ 39 | :param speaker: SYS or USR 40 | :param actions: a list of Action 41 | """ 42 | self.history.append((speaker, actions)) 43 | 44 | def is_terminal(self): 45 | """ 46 | :return: the user wants to terminate the session 47 | """ 48 | return self.spk_state == self.EXIT 49 | 50 | def yield_floor(self): 51 | """ 52 | :return: True if user want to stop speaking 53 | """ 54 | return self.spk_state == self.LISTEN 55 | 56 | def unmet_goal(self): 57 | for k, v in self.goals_met.items(): 58 | if v is False: 59 | return k 60 | return None 61 | 62 | def update_goals_met(self, top_action): 63 | proposed_sys = top_action.parameters[1] 64 | completed_goals = [] 65 | for goal in proposed_sys.keys(): 66 | if goal in self.goals_met.keys(): 67 | self.goals_met[goal] = True 68 | completed_goals.append(goal) 69 | return completed_goals 70 | 71 | def reset_goal(self, sys_goals): 72 | self.goals_met = {g: False for g in sys_goals} 73 | 74 | def __init__(self, domain, complexity): 75 | super(User, self).__init__(domain, complexity) 76 | self.goal_cnt = np.random.choice(complexity.multi_goals.keys(), p=complexity.multi_goals.values()) 77 | self.goal_ptr = 0 78 | self.usr_constrains, self.sys_goals = self._sample_goal() 79 | self.state = self.DialogState(self.sys_goals) 80 | 81 | def state_update(self, sys_actions): 82 | """ 83 | Update the dialog state given system's action in a new turn 84 | 85 | :param sys_actions: a list of system action 86 | """ 87 | self.state.update_history(self.state.SYS, sys_actions) 88 | self.state.spk_state = self.DialogState.SPEAK 89 | self.state.input_buffer = copy.deepcopy(sys_actions) 90 | 91 | def _sample_goal(self): 92 | """ 93 | :return: {slot_name -> value} for user constrains, [slot_name, ..] for system goals 94 | """ 95 | temp_constrains = self.domain.db.sample_unique_row().tolist() 96 | temp_constrains = [None if np.random.rand() < self.complexity.dont_care 97 | else c for c in temp_constrains] 98 | # there is a chance user does not care 99 | usr_constrains = {s.name: temp_constrains[i] for i, s in enumerate(self.domain.usr_slots)} 100 | 101 | # sample the number of attribute about the system 102 | num_interest = np.random.randint(0, len(self.domain.sys_slots)-1) 103 | goal_candidates = [s.name for s in self.domain.sys_slots if s.name != BaseSysSlot.DEFAULT] 104 | selected_goals = np.random.choice(goal_candidates, size=num_interest, replace=False) 105 | np.random.shuffle(selected_goals) 106 | sys_goals = [BaseSysSlot.DEFAULT] + selected_goals.tolist() 107 | return usr_constrains, sys_goals 108 | 109 | def _constrain_equal(self, top_action): 110 | proposed_constrains = top_action.parameters[0] 111 | for k, v in self.usr_constrains.items(): 112 | if k in proposed_constrains: 113 | if v != proposed_constrains[k]: 114 | return False, k 115 | else: 116 | return False, k 117 | return True, None 118 | 119 | def _increment_goal(self): 120 | if self.goal_ptr >= self.goal_cnt-1: 121 | return None 122 | else: 123 | self.goal_ptr += 1 124 | _, self.sys_goals = self._sample_goal() 125 | change_key = np.random.choice(self.usr_constrains.keys()) 126 | change_slot = self.domain.get_usr_slot(change_key) 127 | old_value = self.usr_constrains[change_key] 128 | old_value = -1 if old_value is None else old_value 129 | new_value = np.random.randint(0, change_slot.dim-1) % change_slot.dim 130 | self.logger.info("Filp user constrain %s from %d to %d" % 131 | (change_key, old_value, new_value)) 132 | self.usr_constrains[change_key] = new_value 133 | self.state.reset_goal(self.sys_goals) 134 | return change_key 135 | 136 | def policy(self): 137 | if self.state.spk_state == self.DialogState.EXIT: 138 | return None 139 | 140 | if len(self.state.input_buffer) == 0: 141 | self.state.spk_state = self.DialogState.LISTEN 142 | return None 143 | 144 | if len(self.state.history) > 100: 145 | self.state.input_buffer = [] 146 | return Action(UserAct.GOODBYE) 147 | 148 | top_action = self.state.input_buffer[0] 149 | self.state.input_buffer.pop(0) 150 | 151 | if top_action.act == SystemAct.GREET: 152 | return Action(UserAct.GREET) 153 | 154 | elif top_action.act == SystemAct.GOODBYE: 155 | return Action(UserAct.GOODBYE) 156 | 157 | elif top_action.act == SystemAct.IMPLICIT_CONFIRM: 158 | if len(top_action.parameters) == 0: 159 | raise ValueError("IMPLICIT_CONFIRM is required to have parameter") 160 | slot_type, slot_val = top_action.parameters[0] 161 | if self.domain.is_usr_slot(slot_type): 162 | # if the confirm is right or usr does not care about this slot 163 | if slot_val == self.usr_constrains[slot_type] or self.usr_constrains[slot_type] is None: 164 | return None 165 | else: 166 | strategy = np.random.choice(self.complexity.reject_style.keys(), 167 | p=self.complexity.reject_style.values()) 168 | if strategy == "reject": 169 | return Action(UserAct.DISCONFIRM, (slot_type, slot_val)) 170 | elif strategy == "reject+inform": 171 | return [Action(UserAct.DISCONFIRM, (slot_type, slot_val)), 172 | Action(UserAct.INFORM, (slot_type, self.usr_constrains[slot_type]))] 173 | else: 174 | raise ValueError("Unknown reject strategy") 175 | else: 176 | raise ValueError("Usr cannot handle imp_confirm to non-usr slots") 177 | 178 | elif top_action.act == SystemAct.EXPLICIT_CONFIRM: 179 | if len(top_action.parameters) == 0: 180 | raise ValueError("EXPLICIT_CONFIRM is required to have parameter") 181 | slot_type, slot_val = top_action.parameters[0] 182 | if self.domain.is_usr_slot(slot_type): 183 | # if the confirm is right or usr does not care about this slot 184 | if slot_val == self.usr_constrains[slot_type]: 185 | return Action(UserAct.CONFIRM, (slot_type, slot_val)) 186 | else: 187 | return Action(UserAct.DISCONFIRM, (slot_type, slot_val)) 188 | else: 189 | raise ValueError("Usr cannot handle imp_confirm to non-usr slots") 190 | 191 | elif top_action.act == SystemAct.INFORM: 192 | if len(top_action.parameters) != 2: 193 | raise ValueError("INFORM needs to contain the constrains and goal (2 parameters)") 194 | 195 | # check if the constrains are the same 196 | valid_constrain, wrong_slot = self._constrain_equal(top_action) 197 | if valid_constrain: 198 | # update the state for goal met 199 | complete_goals = self.state.update_goals_met(top_action) 200 | next_goal = self.state.unmet_goal() 201 | 202 | if next_goal is None: 203 | slot_key = self._increment_goal() 204 | if slot_key is not None: 205 | return [Action(UserAct.NEW_SEARCH, (BaseSysSlot.DEFAULT, None)), 206 | Action(UserAct.INFORM, (slot_key, self.usr_constrains[slot_key]))] 207 | else: 208 | return [Action(UserAct.SATISFY, [(g, None) for g in complete_goals]), 209 | Action(UserAct.GOODBYE)] 210 | else: 211 | ack_act = Action(UserAct.MORE_REQUEST, [(g, None) for g in complete_goals]) 212 | if np.random.rand() < self.complexity.yn_question: 213 | # find a system slot with yn_templates 214 | slot = self.domain.get_sys_slot(next_goal) 215 | expected_val = np.random.randint(0, slot.dim) 216 | if len(slot.yn_questions.get(slot.vocabulary[expected_val], [])) > 0: 217 | # sample a expected value 218 | return [ack_act, Action(UserAct.YN_QUESTION, (slot.name, expected_val))] 219 | 220 | return [ack_act, Action(UserAct.REQUEST, (next_goal, None))] 221 | else: 222 | # find the wrong concept 223 | return Action(UserAct.INFORM, (wrong_slot, self.usr_constrains[wrong_slot])) 224 | 225 | elif top_action.act == SystemAct.REQUEST: 226 | if len(top_action.parameters) == 0: 227 | raise ValueError("Request is required to have parameter") 228 | 229 | slot_type, slot_val = top_action.parameters[0] 230 | 231 | if slot_type == BaseUsrSlot.NEED: 232 | next_goal = self.state.unmet_goal() 233 | return Action(UserAct.REQUEST, (next_goal, None)) 234 | 235 | elif slot_type == BaseUsrSlot.HAPPY: 236 | return None 237 | 238 | elif self.domain.is_usr_slot(slot_type): 239 | if len(self.domain.usr_slots) > 1: 240 | num_informs = np.random.choice(self.complexity.multi_slots.keys(), 241 | p=self.complexity.multi_slots.values(), 242 | replace=False) 243 | if num_informs > 1: 244 | candidates = [k for k, v in self.usr_constrains.items() if k != slot_type and v is not None] 245 | num_extra = min(num_informs-1, len(candidates)) 246 | if num_extra > 0: 247 | extra_keys = np.random.choice(candidates, size=num_extra, replace=False) 248 | actions = [Action(UserAct.INFORM, (key, self.usr_constrains[key])) for key in extra_keys] 249 | actions.insert(0, Action(UserAct.INFORM, (slot_type, self.usr_constrains[slot_type]))) 250 | return actions 251 | 252 | return Action(UserAct.INFORM, (slot_type, self.usr_constrains[slot_type])) 253 | 254 | else: 255 | raise ValueError("Usr cannot handle request to this type of parameters") 256 | 257 | elif top_action.act == SystemAct.CLARIFY: 258 | raise ValueError("Cannot handle clarify now") 259 | 260 | elif top_action.act == SystemAct.ASK_REPEAT: 261 | last_usr_actions = self.state.last_actions(self.state.USR) 262 | if last_usr_actions is None: 263 | raise ValueError("Unexpected ask repeat") 264 | return last_usr_actions 265 | 266 | elif top_action.act == SystemAct.ASK_REPHRASE: 267 | last_usr_actions = self.state.last_actions(self.state.USR) 268 | if last_usr_actions is None: 269 | raise ValueError("Unexpected ask rephrase") 270 | for a in last_usr_actions: 271 | a.add_parameter(BaseUsrSlot.AGAIN, True) 272 | return last_usr_actions 273 | 274 | elif top_action.act == SystemAct.QUERY: 275 | query, goals = top_action.parameters[0], top_action.parameters[1] 276 | valid_entries = self.domain.db.select([v for name, v in query]) 277 | chosen_entry = valid_entries[np.random.randint(0, len(valid_entries)), :] 278 | 279 | results = {} 280 | if chosen_entry.shape[0] > 0: 281 | for goal in goals: 282 | _, slot_id = self.domain.get_sys_slot(goal, return_idx=True) 283 | results[goal] = chosen_entry[slot_id] 284 | else: 285 | print(chosen_entry) 286 | raise ValueError("No valid entries") 287 | 288 | return Action(UserAct.KB_RETURN, [query, results]) 289 | else: 290 | raise ValueError("Unknown system act %s" % top_action.act) 291 | 292 | def step(self, inputs): 293 | """ 294 | Given a list of inputs from the system, generate a response 295 | 296 | :param inputs: a list of Action 297 | :return: reward, terminal, [Action] 298 | """ 299 | turn_actions = [] 300 | # update the dialog state 301 | self.state_update(inputs) 302 | while True: 303 | action = self.policy() 304 | if action is not None: 305 | if type(action) is list: 306 | turn_actions.extend(action) 307 | else: 308 | turn_actions.append(action) 309 | 310 | if self.state.is_terminal(): 311 | reward = 1.0 if self.state.unmet_goal() is None else -1.0 312 | self.state.update_history(self.state.USR, turn_actions) 313 | return reward, True, turn_actions 314 | 315 | if self.state.yield_floor(): 316 | self.state.update_history(self.state.USR, turn_actions) 317 | return 0.0, False, turn_actions 318 | -------------------------------------------------------------------------------- /simdial/agent/system.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # author: Tiancheng Zhao 3 | 4 | from simdial.agent.core import Agent, Action, State, SystemAct, UserAct, BaseSysSlot, BaseUsrSlot 5 | import logging 6 | from collections import OrderedDict 7 | import numpy as np 8 | import copy 9 | 10 | 11 | class BeliefSlot(object): 12 | """ 13 | A slot with a probabilistic distribution over the possible values 14 | 15 | :ivar value_map: entity_value -> (score, norm_value) 16 | :ivar last_update_turn: the last turn ID this slot is modified 17 | :ivar uid: the unique ID, i.e. slot name 18 | """ 19 | 20 | EXPLICIT_THRESHOLD = 0.2 21 | IMPLICIT_THRESHOLD = 0.6 22 | GROUND_THRESHOLD = 0.95 23 | 24 | def __init__(self, uid, vocabulary): 25 | self.uid = uid 26 | self.value_map = {} 27 | self.last_update_turn = -1 28 | self.logger = logging.getLogger(__name__) 29 | 30 | def add_new_observation(self, value, conf, turn_id): 31 | self.last_update_turn = turn_id 32 | 33 | if value in self.value_map.keys(): 34 | prev_conf = self.value_map[value] 35 | self.value_map[value] = max([prev_conf, conf]) + 0.2 36 | self.logger.info("Update %s conf to %f at turn %d" % (value, conf, turn_id)) 37 | else: 38 | self.value_map = {k: c/2 for k, c in self.value_map.items()} 39 | self.value_map[value] = conf 40 | self.logger.info("Add %s conf as %f at turn %d" % (value, conf, turn_id)) 41 | 42 | def add_grounding(self, confirm_conf, disconfirm_conf, turn_id, target_value=None): 43 | if len(self.value_map) > 0: 44 | self.last_update_turn = turn_id 45 | if target_value is None: 46 | grounded_value = self.get_maxconf_value() 47 | else: 48 | grounded_value = target_value 49 | up_conf = confirm_conf * (1.0 - self.EXPLICIT_THRESHOLD) 50 | down_conf = disconfirm_conf * (1.0 - self.EXPLICIT_THRESHOLD) 51 | old_conf = self.value_map[grounded_value] 52 | new_conf = max(0.0, min((old_conf + up_conf - down_conf), 1.5)) 53 | self.value_map[grounded_value] = new_conf 54 | self.logger.info( 55 | "Ground %s from %f to %f at turn %d" % (grounded_value, old_conf, new_conf, turn_id)) 56 | else: 57 | self.logger.warn("Warn an concept without value") 58 | 59 | def get_maxconf_value(self): 60 | if len(self.value_map) == 0: 61 | return None 62 | max_s, max_v = max([(s, v) for v, s in self.value_map.items()]) 63 | return max_v 64 | 65 | def max_conf(self): 66 | """ 67 | :return: the highest confidence of all potential values. 0.0 if its empty 68 | """ 69 | if len(self.value_map) == 0: 70 | return 0.0 71 | return max([s for s in self.value_map.values()]) 72 | 73 | def clear(self, turn_id): 74 | middle = (self.IMPLICIT_THRESHOLD+self.EXPLICIT_THRESHOLD)/2. 75 | self.value_map = {k: middle for k in self.value_map.keys()} 76 | 77 | 78 | class BeliefGoal(object): 79 | THRESHOLD = 0.7 80 | 81 | def __init__(self, uid, conf=0.0): 82 | self.uid = uid 83 | self.conf = conf 84 | self.delivered = False 85 | self.value = None 86 | self.expected_value = None 87 | 88 | def add_observation(self, conf, expected_value): 89 | self.conf = max(conf, self.conf) + 0.2 90 | self.expected_value = expected_value 91 | 92 | def get_conf(self): 93 | return self.conf 94 | 95 | def deliver(self): 96 | self.delivered = True 97 | 98 | def clear(self): 99 | self.conf = 0 100 | self.delivered = False 101 | self.expected_value = None 102 | 103 | 104 | class DialogState(State): 105 | """ 106 | The dialog state class for a system 107 | 108 | :ivar history: the raw dialog history 109 | :ivar spk_state: the FSM state for turn-taking. SPK, LISTEN or EXIT 110 | :ivar valid_entries: a list of valid system entries satisfy the user belief 111 | :ivar usr_beliefs: a dict of slot name -> BeliefSlot() 112 | :ivar sys_goals: a dict of system goal that is obligated to answer 113 | """ 114 | INFORM_THRESHOLD = 5 115 | 116 | def __init__(self, domain): 117 | super(State, self).__init__() 118 | self.history = [] 119 | self.spk_state = self.SPEAK 120 | self.usr_beliefs = OrderedDict([(s.name, BeliefSlot(s.name, s.vocabulary)) for s in domain.usr_slots]) 121 | self.sys_goals = OrderedDict([(s.name, BeliefGoal(s.name)) for s in domain.sys_slots]) 122 | self.sys_goals[BaseSysSlot.DEFAULT] = BeliefGoal(BaseSysSlot.DEFAULT, conf=1.0) 123 | self.valid_entries = domain.db.select(self.gen_query()) 124 | self.pending_return = None 125 | self.domain = domain 126 | 127 | def turn_id(self): 128 | return len(self.history) 129 | 130 | def gen_query(self): 131 | """ 132 | :return: a DB compatible query given the current usr beliefs 133 | """ 134 | query = [] 135 | for s in self.usr_beliefs.values(): 136 | max_val = s.get_maxconf_value() 137 | query.append(max_val) 138 | return query 139 | 140 | def has_pending_return(self): 141 | return self.pending_return is not None 142 | 143 | def ready_to_inform(self): 144 | # if len(self.valid_entries) <= self.INFORM_THRESHOLD: 145 | # return True 146 | 147 | for slot in self.usr_beliefs.values(): 148 | if slot.max_conf() < slot.GROUND_THRESHOLD: 149 | return False 150 | 151 | for goal in self.sys_goals.values(): 152 | if BeliefGoal.THRESHOLD > goal.get_conf() > 0: 153 | return False 154 | 155 | return True 156 | 157 | def yield_floor(self, actions): 158 | if type(actions) is list: 159 | last_action = actions[-1] 160 | else: 161 | last_action = actions 162 | return last_action.act in [SystemAct.REQUEST, SystemAct.EXPLICIT_CONFIRM, SystemAct.QUERY] 163 | 164 | def is_terminal(self): 165 | return self.spk_state == State.EXIT 166 | 167 | def reset_sys_goals(self): 168 | for goal in self.sys_goals.values(): 169 | goal.clear() 170 | self.sys_goals[BaseSysSlot.DEFAULT] = BeliefGoal(BaseSysSlot.DEFAULT, conf=1.0) 171 | 172 | def reset_slots(self): 173 | for slot in self.usr_beliefs.values(): 174 | slot.clear(self.turn_id) 175 | 176 | def state_summary(self): 177 | # return a dump of the dialog state 178 | usr_slots = [] 179 | for slot in self.usr_beliefs.values(): 180 | max_conf = slot.max_conf() 181 | max_val = slot.get_maxconf_value() 182 | if max_val is not None: 183 | usr_slot = self.domain.get_usr_slot(slot.uid) 184 | max_val = usr_slot.vocabulary[max_val] 185 | usr_slots.append({'name':slot.uid, 'max_conf': max_conf, 'max_val': max_val}) 186 | 187 | sys_goals = [] 188 | for goal in self.sys_goals.values(): 189 | value = goal.value 190 | exp_value = goal.expected_value 191 | if value is not None: 192 | sys_goal = self.domain.get_sys_slot(goal.uid) 193 | value = sys_goal.vocabulary[value] 194 | 195 | if exp_value is not None: 196 | sys_goal = self.domain.get_sys_slot(goal.uid) 197 | exp_value = sys_goal.vocabulary[exp_value] 198 | 199 | sys_goals.append({'name': goal.uid, 'delivered': goal.delivered, 200 | 'value': value, 'expected': exp_value, 201 | 'conf': goal.conf}) 202 | 203 | return {'usr_slots': usr_slots, 'sys_goals': sys_goals, 204 | 'kb_update': self.has_pending_return()} 205 | 206 | 207 | class System(Agent): 208 | """ 209 | basic system agent 210 | """ 211 | logger = logging.getLogger(__name__) 212 | 213 | def __init__(self, domain, complexity): 214 | super(System, self).__init__(domain, complexity) 215 | self.state = DialogState(domain) 216 | 217 | def state_update(self, usr_actions, conf): 218 | """ 219 | Update the dialog state given system's action in a new turn 220 | 221 | :param usr_actions: a list of system action, None if no action 222 | :param conf: float [0, 1] confidence of the parsing 223 | """ 224 | if usr_actions is None or len(usr_actions) == 0: 225 | return 226 | 227 | self.state.update_history(self.state.USR, usr_actions) 228 | self.state.spk_state = DialogState.SPEAK 229 | 230 | for action in usr_actions: 231 | # check for user confirm/disconfirm 232 | if action.act == UserAct.CONFIRM: 233 | slot, _ = action.parameters[0] 234 | self.state.usr_beliefs[slot].add_grounding(conf, 1.0 - conf, self.state.turn_id()) 235 | elif action.act == UserAct.DISCONFIRM: 236 | slot, _ = action.parameters[0] 237 | self.state.usr_beliefs[slot].add_grounding(1.0 - conf, conf, self.state.turn_id()) 238 | elif action.act == UserAct.INFORM: 239 | slot, value = action.parameters[0] 240 | self.state.usr_beliefs[slot].add_new_observation(value, conf, self.state.turn_id()) 241 | elif action.act == UserAct.REQUEST: 242 | slot, _ = action.parameters[0] 243 | self.state.sys_goals[slot].add_observation(conf, None) 244 | elif action.act == UserAct.NEW_SEARCH: 245 | self.state.reset_sys_goals() 246 | self.state.reset_slots() 247 | elif action.act == UserAct.YN_QUESTION: 248 | slot, value = action.parameters[0] 249 | self.state.sys_goals[slot].add_observation(conf, value) 250 | 251 | elif action.act == UserAct.SATISFY or action.act == UserAct.MORE_REQUEST: 252 | for para, _ in action.parameters: 253 | self.state.sys_goals[para].deliver() 254 | 255 | elif action.act == UserAct.KB_RETURN: 256 | query = action.parameters[0] 257 | results = action.parameters[1] 258 | self.state.pending_return = query 259 | for slot_name, goal in self.state.sys_goals.items(): 260 | if slot_name in results.keys(): 261 | goal.value = results[slot_name] 262 | 263 | 264 | def update_grounding(self, sys_actions): 265 | if type(sys_actions) is not list: 266 | sys_actions = [sys_actions] 267 | 268 | for a in sys_actions: 269 | if a.act == SystemAct.IMPLICIT_CONFIRM: 270 | slot, value = a.parameters[0] 271 | self.state.usr_beliefs[slot].add_grounding(1.0, 0.0, self.state.turn_id()) 272 | 273 | def policy(self): 274 | if self.state.spk_state == State.EXIT: 275 | return None 276 | 277 | # dialog opener 278 | if len(self.state.history) == 0: 279 | return [Action(SystemAct.GREET), Action(SystemAct.REQUEST, (BaseUsrSlot.NEED, None))] 280 | 281 | last_usr = self.state.last_actions(DialogState.USR) 282 | if last_usr is None: 283 | raise ValueError("System should talk first") 284 | 285 | actions = [] 286 | for usr_act in last_usr: 287 | if usr_act.act == UserAct.GOODBYE: 288 | self.state.spk_state = State.EXIT 289 | return Action(SystemAct.GOODBYE) 290 | 291 | if self.state.has_pending_return(): 292 | # system goal 293 | query = self.state.pending_return 294 | goals = {} 295 | for goal in self.state.sys_goals.values(): 296 | if goal.delivered is False and goal.conf >= BeliefGoal.THRESHOLD: 297 | goals[goal.uid] = (goal.value, goal.expected_value) 298 | 299 | actions.append(Action(SystemAct.INFORM, [dict(query), goals])) 300 | actions.append(Action(SystemAct.REQUEST, (BaseUsrSlot.HAPPY, None))) 301 | self.state.pending_return = None 302 | 303 | return actions 304 | 305 | # check if it's ready to inform 306 | elif self.state.ready_to_inform(): 307 | # INFORM + {slot -> usr_constrain} + {goal: goal_value} 308 | # user constrains 309 | query = [(key, slot.get_maxconf_value()) for key, slot in self.state.usr_beliefs.items()] 310 | # system goal 311 | goals = [] 312 | for goal in self.state.sys_goals.values(): 313 | if goal.delivered is False and goal.conf >= BeliefGoal.THRESHOLD: 314 | goals.append(goal.uid) 315 | if len(goals) == 0: 316 | raise ValueError("Empty goal. Debug!") 317 | actions.append(Action(SystemAct.QUERY, [query, goals])) 318 | return actions 319 | else: 320 | implicit_confirms = [] 321 | exp_confirms = [] 322 | requests = [] 323 | for slot in self.state.usr_beliefs.values(): 324 | if slot.max_conf() < slot.EXPLICIT_THRESHOLD: 325 | exp_confirms.append(Action(SystemAct.REQUEST, (slot.uid, None))) 326 | elif slot.max_conf() < slot.IMPLICIT_THRESHOLD: 327 | requests.append(Action(SystemAct.EXPLICIT_CONFIRM, (slot.uid, slot.get_maxconf_value()))) 328 | elif slot.max_conf() < slot.GROUND_THRESHOLD: 329 | implicit_confirms.append(Action(SystemAct.IMPLICIT_CONFIRM, (slot.uid, slot.get_maxconf_value()))) 330 | 331 | for goal in self.state.sys_goals.values(): 332 | if BeliefGoal.THRESHOLD > goal.get_conf() > 0: 333 | requests.append(Action(SystemAct.REQUEST, (BaseUsrSlot.NEED, None))) 334 | break 335 | 336 | if len(exp_confirms) > 0: 337 | actions.extend(implicit_confirms + exp_confirms[0:1]) 338 | return actions 339 | elif len(requests) > 0: 340 | actions.extend(implicit_confirms + requests[0:1]) 341 | return actions 342 | else: 343 | return implicit_confirms 344 | 345 | def step(self, inputs, conf): 346 | """ 347 | Given a list of inputs from the system, generate a response 348 | 349 | :param inputs: a list of Action 350 | :param conf: the probability that this user input is correct 351 | :return: reward, terminal, [Action], state 352 | """ 353 | turn_actions = [] 354 | # update the dialog state 355 | self.state_update(inputs, conf) 356 | state = self.state.state_summary() 357 | while True: 358 | action = self.policy() 359 | 360 | if action is not None: 361 | if type(action) is list: 362 | turn_actions.extend(action) 363 | else: 364 | turn_actions.append(action) 365 | 366 | self.update_grounding(action) 367 | 368 | if self.state.is_terminal(): 369 | self.state.update_history(self.state.SYS, turn_actions) 370 | return 0.0, True, turn_actions, state 371 | 372 | if self.state.yield_floor(turn_actions): 373 | self.state.update_history(self.state.SYS, turn_actions) 374 | return 0.0, False, turn_actions, state 375 | -------------------------------------------------------------------------------- /multiple_domains.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # author: Tiancheng Zhao 3 | from simdial.domain import Domain, DomainSpec 4 | from simdial.generator import Generator 5 | from simdial import complexity 6 | import string 7 | 8 | 9 | class RestSpec(DomainSpec): 10 | name = "restaurant" 11 | greet = "Welcome to restaurant recommendation system." 12 | nlg_spec = {"loc": {"inform": ["I am at %s.", "%s.", "I'm interested in food at %s.", "At %s.", "In %s."], 13 | "request": ["Which city are you interested in?", "Which place?"]}, 14 | 15 | "food_pref": {"inform": ["I like %s food.", "%s food.", "%s restaurant.", "%s."], 16 | "request": ["What kind of food do you like?", "What type of restaurant?"]}, 17 | 18 | "open": {"inform": ["The restaurant is %s.", "It is %s right now."], 19 | "request": ["Tell me if the restaurant is open.", "What's the hours?"], 20 | "yn_question": {'open': ["Is the restaurant open?"], 21 | 'closed': ["Is it closed?"] 22 | }}, 23 | 24 | "parking": {"inform": ["The restaurant has %s.", "This place has %s."], 25 | "request": ["What kind of parking does it have?.", "How easy is it to park?"], 26 | "yn_question": {'street parking': ["Does it have street parking?"], 27 | "valet parking": ["Does it have valet parking?"] 28 | }}, 29 | 30 | "price": {"inform": ["The restaurant serves %s food.", "The price is %s."], 31 | "request": ["What's the average price?", "How expensive it is?"], 32 | "yn_question": {'expensive': ["Is it expensive?"], 33 | 'moderate': ["Does it have moderate price?"], 34 | 'cheap': ["Is it cheap?"] 35 | }}, 36 | 37 | "default": {"inform": ["Restaurant %s is a good choice."], 38 | "request": ["I need a restaurant.", 39 | "I am looking for a restaurant.", 40 | "Recommend me a place to eat."]} 41 | } 42 | 43 | usr_slots = [("loc", "location city", ["Pittsburgh", "New York", "Boston", "Seattle", 44 | "Los Angeles", "San Francisco", "San Jose", 45 | "Philadelphia", "Washington DC", "Austin"]), 46 | ("food_pref", "food preference", ["Thai", "Chinese", "Korean", "Japanese", 47 | "American", "Italian", "Indian", "French", 48 | "Greek", "Mexican", "Russian", "Hawaiian"])] 49 | 50 | sys_slots = [("open", "if it's open now", ["open", "closed"]), 51 | ("price", "average price per person", ["cheap", "moderate", "expensive"]), 52 | ("parking", "if it has parking", ["street parking", "valet parking", "no parking"])] 53 | 54 | db_size = 100 55 | 56 | 57 | class RestStyleSpec(DomainSpec): 58 | name = "restaurant_style" 59 | greet = "Hello there. I know a lot about places to eat." 60 | nlg_spec = {"loc": {"inform": ["I am at %s.", "%s.", "I'm interested in food at %s.", "At %s.", "In %s."], 61 | "request": ["Which area are you currently locating at?", "well, what is the place?"]}, 62 | 63 | "food_pref": {"inform": ["I like %s food.", "%s food.", "%s restaurant.", "%s."], 64 | "request": ["What cusine type are you interested", "What do you like to eat?"]}, 65 | 66 | "open": {"inform": ["This wonderful place is %s.", "Currently, this place is %s."], 67 | "request": ["Tell me if the restaurant is open.", "What's the hours?"], 68 | "yn_question": {'open': ["Is the restaurant open?"], 69 | 'closed': ["Is it closed?"] 70 | }}, 71 | 72 | "parking": {"inform": ["The parking status is %s.", "For parking, it does have %s."], 73 | "request": ["What kind of parking does it have?.", "How easy is it to park?"], 74 | "yn_question": {'street parking': ["Does it have street parking?"], 75 | "valet parking": ["Does it have valet parking?"] 76 | }}, 77 | 78 | "price": {"inform": ["This eating place provides %s food.", "Let me check that for you. The price is %s."], 79 | "request": ["What's the average price?", "How expensive it is?"], 80 | "yn_question": {'expensive': ["Is it expensive?"], 81 | 'moderate': ["Does it have moderate price?"], 82 | 'cheap': ["Is it cheap?"] 83 | }}, 84 | 85 | "default": {"inform": ["Let me look up in my database. A good choice is %s."], 86 | "request": ["I need a restaurant.", 87 | "I am looking for a restaurant.", 88 | "Recommend me a place to eat."]} 89 | } 90 | 91 | usr_slots = [("loc", "location city", ["Pittsburgh", "New York", "Boston", "Seattle", 92 | "Los Angeles", "San Francisco", "San Jose", 93 | "Philadelphia", "Washington DC", "Austin"]), 94 | ("food_pref", "food preference", ["Thai", "Chinese", "Korean", "Japanese", 95 | "American", "Italian", "Indian", "French", 96 | "Greek", "Mexican", "Russian", "Hawaiian"])] 97 | 98 | sys_slots = [("open", "if it's open now", ["open", "closed"]), 99 | ("price", "average price per person", ["cheap", "moderate", "expensive"]), 100 | ("parking", "if it has parking", ["street parking", "valet parking", "no parking"])] 101 | 102 | db_size = 100 103 | 104 | 105 | class RestPittSpec(DomainSpec): 106 | name = "rest_pitt" 107 | greet = "I am an expert about Pittsburgh restaurant." 108 | 109 | nlg_spec = {"loc": {"inform": ["I am at %s.", "%s.", "I'm interested in food at %s.", "At %s.", "In %s."], 110 | "request": ["Which city are you interested in?", "Which place?"]}, 111 | 112 | "food_pref": {"inform": ["I like %s food.", "%s food.", "%s restaurant.", "%s."], 113 | "request": ["What kind of food do you like?", "What type of restaurant?"]}, 114 | 115 | "open": {"inform": ["The restaurant is %s.", "It is %s right now."], 116 | "request": ["Tell me if the restaurant is open.", "What's the hours?"], 117 | "yn_question": {'open': ["Is the restaurant open?"], 118 | 'closed': ["Is it closed?"] 119 | }}, 120 | 121 | "parking": {"inform": ["The restaurant has %s.", "This place has %s."], 122 | "request": ["What kind of parking does it have?.", "How easy is it to park?"], 123 | "yn_question": {'street parking': ["Does it have street parking?"], 124 | "valet parking": ["Does it have valet parking?"] 125 | }}, 126 | 127 | "price": {"inform": ["The restaurant serves %s food.", "The price is %s."], 128 | "request": ["What's the average price?", "How expensive it is?"], 129 | "yn_question": {'expensive': ["Is it expensive?"], 130 | 'moderate': ["Does it have moderate price?"], 131 | 'cheap': ["Is it cheap?"] 132 | }}, 133 | 134 | "default": {"inform": ["Restaurant %s is a good choice."], 135 | "request": ["I need a restaurant.", 136 | "I am looking for a restaurant.", 137 | "Recommend me a place to eat."]} 138 | } 139 | 140 | usr_slots = [("loc", "location city", ["Downtown", "CMU", "Forbes and Murray", "Craig", 141 | "Waterfront", "Airport", "U Pitt", "Mellon Park", 142 | "Lawrance", "Monroveil", "Shadyside", "Squrill Hill"]), 143 | ("food_pref", "food preference", ["healthy", "fried", "panned", "steamed", "hot pot", 144 | "grilled", "salad", "boiled", "raw", "stewed"])] 145 | 146 | sys_slots = [("open", "if it's open now", ["open", "going to start", "going to close", "closed"]), 147 | ("price", "average price per person", ["cheap", "average", "fancy"]), 148 | ("parking", "if it has parking", ["garage parking", "street parking", "no parking"])] 149 | 150 | db_size = 150 151 | 152 | 153 | class BusSpec(DomainSpec): 154 | name = "bus" 155 | greet = "Ask me about bus information." 156 | 157 | nlg_spec = {"from_loc": {"inform": ["I am at %s.", "%s.", "Leaving from %s.", "At %s.", "Departure place is %s."], 158 | "request": ["Where are you leaving from?", "What's the departure place?"]}, 159 | 160 | "to_loc": {"inform": ["Going to %s.", "%s.", "Destination is %s.", "Go to %s.", "To %s"], 161 | "request": ["Where are you going?", "Where do you want to take off?"]}, 162 | 163 | "datetime": {"inform": ["At %s.", "%s.", "I am leaving on %s.", "Departure time is %s."], 164 | "request": ["When are you going?", "What time do you need the bus?"]}, 165 | 166 | "arrive_in": {"inform": ["The bus will arrive in %s minutes.", "Arrive in %s minutes.", 167 | "Will be here in %s minutes"], 168 | "request": ["When will the bus arrive?", "How long do I need to wait?", 169 | "What's the estimated arrival time"], 170 | "yn_question": {k: ["Is it a long wait?"] if k>15 else ["Will it be here shortly?"] 171 | for k in range(0, 30, 5)}}, 172 | 173 | "duration": {"inform": ["It will take %s minutes.", "The ride is %s minutes long."], 174 | "request": ["How long will it take?.", "How much tim will it take?"], 175 | "yn_question": {k: ["Will it take long to get there?"] if k>30 else ["Is it a short trip?"] 176 | for k in range(0, 60, 5)}}, 177 | 178 | "default": {"inform": ["Bus %s can take you there."], 179 | "request": ["Look for bus information.", 180 | "I need a bus.", 181 | "Recommend me a bus to take."]} 182 | } 183 | 184 | usr_slots = [("from_loc", "departure place", ["Downtown", "CMU", "Forbes and Murray", "Craig", 185 | "Waterfront", "Airport", "U Pitt", "Mellon Park", 186 | "Lawrance", "Monroveil", "Shadyside", "Squrill Hill"]), 187 | ("to_loc", "arrival place", ["Downtown", "CMU", "Forbes and Murray", "Craig", 188 | "Waterfront", "Airport", "U Pitt", "Mellon Park", 189 | "Lawrance", "Monroveil", "Shadyside", "Squrill Hill"]), 190 | ("datetime", "leaving time", ["today", "tomorrow", "tonight", "this morning", 191 | "this afternoon"] + [str(t+1) for t in range(24)]) 192 | ] 193 | 194 | sys_slots = [("arrive_in", "how soon it arrives", [str(t) for t in range(0, 30, 5)]), 195 | ("duration", "how long it takes", [str(t) for t in range(0, 60, 5)]) 196 | ] 197 | 198 | db_size = 150 199 | 200 | 201 | class WeatherSpec(DomainSpec): 202 | name = "weather" 203 | greet = "Weather bot is here." 204 | 205 | nlg_spec = {"loc": {"inform": ["I am at %s.", "%s.", "Weather at %s.", "At %s.", "In %s."], 206 | "request": ["Which city are you interested in?", "Which place?"]}, 207 | 208 | "datetime": {"inform": ["Weather %s", "%s.", "I am interested in %s."], 209 | "request": ["What time's weather?", "What date are you interested?"]}, 210 | 211 | "temperature": {"inform": ["The temperature will be %s.", "The temperature that time will be %s."], 212 | "request": ["What's the temperature?", "What will be the temperature?"]}, 213 | 214 | "weather_type": {"inform": ["The weather will be %s.", "The weather type will be %s."], 215 | "request": ["What's the weather type?.", "What will be the weather like"], 216 | "yn_question": {k: ["Is it going to be %s?" % k] for k in 217 | ["raining", "snowing", "windy", "sunny", "foggy", "cloudy"]} 218 | }, 219 | 220 | "default": {"inform": ["Your weather report %s is here."], 221 | "request": ["What's the weather?.", 222 | "What will the weather be?"]} 223 | } 224 | 225 | usr_slots = [("loc", "location city", ["Pittsburgh", "New York", "Boston", "Seattle", 226 | "Los Angeles", "San Francisco", "San Jose", 227 | "Philadelphia", "Washington DC", "Austin"]), 228 | ("datetime", "which time's weather?", ["today", "tomorrow", "tonight", "this morning", 229 | "the day after tomorrow", "this weekend"])] 230 | 231 | sys_slots = [("temperature", "the temperature", [str(t) for t in range(20, 40, 2)]), 232 | ("weather_type", "the type", ["raining", "snowing", "windy", "sunny", "foggy", "cloudy"])] 233 | 234 | db_size = 40 235 | 236 | 237 | class MovieSpec(DomainSpec): 238 | name = "movie" 239 | greet = "Want to know about movies?" 240 | 241 | nlg_spec = {"genre": {"inform": ["I like %s movies.", "%s.", "I love %s ones.", "%s movies."], 242 | "request": ["What genre do you like?", "Which type of movie?"]}, 243 | 244 | "years": {"inform": ["Movies in %s", "In %s."], 245 | "request": ["What's the time period?", "Movie in what years?"]}, 246 | 247 | "country": {"inform": ["Movie from %s", "%s.", "From %s."], 248 | "request": ["Which country's movie?", "Movie from what country?"]}, 249 | 250 | "rating": {"inform": ["This movie has a rating of %s.", "The rating is %s."], 251 | "request": ["What's the rating?", "How people rate this movie?"], 252 | "yn_question": {"5": ["Does it have a perfect rating?"], 253 | "4": ["Does it have a rating of 4/5?"], 254 | "1": ["Does it have a very bad rating?"]} 255 | }, 256 | 257 | "company": {"inform": ["It's made by %s.", "The movie is from %s."], 258 | "request": ["Which company produced this movie?.", "Which company?"], 259 | "yn_question": {k: ["Is this movie from %s?" % k] for k in 260 | ["20th Century Fox", "Sony", "MGM", "Walt Disney", "Universal"]} 261 | }, 262 | 263 | "director": {"inform": ["The director is %s.", "It's director by %s."], 264 | "request": ["Who is the director?.", "Who directed it?"], 265 | "yn_question": {k: ["Is it directed by %s?" % k] for k in 266 | list(string.ascii_uppercase)} 267 | }, 268 | 269 | "default": {"inform": ["Movie %s is a good choice."], 270 | "request": ["Recommend a movie.", 271 | "Give me some good suggestions about movies.", 272 | "What should I watch now"]} 273 | } 274 | 275 | usr_slots = [("genre", "type of movie", ["Action", "Sci-Fi", "Comedy", "Crime", 276 | "Sport", "Documentary", "Drama", 277 | "Family", "Horror", "War", "Music", "Fantasy", "Romance", "Western"]), 278 | 279 | ("years", "when", ["60s", "70s", "80s", "90s", "2000-2010", "2010-present"]), 280 | 281 | ("country", "where ", ["USA", "France", "China", "Korea", 282 | "Japan", "Germany", "Mexico", "Russia", "Thailand"]) 283 | ] 284 | 285 | sys_slots = [("rating", "user rating", [str(t) for t in range(5)]), 286 | ("company", "the production company", ["20th Century Fox", "Sony", "MGM", "Walt Disney", "Universal"]), 287 | ("director", "the director's name", list(string.ascii_uppercase)) 288 | ] 289 | 290 | db_size = 200 291 | 292 | 293 | if __name__ == "__main__": 294 | # pipeline here 295 | # generate a fix 500 test set and 5000 training set. 296 | # generate them separately so the model can choose a subset for train and 297 | # test on all the test set to see generalization. 298 | 299 | test_size = 500 300 | train_size = 2000 301 | gen_bot = Generator() 302 | 303 | rest_spec = RestSpec() 304 | rest_style_spec = RestStyleSpec() 305 | rest_pitt_spec = RestPittSpec() 306 | bus_spec = BusSpec() 307 | movie_spec = MovieSpec() 308 | weather_spec = WeatherSpec() 309 | 310 | # restaurant 311 | gen_bot.gen_corpus("test", rest_spec, complexity.CleanSpec, test_size) 312 | gen_bot.gen_corpus("test", rest_spec, complexity.MixSpec, test_size) 313 | gen_bot.gen_corpus("train", rest_spec, complexity.CleanSpec, train_size) 314 | gen_bot.gen_corpus("train", rest_spec, complexity.MixSpec, train_size) 315 | 316 | # restaurant style 317 | gen_bot.gen_corpus("test", rest_style_spec, complexity.CleanSpec, test_size) 318 | gen_bot.gen_corpus("test", rest_style_spec, complexity.MixSpec, test_size) 319 | gen_bot.gen_corpus("train", rest_style_spec, complexity.CleanSpec, train_size) 320 | gen_bot.gen_corpus("train", rest_style_spec, complexity.MixSpec, train_size) 321 | 322 | # bus 323 | gen_bot.gen_corpus("test", bus_spec, complexity.CleanSpec, test_size) 324 | gen_bot.gen_corpus("test", bus_spec, complexity.MixSpec, test_size) 325 | gen_bot.gen_corpus("train", bus_spec, complexity.CleanSpec, train_size) 326 | gen_bot.gen_corpus("train", bus_spec, complexity.MixSpec, train_size) 327 | 328 | # weather 329 | gen_bot.gen_corpus("test", weather_spec, complexity.CleanSpec, test_size) 330 | gen_bot.gen_corpus("test", weather_spec, complexity.MixSpec, test_size) 331 | gen_bot.gen_corpus("train", weather_spec, complexity.CleanSpec, train_size) 332 | gen_bot.gen_corpus("train", weather_spec, complexity.MixSpec, train_size) 333 | 334 | # movie 335 | gen_bot.gen_corpus("test", movie_spec, complexity.CleanSpec, test_size) 336 | gen_bot.gen_corpus("test", movie_spec, complexity.MixSpec, test_size) 337 | gen_bot.gen_corpus("train", movie_spec, complexity.CleanSpec, train_size) 338 | gen_bot.gen_corpus("train", movie_spec, complexity.MixSpec, train_size) 339 | 340 | # restaurant Pitt 341 | gen_bot.gen_corpus("test", rest_pitt_spec, complexity.MixSpec, test_size) 342 | gen_bot.gen_corpus("train", rest_pitt_spec, complexity.MixSpec, train_size) 343 | --------------------------------------------------------------------------------