├── .gitignore ├── README.md ├── agenda.py ├── config.py ├── controller.py ├── datamanager.py ├── dbquery.py ├── evaluator.py ├── fetch_data.sh ├── goal_generator.py ├── hybridv.py ├── learner.py ├── main.py ├── policy.py ├── rule.py ├── tracker.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Data / Saved model / Tensorboard 2 | multiwoz/ 3 | model*/ 4 | runs/ 5 | 6 | # Log 7 | log/ 8 | result*.txt 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | pip-wheel-metadata/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MADPL 2 | 3 | Codes for the paper "Multi-Agent Task-Oriented Dialog Policy Learning with Role-Aware Reward Decomposition", and you can find our paper at [arxiv](https://arxiv.org/abs/2004.03809) 4 | 5 | Cite this paper : 6 | 7 | ``` 8 | @inproceedings{takanobu2020multi, 9 | title={Multi-Agent Task-Oriented Dialog Policy Learning with Role-Aware Reward Decomposition}, 10 | author={Takanobu, Ryuichi and Liang, Runze and Huang, Minlie}, 11 | booktitle={ACL}, 12 | pages={625--638}, 13 | year={2020} 14 | } 15 | ``` 16 | 17 | ## Data 18 | 19 | unzip [zip](https://drive.google.com/open?id=1S2RXrXwsajrdzyyvM0ca_BLfGdb0PBgD) under `data` directory, or simply running 20 | 21 | ``` 22 | sh fetch_data.sh 23 | ``` 24 | 25 | the pre-processed data are under `data/processed_data` directory 26 | 27 | - data preprocessing will be automatically done if `processed_data` directory does not exists when running `main.py` 28 | 29 | ### Use 30 | 31 | the best trained model is under `data/model_madpl` directory 32 | 33 | ``` 34 | python main.py --test True --load data/model_madpl/selected > result.txt 35 | ``` 36 | 37 | ## Run 38 | 39 | Command 40 | 41 | ``` 42 | python main.py {--[option1]=[value1] --[option2]=[value2] ... } 43 | ``` 44 | 45 | Change the corresponding options to set hyper-parameters: 46 | 47 | ``` 48 | parser.add_argument('--log_dir', type=str, default='log', help='Logging directory') 49 | parser.add_argument('--data_dir', type=str, default='data', help='Data directory') 50 | parser.add_argument('--save_dir', type=str, default='model_multi', help='Directory to store model') 51 | parser.add_argument('--load', type=str, default='', help='File name to load trained model') 52 | parser.add_argument('--pretrain', type=bool, default=False, help='Set to pretrain') 53 | parser.add_argument('--test', type=bool, default=False, help='Set to inference') 54 | parser.add_argument('--config', type=str, default='multiwoz', help='Dataset to use') 55 | parser.add_argument('--test_case', type=int, default=1000, help='Number of test cases') 56 | parser.add_argument('--save_per_epoch', type=int, default=4, help="Save model every XXX epoches") 57 | parser.add_argument('--print_per_batch', type=int, default=200, help="Print log every XXX batches") 58 | 59 | parser.add_argument('--epoch', type=int, default=48, help='Max number of epoch') 60 | parser.add_argument('--process', type=int, default=8, help='Process number') 61 | parser.add_argument('--batchsz', type=int, default=32, help='Batch size') 62 | parser.add_argument('--batchsz_traj', type=int, default=512, help='Batch size to collect trajectories') 63 | parser.add_argument('--policy_weight_sys', type=float, default=2.5, help='Pos weight on system policy pretraining') 64 | parser.add_argument('--policy_weight_usr', type=float, default=4, help='Pos weight on user policy pretraining') 65 | parser.add_argument('--lr_policy', type=float, default=1e-3, help='Learning rate of dialog policy') 66 | parser.add_argument('--lr_vnet', type=float, default=3e-5, help='Learning rate of value network') 67 | parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight decay (L2 penalty)') 68 | parser.add_argument('--gamma', type=float, default=0.99, help='Discounted factor') 69 | parser.add_argument('--clip', type=float, default=10, help='Gradient clipping') 70 | parser.add_argument('--interval', type=int, default=400, help='Update interval of target network') 71 | ``` 72 | 73 | We have implemented *distributed RL* for parallel trajectory sampling. You can set `--process` to change the number of multi-process, and set `--batchsz_traj` to change the number of trajectories each process collects before one update iteration. 74 | 75 | ### pretrain 76 | 77 | ``` 78 | python main.py --pretrain True --save_dir model_pre 79 | ``` 80 | 81 | **NOTE**: please pretrain the model first 82 | 83 | ### train 84 | 85 | ``` 86 | python main.py --load model_pre/best --lr_policy 1e-4 --save_dir model_RL --save_per_epoch 1 87 | ``` 88 | 89 | ### test 90 | 91 | ``` 92 | python main.py --test True --load model_RL/best 93 | ``` 94 | 95 | ## Requirements 96 | 97 | python 3 98 | 99 | pytorch >= 1.2 100 | -------------------------------------------------------------------------------- /agenda.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: keshuichonglx 4 | """ 5 | 6 | import random 7 | import json 8 | import torch 9 | from copy import deepcopy 10 | from goal_generator import GoalGenerator 11 | from utils import init_session, init_goal 12 | from tracker import StateTracker 13 | 14 | REF_USR_DA = { 15 | 'Attraction': { 16 | 'area': 'Area', 'type': 'Type', 'name': 'Name', 17 | 'entrance fee': 'Fee', 'address': 'Addr', 18 | 'postcode': 'Post', 'phone': 'Phone' 19 | }, 20 | 'Hospital': { 21 | 'department': 'Department', 'address': 'Addr', 'postcode': 'Post', 22 | 'phone': 'Phone' 23 | }, 24 | 'Hotel': { 25 | 'type': 'Type', 'parking': 'Parking', 'pricerange': 'Price', 26 | 'internet': 'Internet', 'area': 'Area', 'stars': 'Stars', 27 | 'name': 'Name', 'stay': 'Stay', 'day': 'Day', 'people': 'People', 28 | 'address': 'Addr', 'postcode': 'Post', 'phone': 'Phone' 29 | }, 30 | 'Police': { 31 | 'address': 'Addr', 'postcode': 'Post', 'phone': 'Phone' 32 | }, 33 | 'Restaurant': { 34 | 'food': 'Food', 'pricerange': 'Price', 'area': 'Area', 35 | 'name': 'Name', 'time': 'Time', 'day': 'Day', 'people': 'People', 36 | 'phone': 'Phone', 'postcode': 'Post', 'address': 'Addr' 37 | }, 38 | 'Taxi': { 39 | 'leaveAt': 'Leave', 'destination': 'Dest', 'departure': 'Depart', 'arriveBy': 'Arrive', 40 | 'car type': 'Car', 'phone': 'Phone' 41 | }, 42 | 'Train': { 43 | 'destination': 'Dest', 'day': 'Day', 'arriveBy': 'Arrive', 44 | 'departure': 'Depart', 'leaveAt': 'Leave', 'people': 'People', 45 | 'duration': 'Time', 'price': 'Ticket', 'trainID': 'Id' 46 | } 47 | } 48 | 49 | REF_SYS_DA = { 50 | 'Attraction': { 51 | 'Addr': "address", 'Area': "area", 'Choice': "choice", 52 | 'Fee': "entrance fee", 'Name': "name", 'Phone': "phone", 53 | 'Post': "postcode", 'Price': "pricerange", 'Type': "type", 54 | 'none': None, 'Open': None 55 | }, 56 | 'Hospital': { 57 | 'Department': 'department', 'Addr': 'address', 'Post': 'postcode', 58 | 'Phone': 'phone', 'none': None 59 | }, 60 | 'Booking': { 61 | 'Day': 'day', 'Name': 'name', 'People': 'people', 62 | 'Ref': 'ref', 'Stay': 'stay', 'Time': 'time', 63 | 'none': None 64 | }, 65 | 'Hotel': { 66 | 'Addr': "address", 'Area': "area", 'Choice': "choice", 67 | 'Internet': "internet", 'Name': "name", 'Parking': "parking", 68 | 'Phone': "phone", 'Post': "postcode", 'Price': "pricerange", 69 | 'Ref': "ref", 'Stars': "stars", 'Type': "type", 70 | 'none': None 71 | }, 72 | 'Restaurant': { 73 | 'Addr': "address", 'Area': "area", 'Choice': "choice", 74 | 'Name': "name", 'Food': "food", 'Phone': "phone", 75 | 'Post': "postcode", 'Price': "pricerange", 'Ref': "ref", 76 | 'none': None 77 | }, 78 | 'Taxi': { 79 | 'Arrive': "arriveBy", 'Car': "car type", 'Depart': "departure", 80 | 'Dest': "destination", 'Leave': "leaveAt", 'Phone': "phone", 81 | 'none': None 82 | }, 83 | 'Train': { 84 | 'Arrive': "arriveBy", 'Choice': "choice", 'Day': "day", 85 | 'Depart': "departure", 'Dest': "destination", 'Id': "trainID", 86 | 'Leave': "leaveAt", 'People': "people", 'Ref': "ref", 87 | 'Time': "duration", 'none': None, 'Ticket': 'price', 88 | }, 89 | 'Police': { 90 | 'Addr': "address", 'Post': "postcode", 'Phone': "phone" 91 | }, 92 | } 93 | 94 | 95 | DEF_VAL_UNK = '?' # Unknown 96 | DEF_VAL_DNC = 'don\'t care' # Do not care 97 | DEF_VAL_NUL = 'none' # for none 98 | DEF_VAL_BOOKED = 'yes' # for booked 99 | DEF_VAL_NOBOOK = 'no' # for booked 100 | NOT_SURE_VALS = [DEF_VAL_UNK, DEF_VAL_DNC, DEF_VAL_NUL, DEF_VAL_NOBOOK] 101 | 102 | # import reflect table 103 | REF_USR_DA_M = deepcopy(REF_USR_DA) 104 | REF_SYS_DA_M = {} 105 | for dom, ref_slots in REF_SYS_DA.items(): 106 | dom = dom.lower() 107 | REF_SYS_DA_M[dom] = {} 108 | for slot_a, slot_b in ref_slots.items(): 109 | REF_SYS_DA_M[dom][slot_a.lower()] = slot_b 110 | REF_SYS_DA_M[dom]['none'] = None 111 | 112 | # def book slot 113 | BOOK_SLOT = ['people', 'day', 'stay', 'time'] 114 | 115 | class UserAgenda(StateTracker): 116 | """ The rule-based user policy model by agenda""" 117 | 118 | def __init__(self, data_dir, cfg): 119 | super(UserAgenda, self).__init__(data_dir, cfg) 120 | self.max_turn = 40 121 | self.max_initiative = 4 122 | 123 | # load stand value 124 | with open(data_dir + '/' + cfg.ontology_file) as f: 125 | self.stand_value_dict = json.load(f) 126 | 127 | self.goal_generator = GoalGenerator(data_dir, cfg, 128 | goal_model_path='processed_data/goal_model.pkl', 129 | corpus_path=cfg.data_file) 130 | 131 | self.goal = None 132 | self.agenda = None 133 | 134 | def _action_to_dict(self, das): 135 | da_dict = {} 136 | for da, value in das.items(): 137 | domain, intent, slot, p = da.split('-') 138 | domint = '-'.join((domain, intent)) 139 | if domint not in da_dict: 140 | da_dict[domint] = [] 141 | da_dict[domint].append([slot, value]) 142 | return da_dict 143 | 144 | def _dict_to_vec(self, das): 145 | da_vector = torch.zeros(self.cfg.a_dim_usr, dtype=torch.int32) 146 | for domint in das: 147 | pairs = das[domint] 148 | for slot, value in pairs: 149 | da = '-'.join((domint, slot)).lower() 150 | if da in self.cfg.da2idx_u: 151 | idx = self.cfg.da2idx_u[da] 152 | da_vector[idx] = 1 153 | return da_vector 154 | 155 | def reset(self, random_seed=None): 156 | """ Build new Goal and Agenda for next session """ 157 | self.time_step = 0 158 | self.topic = '' 159 | self.goal = Goal(self.goal_generator, seed=random_seed) 160 | self.agenda = Agenda(self.goal) 161 | 162 | dummy_state, dummy_goal = init_session(-1, self.cfg) 163 | init_goal(dummy_goal, dummy_state['goal_state'], self.goal.domain_goals, self.cfg) 164 | 165 | domain_ordering = self.goal.domains 166 | dummy_state['next_available_domain'] = domain_ordering[0] 167 | dummy_state['invisible_domains'] = domain_ordering[1:] 168 | 169 | dummy_state['user_goal'] = dummy_goal 170 | self.evaluator.add_goal(dummy_goal) 171 | 172 | usr_a, terminal = self.predict(None, {}) 173 | usr_a = self._dict_to_vec(usr_a) 174 | usr_a[-1] = 1 if terminal else 0 175 | init_state = self.update_belief_usr(dummy_state, usr_a) 176 | return init_state 177 | 178 | def step(self, s, sys_a): 179 | """ 180 | interact with simulator for one sys-user turn 181 | """ 182 | # update state with sys_act 183 | current_s = self.update_belief_sys(s, sys_a) 184 | if current_s['others']['terminal']: 185 | # user has terminated the session at last turn 186 | usr_a, terminal = torch.zeros(self.cfg.a_dim_usr, dtype=torch.int32), True 187 | else: 188 | da_dict = self._action_to_dict(current_s['sys_action']) 189 | usr_a, terminal = self.predict(None, da_dict) 190 | usr_a = self._dict_to_vec(usr_a) 191 | 192 | # update state with user_act 193 | usr_a[-1] = 1 if terminal else 0 194 | next_s = self.update_belief_usr(current_s, usr_a) 195 | return next_s, terminal 196 | 197 | def predict(self, state, sys_action): 198 | """ 199 | Predict an user act based on state and preorder system action. 200 | Args: 201 | state (tuple): Dialog state. 202 | sys_action (tuple): Preorder system action.s 203 | Returns: 204 | action (tuple): User act. 205 | session_over (boolean): True to terminate session, otherwise session continues. 206 | reward (float): Reward given by user. 207 | """ 208 | if self.time_step >= self.max_turn: 209 | self.agenda.close_session() 210 | else: 211 | sys_action = self._transform_sysact_in(sys_action) 212 | self.agenda.update(sys_action, self.goal) 213 | if self.goal.task_complete(): 214 | self.agenda.close_session() 215 | 216 | # A -> A' + user_action 217 | action = self.agenda.get_action(self.max_initiative) 218 | 219 | # Is there any action to say? 220 | session_over = self.agenda.is_empty() 221 | 222 | # transform to DA 223 | action = self._transform_usract_out(action) 224 | 225 | return action, session_over 226 | 227 | def _transform_usract_out(self, action): 228 | new_action = {} 229 | for act in action.keys(): 230 | if '-' in act: 231 | if 'general' not in act: 232 | (dom, intent) = act.split('-') 233 | new_act = dom.capitalize() + '-' + intent.capitalize() 234 | new_action[new_act] = [] 235 | for pairs in action[act]: 236 | slot = REF_USR_DA_M[dom.capitalize()].get(pairs[0], None) 237 | if slot is not None: 238 | new_action[new_act].append([slot, pairs[1]]) 239 | else: 240 | new_action[act] = action[act] 241 | else: 242 | pass 243 | return new_action 244 | 245 | def _transform_sysact_in(self, action): 246 | new_action = {} 247 | if not isinstance(action, dict): 248 | print('illegal da:', action) 249 | return new_action 250 | 251 | for act in action.keys(): 252 | if not isinstance(act, str) or '-' not in act: 253 | print('illegal act: %s' % act) 254 | continue 255 | 256 | if 'general' not in act: 257 | (dom, intent) = act.lower().split('-') 258 | if dom in REF_SYS_DA_M.keys(): 259 | new_list = [] 260 | for pairs in action[act]: 261 | if (not isinstance(pairs, list) and not isinstance(pairs, tuple)) or\ 262 | (len(pairs) < 2) or\ 263 | (not isinstance(pairs[0], str) or not isinstance(pairs[1], str)): 264 | print('illegal pairs:', pairs) 265 | continue 266 | 267 | if REF_SYS_DA_M[dom].get(pairs[0].lower(), None) is not None: 268 | new_list.append([REF_SYS_DA_M[dom][pairs[0].lower()], self._normalize_value(dom, intent, REF_SYS_DA_M[dom][pairs[0].lower()], pairs[1])]) 269 | 270 | if len(new_list) > 0: 271 | new_action[act.lower()] = new_list 272 | else: 273 | new_action[act.lower()] = action[act] 274 | 275 | return new_action 276 | 277 | def _normalize_value(self, domain, intent, slot, value): 278 | if intent == 'request': 279 | return DEF_VAL_UNK 280 | 281 | if domain not in self.stand_value_dict.keys(): 282 | return value 283 | 284 | if slot not in self.stand_value_dict[domain]: 285 | return value 286 | 287 | if domain == 'taxi' and slot == 'phone': 288 | return value 289 | 290 | value_list = self.stand_value_dict[domain][slot] 291 | if value not in value_list and value != 'none': 292 | v0 = ' '.join(value.split()) 293 | v0N = ''.join(value.split()) 294 | for val in value_list: 295 | v1 = ' '.join(val.split()) 296 | if v0 in v1 or v1 in v0 or v0N in v1 or v1 in v0N: 297 | return v1 298 | print('illegal value: %s, slot: %s domain: %s' % (value, slot, domain)) 299 | return value 300 | 301 | def check_constraint(slot, val_usr, val_sys): 302 | try: 303 | if slot == 'arriveBy': 304 | val1 = int(val_usr.split(':')[0]) * 100 + int(val_usr.split(':')[1]) 305 | val2 = int(val_sys.split(':')[0]) * 100 + int(val_sys.split(':')[1]) 306 | if val1 < val2: 307 | return True 308 | elif slot == 'leaveAt': 309 | val1 = int(val_usr.split(':')[0]) * 100 + int(val_usr.split(':')[1]) 310 | val2 = int(val_sys.split(':')[0]) * 100 + int(val_sys.split(':')[1]) 311 | if val1 > val2: 312 | return True 313 | else: 314 | if val_usr != val_sys: 315 | return True 316 | return False 317 | except: 318 | return False 319 | 320 | class Goal(object): 321 | """ User Goal Model Class. """ 322 | 323 | def __init__(self, goal_generator: GoalGenerator, seed=None): 324 | """ 325 | create new Goal by random 326 | Args: 327 | goal_generator (GoalGenerator): Goal Gernerator. 328 | """ 329 | self.domain_goals = goal_generator.get_user_goal(seed) 330 | self.domains = list(self.domain_goals['domain_ordering']) 331 | del self.domain_goals['domain_ordering'] 332 | 333 | for domain in self.domains: 334 | if 'reqt' in self.domain_goals[domain].keys(): 335 | self.domain_goals[domain]['reqt'] = {slot: DEF_VAL_UNK for slot in self.domain_goals[domain]['reqt']} 336 | 337 | if 'book' in self.domain_goals[domain].keys(): 338 | self.domain_goals[domain]['booked'] = DEF_VAL_UNK 339 | 340 | def task_complete(self): 341 | """ 342 | Check that all requests have been met 343 | Returns: 344 | (boolean): True to accomplish. 345 | """ 346 | for domain in self.domains: 347 | if 'reqt' in self.domain_goals[domain]: 348 | reqt_vals = self.domain_goals[domain]['reqt'].values() 349 | for val in reqt_vals: 350 | if val in NOT_SURE_VALS: 351 | return False 352 | 353 | if 'booked' in self.domain_goals[domain]: 354 | if self.domain_goals[domain]['booked'] in NOT_SURE_VALS: 355 | return False 356 | return True 357 | 358 | def next_domain_incomplete(self): 359 | # request 360 | for domain in self.domains: 361 | # reqt 362 | if 'reqt' in self.domain_goals[domain]: 363 | requests = self.domain_goals[domain]['reqt'] 364 | unknow_reqts = [key for (key, val) in requests.items() if val in NOT_SURE_VALS] 365 | if len(unknow_reqts) > 0: 366 | return domain, 'reqt', ['name'] if 'name' in unknow_reqts else unknow_reqts 367 | 368 | # book 369 | if 'booked' in self.domain_goals[domain]: 370 | if self.domain_goals[domain]['booked'] in NOT_SURE_VALS: 371 | return domain, 'book', \ 372 | self.domain_goals[domain]['fail_book'] if 'fail_book' in self.domain_goals[domain].keys() else self.domain_goals[domain]['book'] 373 | 374 | return None, None, None 375 | 376 | 377 | class Agenda(object): 378 | def __init__(self, goal: Goal): 379 | """ 380 | Build a new agenda from goal 381 | Args: 382 | goal (Goal): User goal. 383 | """ 384 | 385 | def random_sample(data, minimum=0, maximum=1000): 386 | return random.sample(data, random.randint(min(len(data), minimum), min(len(data), maximum))) 387 | 388 | self.CLOSE_ACT = 'general-bye' 389 | self.HELLO_ACT = 'general-greet' 390 | self.__cur_push_num = 0 391 | 392 | self.__stack = [] 393 | 394 | # there is a 'bye' action at the bottom of the stack 395 | self.__push(self.CLOSE_ACT) 396 | 397 | for idx in range(len(goal.domains) - 1, -1, -1): 398 | domain = goal.domains[idx] 399 | 400 | # inform 401 | if 'fail_info' in goal.domain_goals[domain]: 402 | for slot in random_sample(goal.domain_goals[domain]['fail_info'].keys(), 403 | len(goal.domain_goals[domain]['fail_info'])): 404 | self.__push(domain + '-inform', slot, goal.domain_goals[domain]['fail_info'][slot]) 405 | elif 'info' in goal.domain_goals[domain]: 406 | for slot in random_sample(goal.domain_goals[domain]['info'].keys(), 407 | len(goal.domain_goals[domain]['info'])): 408 | self.__push(domain + '-inform', slot, goal.domain_goals[domain]['info'][slot]) 409 | 410 | self.cur_domain = None 411 | 412 | def update(self, sys_action, goal: Goal): 413 | """ 414 | update Goal by current agent action and current goal. { A' + G" + sys_action -> A" } 415 | Args: 416 | sys_action (tuple): Preorder system action.s 417 | goal (Goal): User Goal 418 | """ 419 | self.__cur_push_num = 0 420 | self._update_current_domain(sys_action, goal) 421 | 422 | for diaact in sys_action.keys(): 423 | slot_vals = sys_action[diaact] 424 | if 'nooffer' in diaact: 425 | if self.update_domain(diaact, slot_vals, goal): 426 | return 427 | elif 'nobook' in diaact: 428 | if self.update_booking(diaact, slot_vals, goal): 429 | return 430 | 431 | for diaact in sys_action.keys(): 432 | if 'nooffer' in diaact or 'nobook' in diaact: 433 | continue 434 | 435 | slot_vals = sys_action[diaact] 436 | if 'booking' in diaact: 437 | if self.update_booking(diaact, slot_vals, goal): 438 | return 439 | elif 'general' in diaact: 440 | if self.update_general(diaact, slot_vals, goal): 441 | return 442 | else: 443 | if self.update_domain(diaact, slot_vals, goal): 444 | return 445 | 446 | unk_dom, unk_type, data = goal.next_domain_incomplete() 447 | if unk_dom is not None: 448 | if unk_type == 'reqt' and not self._check_reqt_info(unk_dom) and not self._check_reqt(unk_dom): 449 | for slot in data: 450 | self._push_item(unk_dom + '-request', slot, DEF_VAL_UNK) 451 | elif unk_type == 'book' and not self._check_reqt_info(unk_dom) and not self._check_book_info(unk_dom): 452 | for (slot, val) in data.items(): 453 | self._push_item(unk_dom + '-inform', slot, val) 454 | 455 | def update_booking(self, diaact, slot_vals, goal: Goal): 456 | """ 457 | Handel Book-XXX 458 | :param diaact: Dial-Act 459 | :param slot_vals: slot value pairs 460 | :param goal: Goal 461 | :return: True:user want to close the session. False:session is continue 462 | """ 463 | _, intent = diaact.split('-') 464 | domain = self.cur_domain 465 | 466 | if domain not in goal.domains: 467 | return False 468 | 469 | g_reqt = goal.domain_goals[domain].get('reqt', dict({})) 470 | g_info = goal.domain_goals[domain].get('info', dict({})) 471 | g_fail_info = goal.domain_goals[domain].get('fail_info', dict({})) 472 | g_book = goal.domain_goals[domain].get('book', dict({})) 473 | g_fail_book = goal.domain_goals[domain].get('fail_book', dict({})) 474 | 475 | if intent in ['book', 'inform']: 476 | info_right = True 477 | for [slot, value] in slot_vals: 478 | if domain == 'train' and slot == 'time': 479 | slot = 'duration' 480 | 481 | if slot in g_reqt: 482 | if not self._check_reqt_info(domain): 483 | self._remove_item(domain + '-request', slot) 484 | if value in NOT_SURE_VALS: 485 | g_reqt[slot] = '\"' + value + '\"' 486 | else: 487 | g_reqt[slot] = value 488 | 489 | elif slot in g_fail_info and value != g_fail_info[slot]: 490 | self._push_item(domain + '-inform', slot, g_fail_info[slot]) 491 | info_right = False 492 | elif len(g_fail_info) <= 0 and slot in g_info and check_constraint(slot, g_info[slot], value): 493 | self._push_item(domain + '-inform', slot, g_info[slot]) 494 | info_right = False 495 | 496 | elif slot in g_fail_book and value != g_fail_book[slot]: 497 | self._push_item(domain + '-inform', slot, g_fail_book[slot]) 498 | info_right = False 499 | elif len(g_fail_book) <= 0 and slot in g_book and value != g_book[slot]: 500 | self._push_item(domain + '-inform', slot, g_book[slot]) 501 | info_right = False 502 | 503 | else: 504 | pass 505 | 506 | if intent == 'book' and info_right: 507 | # booked ok 508 | if 'booked' in goal.domain_goals[domain]: 509 | goal.domain_goals[domain]['booked'] = DEF_VAL_BOOKED 510 | self._push_item('general-thank') 511 | 512 | elif intent in ['nobook']: 513 | if len(g_fail_book) > 0: 514 | # Discard fail_book data and update the book data to the stack 515 | for slot in g_book.keys(): 516 | if (slot not in g_fail_book) or (slot in g_fail_book and g_fail_book[slot] != g_book[slot]): 517 | self._push_item(domain + '-inform', slot, g_book[slot]) 518 | 519 | # change fail_info name 520 | goal.domain_goals[domain]['fail_book_fail'] = goal.domain_goals[domain].pop('fail_book') 521 | elif 'booked' in goal.domain_goals[domain].keys(): 522 | self.close_session() 523 | return True 524 | 525 | elif intent in ['request']: 526 | for [slot, _] in slot_vals: 527 | if domain == 'train' and slot == 'time': 528 | slot = 'duration' 529 | 530 | if slot in g_reqt: 531 | pass 532 | elif slot in g_fail_info: 533 | self._push_item(domain + '-inform', slot, g_fail_info[slot]) 534 | elif len(g_fail_info) <= 0 and slot in g_info: 535 | self._push_item(domain + '-inform', slot, g_info[slot]) 536 | 537 | elif slot in g_fail_book: 538 | self._push_item(domain + '-inform', slot, g_fail_book[slot]) 539 | elif len(g_fail_book) <= 0 and slot in g_book: 540 | self._push_item(domain + '-inform', slot, g_book[slot]) 541 | 542 | else: 543 | 544 | if domain == 'taxi' and (slot == 'destination' or slot == 'departure'): 545 | places = [dom for dom in goal.domains[: goal.domains.index('taxi')] if 546 | 'address' in goal.domain_goals[dom]['reqt']] 547 | 548 | if len(places) >= 1 and slot == 'destination' and \ 549 | goal.domain_goals[places[-1]]['reqt']['address'] not in NOT_SURE_VALS: 550 | self._push_item(domain + '-inform', slot, goal.domain_goals[places[-1]]['reqt']['address']) 551 | 552 | elif len(places) >= 2 and slot == 'departure' and \ 553 | goal.domain_goals[places[-2]]['reqt']['address'] not in NOT_SURE_VALS: 554 | self._push_item(domain + '-inform', slot, goal.domain_goals[places[-2]]['reqt']['address']) 555 | 556 | elif random.random() < 0.5: 557 | self._push_item(domain + '-inform', slot, DEF_VAL_DNC) 558 | 559 | elif random.random() < 0.5: 560 | self._push_item(domain + '-inform', slot, DEF_VAL_DNC) 561 | 562 | return False 563 | 564 | def update_domain(self, diaact, slot_vals, goal: Goal): 565 | """ 566 | Handel Domain-XXX 567 | :param diaact: Dial-Act 568 | :param slot_vals: slot value pairs 569 | :param goal: Goal 570 | :return: True:user want to close the session. False:session is continue 571 | """ 572 | domain, intent = diaact.split('-') 573 | 574 | if domain not in goal.domains: 575 | return False 576 | 577 | g_reqt = goal.domain_goals[domain].get('reqt', dict({})) 578 | g_info = goal.domain_goals[domain].get('info', dict({})) 579 | g_fail_info = goal.domain_goals[domain].get('fail_info', dict({})) 580 | g_book = goal.domain_goals[domain].get('book', dict({})) 581 | g_fail_book = goal.domain_goals[domain].get('fail_book', dict({})) 582 | 583 | if intent in ['inform', 'recommend', 'offerbook', 'offerbooked']: 584 | info_right = True 585 | for [slot, value] in slot_vals: 586 | if slot in g_reqt: 587 | if not self._check_reqt_info(domain): 588 | self._remove_item(domain + '-request', slot) 589 | if value in NOT_SURE_VALS: 590 | g_reqt[slot] = '\"' + value + '\"' 591 | else: 592 | g_reqt[slot] = value 593 | 594 | elif slot in g_fail_info and value != g_fail_info[slot]: 595 | self._push_item(domain + '-inform', slot, g_fail_info[slot]) 596 | info_right = False 597 | elif len(g_fail_info) <= 0 and slot in g_info and check_constraint(slot, g_info[slot], value): 598 | self._push_item(domain + '-inform', slot, g_info[slot]) 599 | info_right = False 600 | 601 | elif slot in g_fail_book and value != g_fail_book[slot]: 602 | self._push_item(domain + '-inform', slot, g_fail_book[slot]) 603 | info_right = False 604 | elif len(g_fail_book) <= 0 and slot in g_book and value != g_book[slot]: 605 | self._push_item(domain + '-inform', slot, g_book[slot]) 606 | info_right = False 607 | 608 | else: 609 | pass 610 | 611 | if intent == 'offerbooked' and info_right: 612 | # booked ok 613 | if 'booked' in goal.domain_goals[domain]: 614 | goal.domain_goals[domain]['booked'] = DEF_VAL_BOOKED 615 | self._push_item('general-thank') 616 | 617 | elif intent in ['request']: 618 | for [slot, _] in slot_vals: 619 | if slot in g_reqt: 620 | pass 621 | elif slot in g_fail_info: 622 | self._push_item(domain + '-inform', slot, g_fail_info[slot]) 623 | elif len(g_fail_info) <= 0 and slot in g_info: 624 | self._push_item(domain + '-inform', slot, g_info[slot]) 625 | 626 | elif slot in g_fail_book: 627 | self._push_item(domain + '-inform', slot, g_fail_book[slot]) 628 | elif len(g_fail_book) <= 0 and slot in g_book: 629 | self._push_item(domain + '-inform', slot, g_book[slot]) 630 | 631 | else: 632 | 633 | if domain == 'taxi' and (slot == 'destination' or slot == 'departure'): 634 | places = [dom for dom in goal.domains[: goal.domains.index('taxi')] if 635 | 'address' in goal.domain_goals[dom]['reqt']] 636 | 637 | if len(places) >= 1 and slot == 'destination' and \ 638 | goal.domain_goals[places[-1]]['reqt']['address'] not in NOT_SURE_VALS: 639 | self._push_item(domain + '-inform', slot, goal.domain_goals[places[-1]]['reqt']['address']) 640 | 641 | elif len(places) >= 2 and slot == 'departure' and \ 642 | goal.domain_goals[places[-2]]['reqt']['address'] not in NOT_SURE_VALS: 643 | self._push_item(domain + '-inform', slot, goal.domain_goals[places[-2]]['reqt']['address']) 644 | 645 | elif random.random() < 0.5: 646 | self._push_item(domain + '-inform', slot, DEF_VAL_DNC) 647 | 648 | elif random.random() < 0.5: 649 | self._push_item(domain + '-inform', slot, DEF_VAL_DNC) 650 | 651 | elif intent in ['nooffer']: 652 | if len(g_fail_info) > 0: 653 | # update info data to the stack 654 | for slot in g_info.keys(): 655 | if (slot not in g_fail_info) or (slot in g_fail_info and g_fail_info[slot] != g_info[slot]): 656 | self._push_item(domain + '-inform', slot, g_info[slot]) 657 | 658 | # change fail_info name 659 | goal.domain_goals[domain]['fail_info_fail'] = goal.domain_goals[domain].pop('fail_info') 660 | elif len(g_reqt.keys()) > 0: 661 | self.close_session() 662 | return True 663 | 664 | elif intent in ['select']: 665 | # delete Choice 666 | slot_vals = [[slot, val] for [slot, val] in slot_vals if slot != 'choice'] 667 | 668 | if len(slot_vals) > 0: 669 | slot = slot_vals[0][0] 670 | 671 | if slot in g_fail_info: 672 | self._push_item(domain + '-inform', slot, g_fail_info[slot]) 673 | elif len(g_fail_info) <= 0 and slot in g_info: 674 | self._push_item(domain + '-inform', slot, g_info[slot]) 675 | 676 | elif slot in g_fail_book: 677 | self._push_item(domain + '-inform', slot, g_fail_book[slot]) 678 | elif len(g_fail_book) <= 0 and slot in g_book: 679 | self._push_item(domain + '-inform', slot, g_book[slot]) 680 | 681 | else: 682 | if not self._check_reqt_info(domain): 683 | [slot, value] = random.choice(slot_vals) 684 | self._push_item(domain + '-inform', slot, value) 685 | 686 | if slot in g_reqt: 687 | self._remove_item(domain + '-request', slot) 688 | g_reqt[slot] = value 689 | 690 | return False 691 | 692 | def update_general(self, diaact, slot_vals, goal: Goal): 693 | domain, intent = diaact.split('-') 694 | 695 | if intent == 'bye': 696 | pass 697 | elif intent == 'greet': 698 | pass 699 | elif intent == 'reqmore': 700 | pass 701 | elif intent == 'welcome': 702 | pass 703 | 704 | return False 705 | 706 | def close_session(self): 707 | """ Clear up all actions """ 708 | self.__stack = [] 709 | self.__push(self.CLOSE_ACT) 710 | 711 | def get_action(self, initiative=1): 712 | """ 713 | get multiple acts based on initiative 714 | Args: 715 | initiative (int): number of slots , just for 'inform' 716 | Returns: 717 | action (dict): user diaact 718 | """ 719 | diaacts, slots, values = self.__pop(initiative) 720 | action = {} 721 | for (diaact, slot, value) in zip(diaacts, slots, values): 722 | if diaact not in action.keys(): 723 | action[diaact] = [] 724 | action[diaact].append([slot, value]) 725 | 726 | return action 727 | 728 | def is_empty(self): 729 | """ 730 | Is the agenda already empty 731 | Returns: 732 | (boolean): True for empty, False for not. 733 | """ 734 | return len(self.__stack) <= 0 735 | 736 | def _update_current_domain(self, sys_action, goal: Goal): 737 | for diaact in sys_action.keys(): 738 | domain, _ = diaact.split('-') 739 | if domain in goal.domains: 740 | self.cur_domain = domain 741 | 742 | def _remove_item(self, diaact, slot=DEF_VAL_UNK): 743 | for idx in range(len(self.__stack)): 744 | if 'general' in diaact: 745 | if self.__stack[idx]['diaact'] == diaact: 746 | self.__stack.remove(self.__stack[idx]) 747 | break 748 | else: 749 | if self.__stack[idx]['diaact'] == diaact and self.__stack[idx]['slot'] == slot: 750 | self.__stack.remove(self.__stack[idx]) 751 | break 752 | 753 | def _push_item(self, diaact, slot=DEF_VAL_NUL, value=DEF_VAL_NUL): 754 | self._remove_item(diaact, slot) 755 | self.__push(diaact, slot, value) 756 | self.__cur_push_num += 1 757 | 758 | def _check_item(self, diaact, slot=None): 759 | for idx in range(len(self.__stack)): 760 | if slot is None: 761 | if self.__stack[idx]['diaact'] == diaact: 762 | return True 763 | else: 764 | if self.__stack[idx]['diaact'] == diaact and self.__stack[idx]['slot'] == slot: 765 | return True 766 | return False 767 | 768 | def _check_reqt(self, domain): 769 | for idx in range(len(self.__stack)): 770 | if self.__stack[idx]['diaact'] == domain + '-request': 771 | return True 772 | return False 773 | 774 | def _check_reqt_info(self, domain): 775 | for idx in range(len(self.__stack)): 776 | if self.__stack[idx]['diaact'] == domain + '-inform' and self.__stack[idx]['slot'] not in BOOK_SLOT: 777 | return True 778 | return False 779 | 780 | def _check_book_info(self, domain): 781 | for idx in range(len(self.__stack)): 782 | if self.__stack[idx]['diaact'] == domain + '-inform' and self.__stack[idx]['slot'] in BOOK_SLOT: 783 | return True 784 | return False 785 | 786 | def __check_next_diaact_slot(self): 787 | if len(self.__stack) > 0: 788 | return self.__stack[-1]['diaact'], self.__stack[-1]['slot'] 789 | return None, None 790 | 791 | def __check_next_diaact(self): 792 | if len(self.__stack) > 0: 793 | return self.__stack[-1]['diaact'] 794 | return None 795 | 796 | def __push(self, diaact, slot=DEF_VAL_NUL, value=DEF_VAL_NUL): 797 | self.__stack.append({'diaact': diaact, 'slot': slot, 'value': value}) 798 | 799 | def __pop(self, initiative=1): 800 | diaacts = [] 801 | slots = [] 802 | values = [] 803 | 804 | p_diaact, p_slot = self.__check_next_diaact_slot() 805 | if p_diaact.split('-')[1] == 'inform' and p_slot in BOOK_SLOT: 806 | for _ in range(10 if self.__cur_push_num == 0 else self.__cur_push_num): 807 | try: 808 | item = self.__stack.pop(-1) 809 | diaacts.append(item['diaact']) 810 | slots.append(item['slot']) 811 | values.append(item['value']) 812 | 813 | cur_diaact = item['diaact'] 814 | 815 | next_diaact, next_slot = self.__check_next_diaact_slot() 816 | if next_diaact is None or \ 817 | next_diaact != cur_diaact or \ 818 | next_diaact.split('-')[1] != 'inform' or next_slot not in BOOK_SLOT: 819 | break 820 | except: 821 | break 822 | else: 823 | for _ in range(initiative if self.__cur_push_num == 0 else self.__cur_push_num): 824 | try: 825 | item = self.__stack.pop(-1) 826 | diaacts.append(item['diaact']) 827 | slots.append(item['slot']) 828 | values.append(item['value']) 829 | 830 | cur_diaact = item['diaact'] 831 | 832 | next_diaact = self.__check_next_diaact() 833 | if next_diaact is None or \ 834 | next_diaact != cur_diaact or \ 835 | (cur_diaact.split('-')[1] == 'request' and item['slot'] == 'name'): 836 | break 837 | except: 838 | break 839 | 840 | return diaacts, slots, values 841 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: ryuichi takanobu 4 | """ 5 | 6 | class Config(): 7 | 8 | def __init__(self): 9 | self.domain = [] 10 | self.intent = [] 11 | self.slot = [] 12 | self.da = [] 13 | self.da_usr = [] 14 | self.data_file = '' 15 | self.db_domains = [] 16 | self.belief_domains = [] 17 | 18 | def init_inform_request(self): 19 | self.inform_da = [] 20 | self.request_da = [] 21 | self.requestable = [] 22 | 23 | for da in self.da_usr: 24 | d, i, s = da.split('-') 25 | if s == 'none': 26 | continue 27 | key = '-'.join([d,s]) 28 | if i == 'inform' and key not in self.inform_da: 29 | self.inform_da.append(key) 30 | elif i == 'request' and key not in self.request_da: 31 | self.request_da.append(key) 32 | if d in self.db_domains and s != 'ref': 33 | self.requestable.append(key) 34 | 35 | self.inform_da_usr = [] 36 | self.request_da_usr = [] 37 | for da in self.da_goal: 38 | d, i, s = da.split('-') 39 | key = '-'.join([d,s]) 40 | if i == 'inform': 41 | self.inform_da_usr.append(key) 42 | else: 43 | self.request_da_usr.append(key) 44 | 45 | def init_dict(self): 46 | self.domain2idx = dict((a, i) for i, a in enumerate(self.belief_domains)) 47 | self.idx2domain = dict((v, k) for k, v in self.domain2idx.items()) 48 | 49 | self.inform2idx = dict((a, i) for i, a in enumerate(self.inform_da)) 50 | self.idx2inform = dict((v, k) for k, v in self.inform2idx.items()) 51 | 52 | self.request2idx = dict((a, i) for i, a in enumerate(self.request_da)) 53 | self.idx2request = dict((v, k) for k, v in self.request2idx.items()) 54 | 55 | self.inform2idx_u = dict((a, i) for i, a in enumerate(self.inform_da_usr)) 56 | self.idx2inform_u = dict((v, k) for k, v in self.inform2idx_u.items()) 57 | 58 | self.request2idx_u = dict((a, i) for i, a in enumerate(self.request_da_usr)) 59 | self.idx2request_u = dict((v, k) for k, v in self.request2idx_u.items()) 60 | 61 | self.requestable2idx = dict((a, i) for i, a in enumerate(self.requestable)) 62 | self.idx2requestable = dict((v, k) for k, v in self.requestable2idx.items()) 63 | 64 | self.da2idx = dict((a, i) for i, a in enumerate(self.da)) 65 | self.idx2da = dict((v, k) for k, v in self.da2idx.items()) 66 | 67 | self.da2idx_u = dict((a, i) for i, a in enumerate(self.da_usr)) 68 | self.idx2da_u = dict((v, k) for k, v in self.da2idx_u.items()) 69 | 70 | def init_dim(self): 71 | self.s_dim = len(self.da) + len(self.da_usr) + len(self.inform_da) + len(self.request_da) + len(self.belief_domains) + 6*len(self.db_domains) + 1#len(self.requestable) + 1 72 | self.s_dim_usr = len(self.da) + len(self.da_usr) + len(self.inform_da_usr)*2 + len(self.request_da_usr) + len(self.belief_domains)#*2 73 | self.a_dim = len(self.da) 74 | self.a_dim_usr = len(self.da_usr) + 1 75 | 76 | 77 | class MultiWozConfig(Config): 78 | 79 | def __init__(self): 80 | self.domain = ['general', 'train', 'booking', 'hotel', 'restaurant', 'attraction', 'taxi', 'police', 'hospital'] 81 | self.intent = ['inform', 'request', 'reqmore', 'bye', 'book', 'welcome', 'recommend', 'offerbook', 'nooffer', 'offerbooked', 'greet', 'select', 'nobook', 'thank'] 82 | self.slot = ['none', 'name', 'area', 'choice', 'type', 'price', 'ref', 'leave', 'addr', 'phone', 'food', 'day', 'arrive', 'depart', 'dest', 'post', 'id', 'people', 'stars', 'ticket', 'time', 'fee', 'car', 'internet', 'parking', 'stay', 'department'] 83 | self.da = ['general-reqmore-none-none', 'general-bye-none-none', 'booking-inform-none-none', 'booking-book-ref-1', 'general-welcome-none-none', 'restaurant-inform-name-1', 'hotel-inform-choice-1', 'train-inform-leave-1', 'hotel-inform-name-1', 'train-inform-id-1', 'restaurant-inform-choice-1', 'train-inform-arrive-1', 'restaurant-inform-food-1', 'train-offerbook-none-none', 'restaurant-inform-area-1', 'hotel-inform-type-1', 'attraction-inform-name-1', 'restaurant-inform-price-1', 'attraction-inform-area-1', 'train-offerbooked-ref-1', 'hotel-inform-area-1', 'hotel-inform-price-1', 'general-greet-none-none', 'attraction-inform-choice-1', 'train-inform-choice-1', 'hotel-request-area-?', 'attraction-inform-addr-1', 'train-request-leave-?', 'taxi-inform-car-1', 'attraction-inform-type-1', 'taxi-inform-phone-1', 'restaurant-inform-addr-1', 'attraction-inform-fee-1', 'restaurant-request-food-?', 'attraction-inform-phone-1', 'hotel-inform-stars-1', 'booking-request-day-?', 'train-inform-dest-1', 'train-request-depart-?', 'train-request-day-?', 'attraction-inform-post-1', 'hotel-recommend-name-1', 'restaurant-recommend-name-1', 'hotel-inform-internet-1', 'train-request-dest-?', 'attraction-recommend-name-1', 'restaurant-inform-phone-1', 'train-inform-depart-1', 'hotel-inform-parking-1', 'train-offerbooked-ticket-1', 'booking-book-name-1', 'hotel-request-price-?', 'train-inform-ticket-1', 'booking-nobook-none-none', 'restaurant-request-area-?', 'booking-request-people-?', 'hotel-inform-addr-1', 'train-request-arrive-?', 'train-inform-day-1', 'train-inform-time-1', 'booking-request-time-?', 'restaurant-inform-post-1', 'booking-book-day-1', 'booking-request-stay-?', 'restaurant-request-price-?', 'attraction-request-type-?', 'attraction-request-area-?', 'booking-book-people-1', 'restaurant-nooffer-none-none', 'taxi-request-leave-?', 'hotel-inform-phone-1', 'taxi-request-depart-?', 'restaurant-nooffer-food-1', 'hotel-inform-post-1', 'booking-book-time-1', 'train-request-people-?', 'attraction-inform-addr-2', 'taxi-request-dest-?', 'restaurant-inform-name-2', 'hotel-select-none-none', 'restaurant-select-none-none', 'booking-book-stay-1', 'train-offerbooked-id-1', 'hotel-inform-name-2', 'hotel-nooffer-type-1', 'train-offerbooked-people-1', 'taxi-request-arrive-?', 'attraction-recommend-addr-1', 'attraction-recommend-fee-1', 'hotel-recommend-area-1', 'hotel-request-stars-?', 'restaurant-nooffer-area-1', 'restaurant-recommend-food-1', 'restaurant-recommend-area-1', 'attraction-recommend-area-1', 'train-inform-leave-2', 'hotel-inform-choice-2', 'attraction-nooffer-area-1', 'attraction-nooffer-type-1', 'hotel-nooffer-none-none', 'hotel-recommend-price-1', 'attraction-inform-name-2', 'hotel-recommend-stars-1', 'restaurant-recommend-price-1', 'restaurant-inform-food-2', 'train-select-none-none', 'attraction-inform-type-2', 'booking-inform-name-1', 'hotel-inform-type-2', 'hotel-request-type-?', 'hotel-request-parking-?', 'hospital-inform-phone-1', 'hospital-inform-post-1', 'train-offerbooked-leave-1', 'attraction-select-none-none', 'hotel-select-type-1', 'taxi-inform-depart-1', 'hotel-inform-price-2', 'restaurant-recommend-addr-1', 'police-inform-phone-1', 'hospital-inform-addr-1', 'hotel-nooffer-area-1', 'hotel-inform-area-2', 'police-inform-post-1', 'police-inform-addr-1', 'attraction-recommend-type-1', 'attraction-inform-type-3', 'hotel-nooffer-stars-1', 'hotel-nooffer-price-1', 'taxi-inform-dest-1', 'hotel-request-internet-?', 'taxi-inform-leave-1', 'hotel-recommend-type-1', 'restaurant-inform-choice-2', 'hotel-recommend-internet-1', 'restaurant-select-food-1', 'restaurant-nooffer-price-1', 'train-offerbook-id-1', 'restaurant-inform-name-3', 'hotel-recommend-parking-1', 'attraction-inform-addr-3', 'attraction-recommend-post-1', 'attraction-inform-choice-2', 'restaurant-inform-area-2', 'train-offerbook-leave-1', 'hotel-inform-addr-2', 'restaurant-inform-price-2', 'attraction-recommend-phone-1', 'hotel-select-type-2', 'train-offerbooked-arrive-1', 'attraction-inform-area-2', 'hotel-recommend-addr-1', 'restaurant-select-food-2', 'train-offerbooked-depart-1', 'attraction-select-type-1', 'train-offerbook-arrive-1', 'taxi-inform-arrive-1', 'restaurant-inform-post-2', 'attraction-inform-fee-2', 'restaurant-inform-food-3', 'train-offerbooked-dest-1', 'attraction-inform-name-3', 'hotel-select-price-1', 'train-inform-arrive-2', 'attraction-request-name-?', 'attraction-nooffer-none-none', 'train-inform-ref-1', 'booking-book-none-none', 'police-inform-name-1', 'hotel-inform-stars-2', 'restaurant-select-price-1', 'attraction-inform-type-4'] 84 | self.da_usr = ['general-thank-none', 'restaurant-inform-food', 'train-inform-dest', 'train-inform-day', 'train-inform-depart', 'restaurant-inform-price', 'restaurant-inform-area', 'hotel-inform-stay', 'restaurant-inform-time', 'hotel-inform-type', 'restaurant-inform-day', 'hotel-inform-day', 'attraction-inform-type', 'restaurant-inform-people', 'hotel-inform-people', 'hotel-inform-price', 'hotel-inform-stars', 'hotel-inform-area', 'train-inform-arrive', 'attraction-inform-area', 'train-inform-people', 'train-inform-leave', 'hotel-inform-parking', 'hotel-inform-internet', 'restaurant-inform-name', 'attraction-request-post', 'hotel-inform-name', 'attraction-request-phone', 'attraction-request-addr', 'restaurant-request-addr', 'restaurant-request-phone', 'attraction-inform-name', 'attraction-request-fee', 'general-bye-none', 'train-request-ticket', 'taxi-inform-leave', 'taxi-inform-none', 'train-request-ref', 'taxi-inform-depart', 'restaurant-inform-none', 'restaurant-request-post', 'taxi-inform-dest', 'train-request-time', 'hotel-inform-none', 'taxi-inform-arrive', 'train-inform-none', 'hotel-request-addr', 'restaurant-request-ref', 'hotel-request-post', 'hotel-request-phone', 'hotel-request-ref', 'train-request-id', 'taxi-request-car', 'attraction-request-area', 'train-request-arrive', 'train-request-leave', 'attraction-inform-none', 'attraction-request-type', 'hotel-request-price', 'hotel-request-internet', 'hospital-inform-none', 'hotel-request-parking', 'restaurant-request-price', 'hotel-request-area', 'restaurant-request-area', 'hospital-request-post', 'hotel-request-type', 'restaurant-request-food', 'hospital-request-phone', 'general-greet-none', 'police-inform-none', 'police-request-addr', 'hospital-request-addr', 'hospital-inform-department', 'police-request-post', 'police-inform-name', 'hotel-request-stars', 'police-request-phone', 'taxi-request-phone'] 85 | self.da_goal = ['restaurant-inform-day', 'restaurant-inform-people', 'restaurant-inform-area', 'restaurant-inform-food', 'restaurant-inform-time', 'restaurant-inform-price', 'restaurant-inform-name', 'hotel-inform-day', 'hotel-inform-parking', 'hotel-inform-type', 'hotel-inform-stay', 'hotel-inform-people', 'hotel-inform-area', 'hotel-inform-stars', 'hotel-inform-price', 'hotel-inform-name', 'hotel-inform-internet', 'attraction-inform-area', 'attraction-inform-name', 'attraction-inform-type', 'train-inform-arrive', 'train-inform-day', 'train-inform-depart', 'train-inform-leave', 'train-inform-people', 'train-inform-dest', 'taxi-inform-arrive', 'taxi-inform-leave', 'taxi-inform-depart', 'taxi-inform-dest', 'hospital-inform-department', 'restaurant-request-addr', 'restaurant-request-post', 'restaurant-request-area', 'restaurant-request-food', 'restaurant-request-price', 'restaurant-request-phone', 'hotel-request-parking', 'hotel-request-type', 'hotel-request-addr', 'hotel-request-post', 'hotel-request-stars', 'hotel-request-area', 'hotel-request-price', 'hotel-request-internet', 'hotel-request-phone', 'attraction-request-type', 'attraction-request-fee', 'attraction-request-addr', 'attraction-request-post', 'attraction-request-area', 'attraction-request-phone', 'train-request-arrive', 'train-request-ticket', 'train-request-leave', 'train-request-id', 'train-request-time', 'taxi-request-car', 'taxi-request-phone', 'police-request-addr', 'police-request-post', 'police-request-phone', 'hospital-request-addr', 'hospital-request-post', 'hospital-request-phone'] 86 | self.data_file = 'annotated_user_da_with_span_full_patchName.json' 87 | self.ontology_file = 'value_set.json' 88 | self.db_domains = ['train', 'hotel', 'restaurant', 'attraction'] 89 | self.belief_domains = ['train', 'hotel', 'restaurant', 'attraction', 'taxi', 'police', 'hospital'] 90 | self.val_file = 'valListFile.json' 91 | self.test_file = 'testListFile.json' 92 | 93 | self.h_dim = 100 94 | self.hs_dim = 100 95 | self.ha_dim = 50 96 | self.hv_dim = 50 # for value function 97 | 98 | # da to db 99 | self.mapping = {'restaurant': {'addr': 'address', 'area': 'area', 'food': 'food', 'name': 'name', 'phone': 'phone', 'post': 'postcode', 'price': 'pricerange'}, 100 | 'hotel': {'addr': 'address', 'area': 'area', 'internet': 'internet', 'parking': 'parking', 'name': 'name', 'phone': 'phone', 'post': 'postcode', 'price': 'pricerange', 'stars': 'stars', 'type': 'type'}, 101 | 'attraction': {'addr': 'address', 'area': 'area', 'fee': 'entrance fee', 'name': 'name', 'phone': 'phone', 'post': 'postcode', 'type': 'type'}, 102 | 'train': {'id': 'trainID', 'arrive': 'arriveBy', 'day': 'day', 'depart': 'departure', 'dest': 'destination', 'time': 'duration', 'leave': 'leaveAt', 'ticket': 'price'}, 103 | 'taxi': {'car': 'taxi_type', 'phone': 'taxi_phone'}, 104 | 'hospital': {'department': 'department', 'phone': 'phone'}, 105 | 'police': {'addr': 'address', 'name': 'name', 'post': 'postcode'}} 106 | # goal to da 107 | self.map_inverse = {'restaurant': {'address': 'addr', 'area': 'area', 'day': 'day', 'food': 'food', 'name': 'name', 'people': 'people', 'phone': 'phone', 'postcode': 'post', 'pricerange': 'price', 'time': 'time'}, 108 | 'hotel': {'address': 'addr', 'area': 'area', 'day': 'day', 'internet': 'internet', 'name': 'name', 'parking': 'parking', 'people': 'people', 'phone': 'phone', 'postcode': 'post', 'pricerange': 'price', 'stars': 'stars', 'stay': 'stay', 'type': 'type'}, 109 | 'attraction': {'address': 'addr', 'area': 'area', 'entrance fee': 'fee', 'name': 'name', 'phone': 'phone', 'postcode': 'post', 'type': 'type'}, 110 | 'train': {'arriveBy': 'arrive', 'day': 'day', 'departure': 'depart', 'destination': 'dest', 'duration': 'time', 'leaveAt': 'leave', 'people': 'people', 'price': 'ticket', 'trainID': 'id'}, 111 | 'taxi': {'arriveBy': 'arrive', 'car type': 'car', 'departure': 'depart', 'destination': 'dest', 'leaveAt': 'leave', 'phone': 'phone'}, 112 | 'hospital': {'address': 'addr', 'department': 'department', 'phone': 'phone', 'postcode': 'post'}, 113 | 'police': {'address': 'addr', 'phone': 'phone', 'postcode': 'post'}} 114 | 115 | self.init_inform_request() # call this first! 116 | self.init_dict() 117 | self.init_dim() 118 | -------------------------------------------------------------------------------- /controller.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: ryuichi takanobu 4 | """ 5 | 6 | from utils import init_goal, init_session 7 | from tracker import StateTracker 8 | from goal_generator import GoalGenerator 9 | 10 | 11 | class Controller(StateTracker): 12 | def __init__(self, data_dir, config): 13 | super(Controller, self).__init__(data_dir, config) 14 | self.goal_gen = GoalGenerator(data_dir, config, 15 | goal_model_path='processed_data/goal_model.pkl', 16 | corpus_path=config.data_file) 17 | 18 | def reset(self, random_seed=None): 19 | """ 20 | init a user goal and return init state 21 | """ 22 | self.time_step = 0 23 | self.topic = '' 24 | self.goal = self.goal_gen.get_user_goal(random_seed) 25 | 26 | dummy_state, dummy_goal = init_session(-1, self.cfg) 27 | init_goal(dummy_goal, dummy_state['goal_state'], self.goal, self.cfg) 28 | 29 | domain_ordering = self.goal['domain_ordering'] 30 | dummy_state['next_available_domain'] = domain_ordering[0] 31 | dummy_state['invisible_domains'] = domain_ordering[1:] 32 | 33 | dummy_state['user_goal'] = dummy_goal 34 | self.evaluator.add_goal(dummy_goal) 35 | 36 | return dummy_state 37 | 38 | def step_sys(self, s, sys_a): 39 | """ 40 | interact with simulator for one sys-user turn 41 | """ 42 | # update state with sys_act 43 | current_s = self.update_belief_sys(s, sys_a) 44 | 45 | return current_s 46 | 47 | def step_usr(self, s, usr_a): 48 | current_s = self.update_belief_usr(s, usr_a) 49 | terminal = current_s['others']['terminal'] 50 | return current_s, terminal 51 | -------------------------------------------------------------------------------- /datamanager.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: ryuichi takanobu 4 | """ 5 | import os 6 | import json 7 | import logging 8 | import torch 9 | import torch.utils.data as data 10 | import copy 11 | from copy import deepcopy 12 | from evaluator import MultiWozEvaluator 13 | from utils import init_session, init_goal, state_vectorize, action_vectorize, \ 14 | state_vectorize_user, action_vectorize_user, discard, reload 15 | 16 | def expand_da(meta): 17 | for k, v in meta.items(): 18 | domain, intent = k.split('-') 19 | if intent.lower() == "request": 20 | for pair in v: 21 | pair.insert(1, '?') 22 | else: 23 | counter = {} 24 | for pair in v: 25 | if pair[0] == 'none': 26 | pair.insert(1, 'none') 27 | else: 28 | if pair[0] in counter: 29 | counter[pair[0]] += 1 30 | else: 31 | counter[pair[0]] = 1 32 | pair.insert(1, str(counter[pair[0]])) 33 | 34 | 35 | def add_domain_mask(data): 36 | all_domains = ["train", "hotel", "restaurant", "attraction", "taxi", "police", "hospital"] 37 | parts = ["train", "valid", "test"] 38 | for part in parts: 39 | dataset = data[part] 40 | domains_in_order_dict = {} # {session_id:domains_in_order} 41 | domains_in_order = [] 42 | current_session_id = "" 43 | for turn in dataset: 44 | session_id = turn["others"]["session_id"] 45 | if session_id != current_session_id: 46 | if current_session_id != "": 47 | domains_in_order_dict[current_session_id] = domains_in_order 48 | domains_in_order = [] 49 | current_session_id = session_id 50 | if "trg_user_action" in turn: 51 | user_das = turn["trg_user_action"] 52 | for user_da in user_das: 53 | [domain, intent, slot] = user_da.split('-') 54 | if domain in all_domains and domain not in domains_in_order: 55 | domains_in_order.append(domain) 56 | domains_in_order_dict[current_session_id] = domains_in_order # for last dialog 57 | 58 | current_session_id = "" 59 | next_available_domain = "" 60 | invisible_domains = [] 61 | for turn in dataset: 62 | session_id = turn["others"]["session_id"] 63 | if session_id != current_session_id: 64 | domains_in_order = domains_in_order_dict[session_id] 65 | if domains_in_order: 66 | next_available_domain = domains_in_order[0] 67 | invisible_domains = domains_in_order[1:] 68 | else: 69 | next_available_domain = "" 70 | invisible_domains = [] 71 | current_session_id = session_id 72 | turn["next_available_domain"] = next_available_domain 73 | turn["invisible_domains"] = copy.copy(invisible_domains) 74 | 75 | if "trg_user_action" in turn: 76 | user_das = turn["trg_user_action"] 77 | for user_da in user_das: 78 | [domain, intent, slot] = user_da.split('-') 79 | if domain == next_available_domain: 80 | if invisible_domains: 81 | next_available_domain = invisible_domains[0] 82 | invisible_domains.remove(next_available_domain) 83 | 84 | 85 | class DataManager(): 86 | """Offline data manager""" 87 | 88 | def __init__(self, data_dir, cfg): 89 | self.data = {} 90 | self.goal = {} 91 | 92 | self.data_dir_new = data_dir + '/processed_data' 93 | if os.path.exists(self.data_dir_new): 94 | logging.info('Load processed data file') 95 | for part in ['train','valid','test']: 96 | with open(self.data_dir_new + '/' + part + '.json', 'r') as f: 97 | self.data[part] = json.load(f) 98 | with open(self.data_dir_new + '/' + part + '_goal.json', 'r') as f: 99 | self.goal[part] = json.load(f) 100 | else: 101 | from dbquery import DBQuery 102 | db = DBQuery(data_dir, cfg) 103 | logging.info('Start preprocessing the dataset') 104 | self._build_data(data_dir, self.data_dir_new, cfg, db) 105 | 106 | for part in ['train', 'valid', 'test']: 107 | file_dir = self.data_dir_new + '/' + part + '_sys.pt' 108 | if not os.path.exists(file_dir): 109 | from dbquery import DBQuery 110 | db = DBQuery(data_dir, cfg) 111 | self.create_dataset_sys(part, file_dir, data_dir, cfg, db) 112 | 113 | file_dir = self.data_dir_new + '/' + part + '_usr.pt' 114 | if not os.path.exists(file_dir): 115 | from dbquery import DBQuery 116 | db = DBQuery(data_dir, cfg) 117 | self.create_dataset_usr(part, file_dir, data_dir, cfg, db) 118 | 119 | file_dir = self.data_dir_new + '/' + part + '_glo.pt' 120 | if not os.path.exists(file_dir): 121 | from dbquery import DBQuery 122 | db = DBQuery(data_dir, cfg) 123 | self.create_dataset_global(part, file_dir, data_dir, cfg, db) 124 | 125 | def _build_data(self, data_dir, data_dir_new, cfg, db): 126 | data_filename = data_dir + '/' + cfg.data_file 127 | with open(data_filename, 'r') as f: 128 | origin_data = json.load(f) 129 | 130 | for part in ['train','valid','test']: 131 | self.data[part] = [] 132 | self.goal[part] = {} 133 | 134 | valList = [] 135 | with open(data_dir + '/' + cfg.val_file) as f: 136 | for line in f: 137 | valList.append(line.split('.')[0]) 138 | testList = [] 139 | with open(data_dir + '/' + cfg.test_file) as f: 140 | for line in f: 141 | testList.append(line.split('.')[0]) 142 | 143 | for k_sess in origin_data: 144 | sess = origin_data[k_sess] 145 | if k_sess in valList: 146 | part = 'valid' 147 | elif k_sess in testList: 148 | part = 'test' 149 | else: 150 | part = 'train' 151 | turn_data, session_data = init_session(k_sess, cfg) 152 | belief_state = turn_data['belief_state'] 153 | goal_state = turn_data['goal_state'] 154 | init_goal(session_data, goal_state, sess['goal'], cfg) 155 | self.goal[part][k_sess] = deepcopy(session_data) 156 | current_domain = '' 157 | book_domain = '' 158 | turn_data['trg_user_action'] = {} 159 | turn_data['trg_sys_action'] = {} 160 | 161 | for i, turn in enumerate(sess['log']): 162 | turn_data['others']['turn'] = i 163 | turn_data['others']['terminal'] = i + 2 >= len(sess['log']) 164 | da_origin = turn['dialog_act'] 165 | expand_da(da_origin) 166 | turn_data['belief_state'] = deepcopy(belief_state) # from previous turn 167 | turn_data['goal_state'] = deepcopy(goal_state) 168 | 169 | if i % 2 == 0: # user 170 | turn_data['sys_action'] = deepcopy(turn_data['trg_sys_action']) 171 | del(turn_data['trg_sys_action']) 172 | turn_data['trg_user_action'] = dict() 173 | for domint in da_origin: 174 | domain_intent = da_origin[domint] 175 | _domint = domint.lower() 176 | _domain, _intent = _domint.split('-') 177 | if _domain in cfg.belief_domains: 178 | current_domain = _domain 179 | for slot, p, value in domain_intent: 180 | _slot = slot.lower() 181 | _value = value.strip() 182 | _da = '-'.join((_domint, _slot)) 183 | if _da in cfg.da_usr: 184 | turn_data['trg_user_action'][_da] = _value 185 | if _intent == 'inform': 186 | inform_da = _domain+'-'+_slot 187 | if inform_da in cfg.inform_da: 188 | belief_state[_domain][_slot] = _value 189 | if inform_da in cfg.inform_da_usr and _slot in session_data[_domain] \ 190 | and session_data[_domain][_slot] != '?': 191 | discard(goal_state[_domain], _slot) 192 | elif _intent == 'request': 193 | request_da = _domain+'-'+_slot 194 | if request_da in cfg.request_da: 195 | belief_state[_domain][_slot] = '?' 196 | 197 | else: # sys 198 | book_status = turn['metadata'] 199 | for domain in cfg.belief_domains: 200 | if book_status[domain]['book']['booked']: 201 | entity = book_status[domain]['book']['booked'][0] 202 | if 'booked' in belief_state[domain]: 203 | continue 204 | book_domain = domain 205 | if domain in ['taxi', 'hospital', 'police']: 206 | belief_state[domain]['booked'] = f'{domain}-booked' 207 | elif domain == 'train': 208 | found = db.query(domain, [('trainID', entity['trainID'])]) 209 | belief_state[domain]['booked'] = found[0]['ref'] 210 | else: 211 | found = db.query(domain, [('name', entity['name'])]) 212 | belief_state[domain]['booked'] = found[0]['ref'] 213 | 214 | turn_data['user_action'] = deepcopy(turn_data['trg_user_action']) 215 | del(turn_data['trg_user_action']) 216 | turn_data['others']['change'] = False 217 | turn_data['trg_sys_action'] = dict() 218 | for domint in da_origin: 219 | domain_intent = da_origin[domint] 220 | _domint = domint.lower() 221 | _domain, _intent = _domint.split('-') 222 | for slot, p, value in domain_intent: 223 | _slot = slot.lower() 224 | _value = value.strip() 225 | _da = '-'.join((_domint, _slot, p)) 226 | if _da in cfg.da and current_domain: 227 | if _slot == 'ref': 228 | turn_data['trg_sys_action'][_da] = belief_state[book_domain]['booked'] 229 | else: 230 | turn_data['trg_sys_action'][_da] = _value 231 | if _intent in ['inform', 'recommend', 'offerbook', 'offerbooked', 'book']: 232 | inform_da = current_domain+'-'+_slot 233 | if inform_da in cfg.request_da: 234 | discard(belief_state[current_domain], _slot, '?') 235 | if inform_da in cfg.request_da_usr and _slot in session_data[current_domain] \ 236 | and session_data[current_domain][_slot] == '?': 237 | goal_state[current_domain][_slot] = _value 238 | elif _intent in ['nooffer', 'nobook']: 239 | # TODO: better transition 240 | for da in turn_data['user_action']: 241 | __domain, __intent, __slot = da.split('-') 242 | if __intent == 'inform' and __domain == current_domain: 243 | discard(belief_state[current_domain], __slot) 244 | turn_data['others']['change'] = True 245 | reload(goal_state, session_data, current_domain) 246 | 247 | if i + 1 == len(sess['log']): 248 | turn_data['final_belief_state'] = belief_state 249 | turn_data['final_goal_state'] = goal_state 250 | 251 | self.data[part].append(deepcopy(turn_data)) 252 | 253 | add_domain_mask(self.data) 254 | 255 | def _set_default(obj): 256 | if isinstance(obj, set): 257 | return list(obj) 258 | raise TypeError 259 | os.makedirs(data_dir_new) 260 | for part in ['train','valid','test']: 261 | with open(data_dir_new + '/' + part + '.json', 'w') as f: 262 | self.data[part] = json.dumps(self.data[part], default=_set_default) 263 | f.write(self.data[part]) 264 | self.data[part] = json.loads(self.data[part]) 265 | with open(data_dir_new + '/' + part + '_goal.json', 'w') as f: 266 | self.goal[part] = json.dumps(self.goal[part], default=_set_default) 267 | f.write(self.goal[part]) 268 | self.goal[part] = json.loads(self.goal[part]) 269 | 270 | def create_dataset_sys(self, part, file_dir, data_dir, cfg, db): 271 | datas = self.data[part] 272 | goals = self.goal[part] 273 | s, a, r, next_s, t = [], [], [], [], [] 274 | evaluator = MultiWozEvaluator(data_dir) 275 | for idx, turn_data in enumerate(datas): 276 | if turn_data['others']['turn'] % 2 == 0: 277 | if turn_data['others']['turn'] == 0: 278 | evaluator.add_goal(goals[turn_data['others']['session_id']]) 279 | evaluator.add_usr_da(turn_data['trg_user_action']) 280 | continue 281 | if turn_data['others']['turn'] != 1: 282 | next_s.append(s[-1]) 283 | 284 | s.append(torch.Tensor(state_vectorize(turn_data, cfg, db, True))) 285 | a.append(torch.Tensor(action_vectorize(turn_data['trg_sys_action'], cfg))) 286 | evaluator.add_sys_da(turn_data['trg_sys_action']) 287 | if turn_data['others']['terminal']: 288 | next_turn_data = deepcopy(turn_data) 289 | next_turn_data['others']['turn'] = -1 290 | next_turn_data['user_action'] = {} 291 | next_turn_data['sys_action'] = turn_data['trg_sys_action'] 292 | next_turn_data['trg_sys_action'] = {} 293 | next_turn_data['belief_state'] = turn_data['final_belief_state'] 294 | next_s.append(torch.Tensor(state_vectorize(next_turn_data, cfg, db, True))) 295 | reward = 20 if evaluator.task_success(False) else -5 296 | r.append(reward) 297 | t.append(1) 298 | else: 299 | reward = 0 300 | if evaluator.cur_domain: 301 | for slot, value in turn_data['belief_state'][evaluator.cur_domain].items(): 302 | if value == '?': 303 | for da in turn_data['trg_sys_action']: 304 | d, i, k, p = da.split('-') 305 | if i in ['inform', 'recommend', 'offerbook', 'offerbooked'] and k == slot: 306 | break 307 | else: 308 | # not answer request 309 | reward -= 1 310 | if not turn_data['trg_sys_action']: 311 | reward -= 5 312 | r.append(reward) 313 | t.append(0) 314 | 315 | torch.save((s, a, r, next_s, t), file_dir) 316 | 317 | def create_dataset_usr(self, part, file_dir, data_dir, cfg, db): 318 | datas = self.data[part] 319 | goals = self.goal[part] 320 | s, a, r, next_s, t = [], [], [], [], [] 321 | evaluator = MultiWozEvaluator(data_dir) 322 | current_goal = None 323 | for idx, turn_data in enumerate(datas): 324 | if turn_data['others']['turn'] % 2 == 1: 325 | evaluator.add_sys_da(turn_data['trg_sys_action']) 326 | continue 327 | 328 | if turn_data['others']['turn'] == 0: 329 | current_goal = goals[turn_data['others']['session_id']] 330 | evaluator.add_goal(current_goal) 331 | else: 332 | next_s.append(s[-1]) 333 | if turn_data['others']['change'] and evaluator.cur_domain: 334 | if 'final' in current_goal[evaluator.cur_domain]: 335 | for key in current_goal[evaluator.cur_domain]['final']: 336 | current_goal[evaluator.cur_domain][key] = current_goal[evaluator.cur_domain]['final'][key] 337 | del(current_goal[evaluator.cur_domain]['final']) 338 | turn_data['user_goal'] = deepcopy(current_goal) 339 | 340 | s.append(torch.Tensor(state_vectorize_user(turn_data, cfg, evaluator.cur_domain))) 341 | a.append(torch.Tensor(action_vectorize_user(turn_data['trg_user_action'], turn_data['others']['terminal'], cfg))) 342 | evaluator.add_usr_da(turn_data['trg_user_action']) 343 | if turn_data['others']['terminal']: 344 | next_turn_data = deepcopy(turn_data) 345 | next_turn_data['others']['turn'] = -1 346 | next_turn_data['user_action'] = turn_data['trg_user_action'] 347 | next_turn_data['sys_action'] = datas[idx+1]['trg_sys_action'] 348 | next_turn_data['trg_user_action'] = {} 349 | next_turn_data['goal_state'] = datas[idx+1]['final_goal_state'] 350 | next_s.append(torch.Tensor(state_vectorize_user(next_turn_data, cfg, evaluator.cur_domain))) 351 | reward = 20 if evaluator.inform_F1(ansbysys=False)[1] == 1. else -5 352 | r.append(reward) 353 | t.append(1) 354 | else: 355 | reward = 0 356 | if evaluator.cur_domain: 357 | for da in turn_data['trg_user_action']: 358 | d, i, k = da.split('-') 359 | if i == 'request': 360 | for slot, value in turn_data['goal_state'][d].items(): 361 | if value != '?' and slot in turn_data['user_goal'][d]\ 362 | and turn_data['user_goal'][d][slot] != '?': 363 | # request before express constraint 364 | reward -= 1 365 | if not turn_data['trg_user_action']: 366 | reward -= 5 367 | r.append(reward) 368 | t.append(0) 369 | 370 | torch.save((s, a, r, next_s, t), file_dir) 371 | 372 | def create_dataset_global(self, part, file_dir, data_dir, cfg, db): 373 | datas = self.data[part] 374 | goals = self.goal[part] 375 | s_usr, s_sys, r_g, next_s_usr, next_s_sys, t = [], [], [], [], [], [] 376 | evaluator = MultiWozEvaluator(data_dir) 377 | for idx, turn_data in enumerate(datas): 378 | if turn_data['others']['turn'] % 2 == 0: 379 | if turn_data['others']['turn'] == 0: 380 | current_goal = goals[turn_data['others']['session_id']] 381 | evaluator.add_goal(current_goal) 382 | else: 383 | next_s_usr.append(s_usr[-1]) 384 | 385 | if turn_data['others']['change'] and evaluator.cur_domain: 386 | if 'final' in current_goal[evaluator.cur_domain]: 387 | for key in current_goal[evaluator.cur_domain]['final']: 388 | current_goal[evaluator.cur_domain][key] = current_goal[evaluator.cur_domain]['final'][key] 389 | del(current_goal[evaluator.cur_domain]['final']) 390 | 391 | turn_data['user_goal'] = deepcopy(current_goal) 392 | s_usr.append(torch.Tensor(state_vectorize_user(turn_data, cfg, evaluator.cur_domain))) 393 | evaluator.add_usr_da(turn_data['trg_user_action']) 394 | 395 | if turn_data['others']['terminal']: 396 | next_turn_data = deepcopy(turn_data) 397 | next_turn_data['others']['turn'] = -1 398 | next_turn_data['user_action'] = turn_data['trg_user_action'] 399 | next_turn_data['sys_action'] = datas[idx+1]['trg_sys_action'] 400 | next_turn_data['trg_user_action'] = {} 401 | next_turn_data['goal_state'] = datas[idx+1]['final_goal_state'] 402 | next_s_usr.append(torch.Tensor(state_vectorize_user(next_turn_data, cfg, evaluator.cur_domain))) 403 | 404 | else: 405 | if turn_data['others']['turn'] != 1: 406 | next_s_sys.append(s_sys[-1]) 407 | 408 | s_sys.append(torch.Tensor(state_vectorize(turn_data, cfg, db, True))) 409 | evaluator.add_sys_da(turn_data['trg_sys_action']) 410 | 411 | if turn_data['others']['terminal']: 412 | next_turn_data = deepcopy(turn_data) 413 | next_turn_data['others']['turn'] = -1 414 | next_turn_data['user_action'] = {} 415 | next_turn_data['sys_action'] = turn_data['trg_sys_action'] 416 | next_turn_data['trg_sys_action'] = {} 417 | next_turn_data['belief_state'] = turn_data['final_belief_state'] 418 | next_s_sys.append(torch.Tensor(state_vectorize(next_turn_data, cfg, db, True))) 419 | reward_g = 20 if evaluator.task_success() else -5 420 | r_g.append(reward_g) 421 | t.append(1) 422 | else: 423 | reward_g = 5 if evaluator.cur_domain and evaluator.domain_success(evaluator.cur_domain) else -1 424 | r_g.append(reward_g) 425 | t.append(0) 426 | 427 | torch.save((s_usr, s_sys, r_g, next_s_usr, next_s_sys, t), file_dir) 428 | 429 | def create_dataset_policy(self, part, batchsz, cfg, db, character='sys'): 430 | assert part in ['train', 'valid', 'test'] 431 | logging.debug('start loading {}'.format(part)) 432 | 433 | if character == 'sys': 434 | file_dir = self.data_dir_new + '/' + part + '_sys.pt' 435 | elif character == 'usr': 436 | file_dir = self.data_dir_new + '/' + part + '_usr.pt' 437 | else: 438 | raise NotImplementedError('Unknown character {}'.format(character)) 439 | 440 | s, a, *_ = torch.load(file_dir) 441 | new_s, new_a = [], [] 442 | for state, action in zip(s, a): 443 | if action.nonzero().size(0): 444 | new_s.append(state) 445 | new_a.append(action) 446 | dataset = Dataset_Policy(new_s, new_a) 447 | dataloader = data.DataLoader(dataset, batchsz, True) 448 | 449 | logging.debug('finish loading {}'.format(part)) 450 | return dataloader 451 | 452 | def create_dataset_vnet(self, part, batchsz, cfg, db): 453 | assert part in ['train', 'valid', 'test'] 454 | logging.debug('start loading {}'.format(part)) 455 | 456 | file_dir_1 = self.data_dir_new + '/' + part + '_sys.pt' 457 | file_dir_2 = self.data_dir_new + '/' + part + '_usr.pt' 458 | file_dir_3 = self.data_dir_new + '/' + part + '_glo.pt' 459 | 460 | s, _, r, next_s, t = torch.load(file_dir_1) 461 | dataset_sys = Dataset_Vnet(s, r, next_s, t) 462 | dataloader_sys = data.DataLoader(dataset_sys, batchsz, True) 463 | 464 | s, _, r, next_s, t = torch.load(file_dir_2) 465 | dataset_usr = Dataset_Vnet(s, r, next_s, t) 466 | dataloader_usr = data.DataLoader(dataset_usr, batchsz, True) 467 | 468 | s_usr, s_sys, r_g, next_s_usr, next_s_sys, t = torch.load(file_dir_3) 469 | dataset_global = Dataset_Vnet_G(s_usr, s_sys, r_g, next_s_usr, next_s_sys, t) 470 | dataloader_global = data.DataLoader(dataset_global, batchsz, True) 471 | 472 | logging.debug('finish loading {}'.format(part)) 473 | return dataloader_sys, dataloader_usr, dataloader_global 474 | 475 | 476 | class Dataset_Policy(data.Dataset): 477 | def __init__(self, s, a): 478 | self.s = s 479 | self.a = a 480 | self.num_total = len(s) 481 | 482 | def __getitem__(self, index): 483 | s = self.s[index] 484 | a = self.a[index] 485 | return s, a 486 | 487 | def __len__(self): 488 | return self.num_total 489 | 490 | class Dataset_Vnet(data.Dataset): 491 | def __init__(self, s, r, next_s, t): 492 | self.s = s 493 | self.r = r 494 | self.next_s = next_s 495 | self.t = t 496 | self.num_total = len(s) 497 | 498 | def __getitem__(self, index): 499 | s = self.s[index] 500 | r = self.r[index] 501 | next_s = self.next_s[index] 502 | t = self.t[index] 503 | return s, r, next_s, t 504 | 505 | def __len__(self): 506 | return self.num_total 507 | 508 | class Dataset_Vnet_G(data.Dataset): 509 | def __init__(self, s_usr, s_sys, r, next_s_usr, next_s_sys, t): 510 | self.s_usr = s_usr 511 | self.s_sys = s_sys 512 | self.r = r 513 | self.next_s_usr = next_s_usr 514 | self.next_s_sys = next_s_sys 515 | self.t = t 516 | self.num_total = len(s_sys) 517 | 518 | def __getitem__(self, index): 519 | s_usr = self.s_usr[index] 520 | s_sys = self.s_sys[index] 521 | r = self.r[index] 522 | next_s_usr = self.next_s_usr[index] 523 | next_s_sys = self.next_s_sys[index] 524 | t = self.t[index] 525 | return s_usr, s_sys, r, next_s_usr, next_s_sys, t 526 | 527 | def __len__(self): 528 | return self.num_total 529 | -------------------------------------------------------------------------------- /dbquery.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import numpy as np 4 | 5 | class DBQuery(): 6 | 7 | def __init__(self, data_dir, cfg): 8 | # loading databases 9 | self.cfg = cfg 10 | self.dbs = {} 11 | for domain in cfg.belief_domains: 12 | with open('{}/{}_db.json'.format(data_dir, domain)) as f: 13 | self.dbs[domain] = json.load(f) 14 | 15 | def query(self, domain, constraints, ignore_open=True): 16 | """Returns the list of entities for a given domain 17 | based on the annotation of the belief state""" 18 | # query the db 19 | if domain == 'taxi': 20 | return [{'taxi_type': random.choice(self.dbs[domain]['taxi_colors']) + ' ' + random.choice(self.dbs[domain]['taxi_types']), 21 | 'taxi_phone': ''.join([str(random.randint(0, 9)) for _ in range(10)])}] 22 | elif domain == 'hospital': 23 | return self.dbs['hospital'] 24 | elif domain == 'police': 25 | return self.dbs['police'] 26 | 27 | found = [] 28 | for i, record in enumerate(self.dbs[domain]): 29 | for key, val in constraints: 30 | if val == "" or val == "dont care" or val == 'not mentioned' or val == "don't care" or val == "dontcare" or val == "do n't care": 31 | pass 32 | else: 33 | try: 34 | record_keys = [key.lower() for key in record] 35 | if key.lower() not in record_keys: 36 | continue 37 | if key == 'leaveAt': 38 | val1 = int(val.split(':')[0]) * 100 + int(val.split(':')[1]) 39 | val2 = int(record['leaveAt'].split(':')[0]) * 100 + int(record['leaveAt'].split(':')[1]) 40 | if val1 > val2: 41 | break 42 | elif key == 'arriveBy': 43 | val1 = int(val.split(':')[0]) * 100 + int(val.split(':')[1]) 44 | val2 = int(record['arriveBy'].split(':')[0]) * 100 + int(record['arriveBy'].split(':')[1]) 45 | if val1 < val2: 46 | break 47 | elif ignore_open and key in ['destination', 'departure']: 48 | continue 49 | else: 50 | if val.strip() != record[key].strip(): 51 | break 52 | except: 53 | continue 54 | else: 55 | record['ref'] = f'{domain}-{i:08d}' 56 | found.append(record) 57 | 58 | return found 59 | 60 | def pointer(self, turn, mapping, db_domains, requestable, noisy): 61 | """Create database pointer for all related domains.""" 62 | pointer_vector = np.zeros(6 * len(db_domains)) 63 | entropy = np.zeros(len(requestable)) 64 | for domain in db_domains: 65 | constraint = [] 66 | for k, v in turn[domain].items(): 67 | if k in mapping[domain] and v != '?': 68 | constraint.append((mapping[domain][k], v)) 69 | entities = self.query(domain, constraint, noisy) 70 | pointer_vector = self.one_hot_vector(len(entities), domain, pointer_vector, db_domains) 71 | entropy = self.calc_slot_entropy(entities, domain, entropy, requestable) 72 | 73 | return pointer_vector, entropy 74 | 75 | def one_hot_vector(self, num, domain, vector, db_domains): 76 | """Return number of available entities for particular domain.""" 77 | if domain != 'train': 78 | idx = db_domains.index(domain) 79 | if num == 0: 80 | vector[idx * 6: idx * 6 + 6] = np.array([1, 0, 0, 0, 0, 0]) 81 | elif num == 1: 82 | vector[idx * 6: idx * 6 + 6] = np.array([0, 1, 0, 0, 0, 0]) 83 | elif num == 2: 84 | vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 1, 0, 0, 0]) 85 | elif num == 3: 86 | vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 1, 0, 0]) 87 | elif num == 4: 88 | vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 1, 0]) 89 | elif num >= 5: 90 | vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 0, 1]) 91 | else: 92 | idx = db_domains.index(domain) 93 | if num == 0: 94 | vector[idx * 6: idx * 6 + 6] = np.array([1, 0, 0, 0, 0, 0]) 95 | elif num <= 2: 96 | vector[idx * 6: idx * 6 + 6] = np.array([0, 1, 0, 0, 0, 0]) 97 | elif num <= 5: 98 | vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 1, 0, 0, 0]) 99 | elif num <= 10: 100 | vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 1, 0, 0]) 101 | elif num <= 40: 102 | vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 1, 0]) 103 | elif num > 40: 104 | vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 0, 1]) 105 | 106 | return vector 107 | 108 | def calc_slot_entropy(self, entities, domain, vector, requestable): 109 | """Calculate entropy of requestable slot values in results""" 110 | N = len(entities) 111 | if not N: 112 | return vector 113 | 114 | # Count the values 115 | value_probabilities = {} 116 | for index, entity in enumerate(entities): 117 | if index == 0: 118 | for key, value in entity.items(): 119 | if key in self.cfg.map_inverse[domain] and \ 120 | domain+'-'+self.cfg.map_inverse[domain][key] in requestable: 121 | value_probabilities[key] = {value:1} 122 | else: 123 | for key, value in entity.items(): 124 | if key in value_probabilities: 125 | if value not in value_probabilities[key]: 126 | value_probabilities[key][value] = 1 127 | else: 128 | value_probabilities[key][value] += 1 129 | 130 | # Calculate entropies 131 | for key in value_probabilities: 132 | entropy = 0 133 | for count in value_probabilities[key].values(): 134 | entropy -= count/N * np.log(count/N) 135 | vector[requestable.index(domain+'-'+self.cfg.map_inverse[domain][key])] = entropy 136 | 137 | return vector 138 | -------------------------------------------------------------------------------- /evaluator.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | from copy import deepcopy 4 | 5 | from dbquery import DBQuery 6 | 7 | informable = \ 8 | {'attraction': ['area', 'name', 'type'], 9 | 'restaurant': ['addr', 'day', 'food', 'name', 'people', 'price', 'time'], 10 | 'train': ['day', 'people', 'arrive', 'leave', 'depart', 'dest'], 11 | 'hotel': ['area', 'day', 'internet', 'name', 'parking', 'people', 'price', 'stars', 'stay', 'type'], 12 | 'taxi': ['arrive', 'leave', 'depart', 'dest'], 13 | 'hospital': ['department'], 14 | 'police': []} 15 | 16 | requestable = \ 17 | {'attraction': ['post', 'phone', 'addr', 'fee', 'area', 'type'], 18 | 'restaurant': ['addr', 'phone', 'post', 'price', 'area', 'food'], 19 | 'train': ['ticket', 'time', 'id', 'arrive', 'leave'], 20 | 'hotel': ['addr', 'post', 'phone', 'price', 'internet', 'parking', 'area', 'type', 'stars'], 21 | 'taxi': ['car', 'phone'], 22 | 'hospital': ['phone'], 23 | 'police': ['addr', 'post']} 24 | 25 | time_re = re.compile(r'^(([01]\d|2[0-3]):([0-5]\d)|24:00)$') 26 | NUL_VALUE = ["", "dont care", 'not mentioned', "don't care", "dontcare", "do n't care"] 27 | 28 | class MultiWozEvaluator(): 29 | def __init__(self, data_dir): 30 | self.sys_da_array = [] 31 | self.usr_da_array = [] 32 | self.goal = {} 33 | self.booked = {} 34 | self.cur_domain = '' 35 | self.complete_domain = [] 36 | from config import MultiWozConfig 37 | cfg = MultiWozConfig() 38 | self.belief_domains = cfg.belief_domains 39 | self.mapping = cfg.mapping 40 | db = DBQuery(data_dir, cfg) 41 | self.dbs = db.dbs 42 | 43 | def _init_dict(self): 44 | dic = {} 45 | for domain in self.belief_domains: 46 | dic[domain] = {} 47 | return dic 48 | 49 | def _init_dict_booked(self): 50 | dic = {} 51 | for domain in self.belief_domains: 52 | dic[domain] = None 53 | return dic 54 | 55 | def add_goal(self, goal): 56 | """ 57 | init goal and array 58 | args: 59 | goal: dict[domain] dict[slot] value 60 | """ 61 | self.sys_da_array = [] 62 | self.usr_da_array = [] 63 | self.goal = deepcopy(goal) 64 | for domain in self.belief_domains: 65 | if 'final' in self.goal[domain]: 66 | for key in self.goal[domain]['final']: 67 | self.goal[domain][key] = self.goal[domain]['final'][key] 68 | del(self.goal[domain]['final']) 69 | self.cur_domain = '' 70 | self.complete_domain = [] 71 | self.booked = self._init_dict_booked() 72 | 73 | def add_sys_da(self, da_turn): 74 | """ 75 | add sys_da into array 76 | args: 77 | da_turn: dict[domain-intent-slot-p] value 78 | """ 79 | for da_w_p in da_turn: 80 | domain, intent, slot, p = da_w_p.split('-') 81 | value = str(da_turn[da_w_p]) 82 | da = '-'.join([domain, intent, slot]) 83 | self.sys_da_array.append(da+'-'+value) 84 | 85 | if value != 'none': 86 | if da == 'booking-book-ref': 87 | book_domain, ref_num = value.split('-') 88 | if not self.booked[book_domain] and re.match(r'^\d{8}$', ref_num): 89 | self.booked[book_domain] = self.dbs[book_domain][int(ref_num)] 90 | elif da == 'train-offerbooked-ref' or da == 'train-inform-ref': 91 | ref_num = value.split('-')[1] 92 | if not self.booked['train'] and re.match(r'^\d{8}$', ref_num): 93 | self.booked['train'] = self.dbs['train'][int(ref_num)] 94 | elif da == 'taxi-inform-car': 95 | if not self.booked['taxi']: 96 | self.booked['taxi'] = 'booked' 97 | 98 | def add_usr_da(self, da_turn): 99 | """ 100 | add usr_da into array 101 | args: 102 | da_turn: dict[domain-intent-slot] value 103 | """ 104 | for da in da_turn: 105 | domain, intent, slot = da.split('-') 106 | value = str(da_turn[da]) 107 | self.usr_da_array.append(da+'-'+value) 108 | if domain in self.belief_domains and domain != self.cur_domain: 109 | self.cur_domain = domain 110 | 111 | def _match_rate_goal(self, goal, booked_entity, domains=None): 112 | """ 113 | judge if the selected entity meets the constraint 114 | """ 115 | if domains is None: 116 | domains = self.belief_domains 117 | score = [] 118 | for domain in domains: 119 | if 'book' in goal[domain]: 120 | tot = 0 121 | for key, value in goal[domain].items(): 122 | if value != '?': 123 | tot += 1 124 | entity = booked_entity[domain] 125 | if entity is None: 126 | score.append(0) 127 | continue 128 | if domain in ['taxi', 'hospital', 'police']: 129 | score.append(1) 130 | continue 131 | match = 0 132 | for k, v in goal[domain].items(): 133 | if v == '?': 134 | continue 135 | if k in ['dest', 'depart', 'name'] or k not in self.mapping[domain]: 136 | tot -= 1 137 | elif k == 'leave': 138 | try: 139 | v_constraint = int(v.split(':')[0]) * 100 + int(v.split(':')[1]) 140 | v_select = int(entity['leaveAt'].split(':')[0]) * 100 + int(entity['leaveAt'].split(':')[1]) 141 | if v_constraint <= v_select: 142 | match += 1 143 | except (ValueError, IndexError): 144 | match += 1 145 | elif k == 'arrive': 146 | try: 147 | v_constraint = int(v.split(':')[0]) * 100 + int(v.split(':')[1]) 148 | v_select = int(entity['arriveBy'].split(':')[0]) * 100 + int(entity['arriveBy'].split(':')[1]) 149 | if v_constraint >= v_select: 150 | match += 1 151 | except (ValueError, IndexError): 152 | match += 1 153 | else: 154 | if v.strip() == entity[self.mapping[domain][k]].strip(): 155 | match += 1 156 | if tot != 0: 157 | score.append(match / tot) 158 | return score 159 | 160 | def _inform_F1_goal(self, goal, sys_history, domains=None): 161 | """ 162 | judge if all the requested information is answered 163 | """ 164 | if domains is None: 165 | domains = self.belief_domains 166 | inform_slot = {} 167 | for domain in domains: 168 | inform_slot[domain] = set() 169 | TP, FP, FN = 0, 0, 0 170 | for da in sys_history: 171 | domain, intent, slot, value = da.split('-', 3) 172 | if intent in ['inform', 'recommend', 'offerbook', 'offerbooked'] and \ 173 | domain in domains and value.strip() not in NUL_VALUE: 174 | inform_slot[domain].add(slot) 175 | for domain in domains: 176 | for k, v in goal[domain].items(): 177 | if v == '?': 178 | if k in inform_slot[domain]: 179 | TP += 1 180 | else: 181 | FN += 1 182 | for k in inform_slot[domain]: 183 | # exclude slots that are informed by users 184 | if k not in goal[domain] \ 185 | and (k in requestable[domain] or k == 'ref'): 186 | FP += 1 187 | return TP, FP, FN 188 | 189 | def _inform_F1_goal_usr(self, goal, usr_history, domains=None): 190 | """ 191 | judge if all the constraint/request information is expressed 192 | """ 193 | if domains is None: 194 | domains = self.belief_domains 195 | inform_slot = {} 196 | request_slot = {} 197 | for domain in domains: 198 | inform_slot[domain] = set() 199 | request_slot[domain] = set() 200 | TP, FP, FN = 0, 0, 0 201 | for da in usr_history: 202 | domain, intent, slot, value = da.split('-', 3) 203 | if intent == 'inform': 204 | inform_slot[domain].add(slot) 205 | elif intent == 'request': 206 | request_slot[domain].add(slot) 207 | for domain in domains: 208 | for k, v in goal[domain].items(): 209 | if v == '?': 210 | if k in request_slot[domain]: 211 | TP += 1 212 | else: 213 | FN += 1 214 | else: 215 | if k in inform_slot[domain]: 216 | TP += 1 217 | else: 218 | FN += 1 219 | for k in inform_slot[domain]: 220 | if k not in goal[domain] \ 221 | and k in informable[domain]: 222 | FP += 1 223 | for k in request_slot[domain]: 224 | if k not in goal[domain] \ 225 | and (k in requestable[domain] or k == 'ref'): 226 | FP += 1 227 | return TP, FP, FN 228 | 229 | def _check_value(self, key, value): 230 | if key == "area": 231 | return value.lower() in ["centre", "east", "south", "west", "north"] 232 | elif key == "arriveBy" or key == "leaveAt": 233 | return time_re.match(value) 234 | elif key == "day": 235 | return value.lower() in ["monday", "tuesday", "wednesday", "thursday", "friday", 236 | "saturday", "sunday"] 237 | elif key == "duration": 238 | return 'minute' in value 239 | elif key == "internet" or key == "parking": 240 | return value in ["yes", "no"] 241 | elif key == "phone": 242 | return re.match(r'^\d{11}$', value) 243 | elif key == "price" or key == "entrance fee": 244 | return 'pound' in value or value in ["free", "?"] 245 | elif key == "pricerange": 246 | return value in ["cheap", "expensive", "moderate", "free"] 247 | elif key == "postcode": 248 | return re.match(r'^cb\d{2,3}[a-z]{2}$', value) 249 | elif key == "stars": 250 | return re.match(r'^\d$', value) 251 | elif key == "trainID": 252 | return re.match(r'^tr\d{4}$', value.lower()) 253 | else: 254 | return True 255 | 256 | def match_rate(self, ref2goal=True, aggregate=True): 257 | if ref2goal: 258 | goal = self.goal 259 | else: 260 | goal = self._init_dict() 261 | for domain in self.belief_domains: 262 | if domain in self.goal and 'book' in self.goal[domain]: 263 | goal[domain]['book'] = True 264 | for da in self.usr_da_array: 265 | d, i, s, v = da.split('-', 3) 266 | if d in self.belief_domains and i == 'inform'\ 267 | and s in informable[d]: 268 | goal[d][s] = v 269 | score = self._match_rate_goal(goal, self.booked) 270 | if aggregate: 271 | return np.mean(score) if score else None 272 | else: 273 | return score 274 | 275 | def inform_F1(self, ref2goal=True, ansbysys=True, aggregate=True): 276 | if ref2goal: 277 | goal = self.goal 278 | else: 279 | goal = self._init_dict() 280 | for da in self.usr_da_array: 281 | d, i, s, v = da.split('-', 3) 282 | if d in self.belief_domains and s in informable[d]: 283 | if i == 'inform': 284 | goal[d][s] = v 285 | elif i == 'request': 286 | goal[d][s] = '?' 287 | if ansbysys: 288 | TP, FP, FN = self._inform_F1_goal(goal, self.sys_da_array) 289 | else: 290 | TP, FP, FN = self._inform_F1_goal_usr(goal, self.usr_da_array) 291 | if aggregate: 292 | try: 293 | rec = TP / (TP + FN) 294 | except ZeroDivisionError: 295 | return None, None, None 296 | try: 297 | prec = TP / (TP + FP) 298 | F1 = 2 * prec * rec / (prec + rec) 299 | except ZeroDivisionError: 300 | return 0, rec, 0 301 | return prec, rec, F1 302 | else: 303 | return [TP, FP, FN] 304 | 305 | def task_success(self, ref2goal=True): 306 | """ 307 | judge if all the domains are successfully completed 308 | """ 309 | book_sess = self.match_rate(ref2goal) 310 | inform_sess = self.inform_F1(ref2goal) 311 | # book rate == 1 & inform recall == 1 312 | if (book_sess == 1 and inform_sess[1] == 1) \ 313 | or (book_sess == 1 and inform_sess[1] is None) \ 314 | or (book_sess is None and inform_sess[1] == 1): 315 | return 1 316 | else: 317 | return 0 318 | 319 | def domain_success(self, domain, ref2goal=True): 320 | """ 321 | judge if the domain (subtask) is successfully completed 322 | """ 323 | if domain not in self.goal: 324 | return None 325 | if domain in self.complete_domain: 326 | return 0 327 | 328 | if ref2goal: 329 | goal = {} 330 | goal[domain] = deepcopy(self.goal[domain]) 331 | else: 332 | goal = self._init_dict() 333 | if 'book' in self.goal[domain]: 334 | goal[domain]['book'] = self.goal[domain]['book'] 335 | for da in self.usr_da_array: 336 | d, i, s, v = da.split('-', 3) 337 | if d != domain: 338 | continue 339 | if s in self.mapping[d]: 340 | if i == 'inform': 341 | goal[d][s] = v 342 | elif i == 'request': 343 | goal[d][s] = '?' 344 | 345 | match_rate = self._match_rate_goal(goal, self.booked, [domain]) 346 | match_rate = np.mean(match_rate) if match_rate else None 347 | 348 | inform = self._inform_F1_goal(goal, self.sys_da_array, [domain]) 349 | try: 350 | inform_rec = inform[0] / (inform[0] + inform[2]) 351 | except ZeroDivisionError: 352 | inform_rec = None 353 | 354 | if (match_rate == 1 and inform_rec == 1) \ 355 | or (match_rate == 1 and inform_rec is None) \ 356 | or (match_rate is None and inform_rec == 1): 357 | self.complete_domain.append(domain) 358 | return 1 359 | else: 360 | return 0 361 | -------------------------------------------------------------------------------- /fetch_data.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data 2 | wget --directory-prefix=data/ https://drive.google.com/open?id=1S2RXrXwsajrdzyyvM0ca_BLfGdb0PBgD 3 | unzip data/MADPLdata.zip -d data/ 4 | -------------------------------------------------------------------------------- /hybridv.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: ryuichi takanobu 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | class HybridValue(nn.Module): 10 | def __init__(self, cfg): 11 | super(HybridValue, self).__init__() 12 | 13 | self.net_sys_s = nn.Sequential(nn.Linear(cfg.s_dim, cfg.hs_dim), 14 | nn.ReLU(), 15 | nn.Linear(cfg.hs_dim, cfg.hs_dim), 16 | nn.Tanh()) 17 | self.net_usr_s = nn.Sequential(nn.Linear(cfg.s_dim_usr, cfg.hs_dim), 18 | nn.ReLU(), 19 | nn.Linear(cfg.hs_dim, cfg.hs_dim), 20 | nn.Tanh()) 21 | 22 | self.net_sys = nn.Sequential(nn.Linear(cfg.hs_dim, cfg.h_dim), 23 | nn.ReLU(), 24 | nn.Linear(cfg.h_dim, 1)) 25 | self.net_usr = nn.Sequential(nn.Linear(cfg.hs_dim, cfg.h_dim), 26 | nn.ReLU(), 27 | nn.Linear(cfg.h_dim, 1)) 28 | self.net_global = nn.Sequential(nn.Linear(cfg.hs_dim+cfg.hs_dim, cfg.h_dim), 29 | nn.ReLU(), 30 | nn.Linear(cfg.h_dim, 1)) 31 | 32 | def forward(self, s, character): 33 | if character == 'sys': 34 | h_s_sys = self.net_sys_s(s) 35 | v = self.net_sys(h_s_sys) 36 | elif character == 'usr': 37 | h_s_usr = self.net_usr_s(s) 38 | v = self.net_usr(h_s_usr) 39 | elif character == 'global': 40 | h_s_usr = self.net_usr_s(s[0]) 41 | h_s_sys = self.net_sys_s(s[1]) 42 | h = torch.cat([h_s_usr, h_s_sys], -1) 43 | v = self.net_global(h) 44 | else: 45 | raise NotImplementedError('Unknown character {}'.format(character)) 46 | return v.squeeze(-1) 47 | -------------------------------------------------------------------------------- /learner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: ryuichi takanobu 4 | """ 5 | 6 | import os 7 | import pickle 8 | import torch 9 | import torch.nn as nn 10 | import logging 11 | import random 12 | import numpy as np 13 | from torch import optim 14 | from policy import MultiDiscretePolicy 15 | from utils import state_vectorize, state_vectorize_user 16 | from hybridv import HybridValue 17 | from torch import multiprocessing as mp 18 | from collections import namedtuple 19 | from torch.utils.tensorboard import SummaryWriter 20 | 21 | try: 22 | mp = mp.get_context('spawn') 23 | except RuntimeError: 24 | pass 25 | 26 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 27 | 28 | Transition = namedtuple('Transition', ('state_usr', 'action_usr', 'reward_usr', 'state_usr_next', \ 29 | 'state_sys', 'action_sys', 'reward_sys', 'state_sys_next', \ 30 | 'mask', 'reward_global')) 31 | 32 | class Memory(object): 33 | 34 | def __init__(self): 35 | self.memory = [] 36 | 37 | def push(self, *args): 38 | """Saves a transition.""" 39 | self.memory.append(Transition(*args)) 40 | 41 | def get_batch(self, batch_size=None): 42 | if batch_size is None: 43 | return Transition(*zip(*self.memory)) 44 | else: 45 | random_batch = random.sample(self.memory, batch_size) 46 | return Transition(*zip(*random_batch)) 47 | 48 | def append(self, new_memory): 49 | self.memory += new_memory.memory 50 | 51 | def __len__(self): 52 | return len(self.memory) 53 | 54 | 55 | def sampler(pid, queue, evt, env, policy_usr, policy_sys, batchsz): 56 | """ 57 | This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple 58 | processes. 59 | :param pid: process id 60 | :param queue: multiprocessing.Queue, to collect sampled data 61 | :param evt: multiprocessing.Event, to keep the process alive 62 | :param env: environment instance 63 | :param policy: policy network, to generate action from current policy 64 | :param batchsz: total sampled items 65 | :return: 66 | """ 67 | buff = Memory() 68 | 69 | # we need to sample batchsz of (state, action, next_state, reward, mask) 70 | # each trajectory contains `trajectory_len` num of items, so we only need to sample 71 | # `batchsz//trajectory_len` num of trajectory totally 72 | # the final sampled number may be larger than batchsz. 73 | 74 | sampled_num = 0 75 | sampled_traj_num = 0 76 | traj_len = 40 77 | real_traj_len = 0 78 | 79 | while sampled_num < batchsz: 80 | # for each trajectory, we reset the env and get initial state 81 | s = env.reset() 82 | 83 | for t in range(traj_len): 84 | 85 | # [s_dim_usr] => [a_dim_usr] 86 | s_vec = torch.Tensor(state_vectorize_user(s, env.cfg, env.evaluator.cur_domain)) 87 | a = policy_usr.select_action(s_vec.to(device=DEVICE)).cpu() 88 | 89 | # interact with env, done is a flag indicates ending or not 90 | next_s, done = env.step_usr(s, a) 91 | 92 | # [s_dim] => [a_dim] 93 | next_s_vec = torch.Tensor(state_vectorize(next_s, env.cfg, env.db)) 94 | next_a = policy_sys.select_action(next_s_vec.to(device=DEVICE)).cpu() 95 | 96 | # interact with env 97 | s = env.step_sys(next_s, next_a) 98 | 99 | # get reward compared to demonstrations 100 | if done: 101 | env.set_rollout(True) 102 | s_vec_next = torch.Tensor(state_vectorize_user(s, env.cfg, env.evaluator.cur_domain)) 103 | a_next = torch.zeros_like(a) 104 | next_s_next, _ = env.step_usr(s, a_next) 105 | next_s_vec_next = torch.Tensor(state_vectorize(next_s_next, env.cfg, env.db)) 106 | env.set_rollout(False) 107 | 108 | r_usr = 20 if env.evaluator.inform_F1(ansbysys=False)[1] == 1. else -5 109 | r_sys = 20 if env.evaluator.task_success(False) else -5 110 | r_global = 20 if env.evaluator.task_success() else -5 111 | else: 112 | # one step roll out 113 | env.set_rollout(True) 114 | s_vec_next = torch.Tensor(state_vectorize_user(s, env.cfg, env.evaluator.cur_domain)) 115 | a_next = policy_usr.select_action(s_vec_next.to(device=DEVICE)).cpu() 116 | next_s_next, _ = env.step_usr(s, a_next) 117 | next_s_vec_next = torch.Tensor(state_vectorize(next_s_next, env.cfg, env.db)) 118 | env.set_rollout(False) 119 | 120 | r_usr = 0 121 | if not s['user_action']: 122 | r_usr -= 5 123 | if env.evaluator.cur_domain: 124 | for da in s['user_action']: 125 | d, i, k = da.split('-') 126 | if i == 'request': 127 | for slot, value in s['goal_state'][d].items(): 128 | if value != '?' and slot in s['user_goal'][d]\ 129 | and s['user_goal'][d][slot] != '?': 130 | # request before express constraint 131 | r_usr -= 1 132 | r_sys = 0 133 | if not next_s['sys_action']: 134 | r_sys -= 5 135 | if env.evaluator.cur_domain: 136 | for slot, value in next_s['belief_state'][env.evaluator.cur_domain].items(): 137 | if value == '?': 138 | for da in next_s['sys_action']: 139 | d, i, k, p = da.split('-') 140 | if i in ['inform', 'recommend', 'offerbook', 'offerbooked'] and k == slot: 141 | break 142 | else: 143 | # not answer request 144 | r_sys -= 1 145 | r_global = 5 if env.evaluator.cur_domain and env.evaluator.domain_success(env.evaluator.cur_domain) else -1 146 | 147 | # save to queue 148 | buff.push(s_vec.numpy(), a.numpy(), r_usr, s_vec_next.numpy(), next_s_vec.numpy(), next_a.numpy(), r_sys, next_s_vec_next.numpy(), done, r_global) 149 | 150 | # update per step 151 | real_traj_len = t 152 | 153 | if done: 154 | break 155 | 156 | # this is end of one trajectory 157 | sampled_num += real_traj_len 158 | sampled_traj_num += 1 159 | # t indicates the valid trajectory length 160 | 161 | # this is end of sampling all batchsz of items. 162 | # when sampling is over, push all buff data into queue 163 | queue.put([pid, buff]) 164 | evt.wait() 165 | 166 | class Learner(): 167 | 168 | def __init__(self, env_cls, args, cfg, process_num, infer=False): 169 | self.policy_sys = MultiDiscretePolicy(cfg).to(device=DEVICE) 170 | self.policy_usr = MultiDiscretePolicy(cfg, 'usr').to(device=DEVICE) 171 | self.vnet = HybridValue(cfg).to(device=DEVICE) 172 | 173 | # initialize envs for each process 174 | self.env_list = [] 175 | for _ in range(process_num): 176 | self.env_list.append(env_cls(args.data_dir, cfg)) 177 | 178 | self.policy_sys.eval() 179 | self.policy_usr.eval() 180 | self.vnet.eval() 181 | self.infer = infer 182 | 183 | if not infer: 184 | self.l2_loss = nn.MSELoss() 185 | self.multi_entropy_loss = nn.BCEWithLogitsLoss() 186 | self.target_vnet = HybridValue(cfg).to(device=DEVICE) 187 | self.episode_num = 0 188 | self.last_target_update_episode = 0 189 | self.target_update_interval = args.interval 190 | 191 | self.policy_sys_optim = optim.RMSprop(self.policy_sys.parameters(), lr=args.lr_policy) 192 | self.policy_usr_optim = optim.RMSprop(self.policy_usr.parameters(), lr=args.lr_policy) 193 | self.vnet_optim = optim.RMSprop(self.vnet.parameters(), lr=args.lr_vnet, weight_decay=args.weight_decay) 194 | 195 | self.gamma = args.gamma 196 | self.grad_norm_clip = args.clip 197 | self.optim_batchsz = args.batchsz 198 | self.save_per_epoch = args.save_per_epoch 199 | self.save_dir = args.save_dir 200 | self.process_num = process_num 201 | self.writer = SummaryWriter() 202 | 203 | def _update_targets(self): 204 | self.target_vnet.load_state_dict(self.vnet.state_dict()) 205 | logging.info('Updated target network') 206 | 207 | def evaluate(self, N): 208 | logging.info('eval: user 2 system') 209 | env = self.env_list[0] 210 | traj_len = 40 211 | turn_tot, inform_tot, match_tot, success_tot = [], [], [], [] 212 | for seed in range(N): 213 | s = env.reset(seed) 214 | print('seed', seed) 215 | print('origin goal', env.goal) 216 | print('goal', env.evaluator.goal) 217 | for t in range(traj_len): 218 | s_vec = torch.Tensor(state_vectorize_user(s, env.cfg, env.evaluator.cur_domain)).to(device=DEVICE) 219 | # mode with policy during evaluation 220 | a = self.policy_usr.select_action(s_vec, False) 221 | next_s, done = env.step_usr(s, a) 222 | 223 | next_s_vec = torch.Tensor(state_vectorize(next_s, env.cfg, env.db)).to(device=DEVICE) 224 | next_a = self.policy_sys.select_action(next_s_vec, False) 225 | s = env.step_sys(next_s, next_a) 226 | 227 | print('usr', s['user_action']) 228 | print('sys', s['sys_action']) 229 | 230 | if done: 231 | break 232 | 233 | turn_tot.append(env.time_step//2) 234 | match_tot += env.evaluator.match_rate(aggregate=False) 235 | inform_tot.append(env.evaluator.inform_F1(aggregate=False)) 236 | print('turn', env.time_step//2) 237 | match_session = env.evaluator.match_rate() 238 | print('match', match_session) 239 | inform_session = env.evaluator.inform_F1() 240 | print('inform', inform_session) 241 | if (match_session == 1 and inform_session[1] == 1) \ 242 | or (match_session == 1 and inform_session[1] is None) \ 243 | or (match_session is None and inform_session[1] == 1): 244 | print('success', 1) 245 | success_tot.append(1) 246 | else: 247 | print('success', 0) 248 | success_tot.append(0) 249 | 250 | logging.info('turn {}'.format(np.mean(turn_tot))) 251 | logging.info('match {}'.format(np.mean(match_tot))) 252 | TP, FP, FN = np.sum(inform_tot, 0) 253 | prec = TP / (TP + FP) 254 | rec = TP / (TP + FN) 255 | F1 = 2 * prec * rec / (prec + rec) 256 | logging.info('inform rec {}, F1 {}'.format(rec, F1)) 257 | logging.info('success {}'.format(np.mean(success_tot))) 258 | 259 | def evaluate_with_agenda(self, env, N): 260 | logging.info('eval: agenda 2 system') 261 | traj_len = 40 262 | turn_tot, inform_tot, match_tot, success_tot = [], [], [], [] 263 | for seed in range(N): 264 | s = env.reset(seed) 265 | print('seed', seed) 266 | print('goal', env.goal.domain_goals) 267 | print('usr', s['user_action']) 268 | for t in range(traj_len): 269 | s_vec = torch.Tensor(state_vectorize(s, env.cfg, env.db)).to(device=DEVICE) 270 | # mode with policy during evaluation 271 | a = self.policy_sys.select_action(s_vec, False) 272 | next_s, done = env.step(s, a.cpu()) 273 | s = next_s 274 | print('sys', s['sys_action']) 275 | print('usr', s['user_action']) 276 | if done: 277 | break 278 | s_vec = torch.Tensor(state_vectorize(s, env.cfg, env.db)).to(device=DEVICE) 279 | # mode with policy during evaluation 280 | a = self.policy_sys.select_action(s_vec, False) 281 | s = env.update_belief_sys(s, a.cpu()) 282 | print('sys', s['sys_action']) 283 | 284 | assert(env.time_step % 2 == 0) 285 | turn_tot.append(env.time_step//2) 286 | match_tot += env.evaluator.match_rate(aggregate=False) 287 | inform_tot.append(env.evaluator.inform_F1(aggregate=False)) 288 | print('turn', env.time_step//2) 289 | match_session = env.evaluator.match_rate() 290 | print('match', match_session) 291 | inform_session = env.evaluator.inform_F1() 292 | print('inform', inform_session) 293 | if (match_session == 1 and inform_session[1] == 1) \ 294 | or (match_session == 1 and inform_session[1] is None) \ 295 | or (match_session is None and inform_session[1] == 1): 296 | print('success', 1) 297 | success_tot.append(1) 298 | else: 299 | print('success', 0) 300 | success_tot.append(0) 301 | 302 | logging.info('turn {}'.format(np.mean(turn_tot))) 303 | logging.info('match {}'.format(np.mean(match_tot))) 304 | TP, FP, FN = np.sum(inform_tot, 0) 305 | prec = TP / (TP + FP) 306 | rec = TP / (TP + FN) 307 | F1 = 2 * prec * rec / (prec + rec) 308 | logging.info('inform rec {}, F1 {}'.format(rec, F1)) 309 | logging.info('success {}'.format(np.mean(success_tot))) 310 | 311 | 312 | def evaluate_with_rule(self, env, N): 313 | logging.info('eval: user 2 rule') 314 | traj_len = 40 315 | turn_tot, inform_tot, match_tot, success_tot = [], [], [], [] 316 | for seed in range(N): 317 | s = env.reset(seed) 318 | print('seed', seed) 319 | print('goal', env.evaluator.goal) 320 | for t in range(traj_len): 321 | s_vec = torch.Tensor(state_vectorize_user(s, env.cfg, env.evaluator.cur_domain)).to(device=DEVICE) 322 | # mode with policy during evaluation 323 | a = self.policy_usr.select_action(s_vec, False) 324 | next_s = env.step(s, a.cpu()) 325 | s = next_s 326 | print('usr', s['user_action']) 327 | print('sys', s['sys_action']) 328 | done = s['others']['terminal'] 329 | if done: 330 | break 331 | 332 | assert(env.time_step % 2 == 0) 333 | turn_tot.append(env.time_step//2) 334 | match_tot += env.evaluator.match_rate(aggregate=False) 335 | inform_tot.append(env.evaluator.inform_F1(aggregate=False)) 336 | print('turn', env.time_step//2) 337 | match_session = env.evaluator.match_rate() 338 | print('match', match_session) 339 | inform_session = env.evaluator.inform_F1() 340 | print('inform', inform_session) 341 | if (match_session == 1 and inform_session[1] == 1) \ 342 | or (match_session == 1 and inform_session[1] is None) \ 343 | or (match_session is None and inform_session[1] == 1): 344 | print('success', 1) 345 | success_tot.append(1) 346 | else: 347 | print('success', 0) 348 | success_tot.append(0) 349 | 350 | logging.info('turn {}'.format(np.mean(turn_tot))) 351 | logging.info('match {}'.format(np.mean(match_tot))) 352 | TP, FP, FN = np.sum(inform_tot, 0) 353 | prec = TP / (TP + FP) 354 | rec = TP / (TP + FN) 355 | F1 = 2 * prec * rec / (prec + rec) 356 | logging.info('inform rec {}, F1 {}'.format(rec, F1)) 357 | logging.info('success {}'.format(np.mean(success_tot))) 358 | 359 | def save(self, directory, epoch): 360 | if not os.path.exists(directory): 361 | os.makedirs(directory) 362 | os.makedirs(directory + '/usr') 363 | os.makedirs(directory + '/sys') 364 | os.makedirs(directory + '/vnet') 365 | 366 | torch.save(self.policy_usr.state_dict(), directory + '/usr/' + str(epoch) + '_pol.mdl') 367 | torch.save(self.policy_sys.state_dict(), directory + '/sys/' + str(epoch) + '_pol.mdl') 368 | torch.save(self.vnet.state_dict(), directory + '/vnet/' + str(epoch) + '_vnet.mdl') 369 | 370 | logging.info('<> epoch {}: saved network to mdl'.format(epoch)) 371 | 372 | def load(self, filename): 373 | 374 | directory, epoch = filename.rsplit('/', 1) 375 | 376 | policy_usr_mdl = directory + '/usr/' + epoch + '_pol.mdl' 377 | if os.path.exists(policy_usr_mdl): 378 | self.policy_usr.load_state_dict(torch.load(policy_usr_mdl)) 379 | logging.info('<> loaded checkpoint from file: {}'.format(policy_usr_mdl)) 380 | 381 | policy_sys_mdl = directory + '/sys/' + epoch + '_pol.mdl' 382 | if os.path.exists(policy_sys_mdl): 383 | self.policy_sys.load_state_dict(torch.load(policy_sys_mdl)) 384 | logging.info('<> loaded checkpoint from file: {}'.format(policy_sys_mdl)) 385 | 386 | if not self.infer: 387 | self._update_targets() 388 | 389 | best_pkl = filename + '.pkl' 390 | if os.path.exists(best_pkl): 391 | with open(best_pkl, 'rb') as f: 392 | best = pickle.load(f) 393 | else: 394 | best = float('-inf') 395 | return best 396 | 397 | def sample(self, batchsz): 398 | """ 399 | Given batchsz number of task, the batchsz will be split equally to each processes 400 | and when processes return, it merge all data and return 401 | :param batchsz: 402 | :return: batch 403 | """ 404 | 405 | # batchsz will be split into each process, 406 | # final batchsz maybe larger than batchsz parameters 407 | process_batchsz = np.ceil(batchsz / self.process_num).astype(np.int32) 408 | # buffer to save all data 409 | queue = mp.Queue() 410 | 411 | # start processes for pid in range(1, processnum) 412 | # if processnum = 1, this part will be ignored. 413 | # when save tensor in Queue, the process should keep alive till Queue.get(), 414 | # please refer to : https://discuss.pytorch.org/t/using-torch-tensor-over-multiprocessing-queue-process-fails/2847 415 | # however still some problem on CUDA tensors on multiprocessing queue, 416 | # please refer to : https://discuss.pytorch.org/t/cuda-tensors-on-multiprocessing-queue/28626 417 | # so just transform tensors into numpy, then put them into queue. 418 | evt = mp.Event() 419 | processes = [] 420 | for i in range(self.process_num): 421 | process_args = (i, queue, evt, self.env_list[i], self.policy_usr, self.policy_sys, process_batchsz) 422 | processes.append(mp.Process(target=sampler, args=process_args)) 423 | for p in processes: 424 | # set the process as daemon, and it will be killed once the main process is stoped. 425 | p.daemon = True 426 | p.start() 427 | 428 | # we need to get the first Memory object and then merge others Memory use its append function. 429 | pid0, buff0 = queue.get() 430 | for _ in range(1, self.process_num): 431 | pid, buff_ = queue.get() 432 | buff0.append(buff_) # merge current Memory into buff0 433 | evt.set() 434 | 435 | # now buff saves all the sampled data 436 | buff = buff0 437 | 438 | return buff.get_batch() 439 | 440 | def update(self, batchsz, epoch, best=None): 441 | """ 442 | firstly sample batchsz items and then perform optimize algorithms. 443 | :param batchsz: 444 | :param epoch: 445 | :param best: 446 | :return: 447 | """ 448 | backward = True if best is None else False 449 | if backward: 450 | self.policy_usr.train() 451 | self.policy_sys.train() 452 | self.vnet.train() 453 | 454 | # 1. sample data asynchronously 455 | batch = self.sample(batchsz) 456 | 457 | policy_usr_loss, policy_sys_loss, vnet_usr_loss, vnet_sys_loss, vnet_glo_loss = 0., 0., 0., 0., 0. 458 | 459 | # data in batch is : batch.state: ([1, s_dim], [1, s_dim]...) 460 | # batch.action: ([1, a_dim], [1, a_dim]...) 461 | # batch.reward/batch.mask: ([1], [1]...) 462 | s_usr = torch.from_numpy(np.stack(batch.state_usr)).to(device=DEVICE) 463 | a_usr = torch.from_numpy(np.stack(batch.action_usr)).to(device=DEVICE) 464 | r_usr = torch.Tensor(np.stack(batch.reward_usr)).to(device=DEVICE) 465 | s_usr_next = torch.from_numpy(np.stack(batch.state_usr_next)).to(device=DEVICE) 466 | s_sys = torch.from_numpy(np.stack(batch.state_sys)).to(device=DEVICE) 467 | a_sys = torch.from_numpy(np.stack(batch.action_sys)).to(device=DEVICE) 468 | r_sys = torch.Tensor(np.stack(batch.reward_sys)).to(device=DEVICE) 469 | s_sys_next = torch.from_numpy(np.stack(batch.state_sys_next)).to(device=DEVICE) 470 | ternimal = torch.Tensor(np.stack(batch.mask)).to(device=DEVICE) 471 | r_glo = torch.Tensor(np.stack(batch.reward_global)).to(device=DEVICE) 472 | batchsz = s_usr.size(0) 473 | 474 | if not backward: 475 | reward = r_usr.mean().item() + r_sys.mean().item() + r_glo.mean().item() 476 | logging.debug('validation, epoch {}, reward {}'.format(epoch, reward)) 477 | self.writer.add_scalar('train/reward', reward, epoch) 478 | if reward > best: 479 | logging.info('best model saved') 480 | best = reward 481 | self.save(self.save_dir, 'best') 482 | with open(self.save_dir+'/best.pkl', 'wb') as f: 483 | pickle.dump(best, f) 484 | return best 485 | else: 486 | logging.debug('epoch {}, reward: usr {}, sys {}, global {}'.format(epoch, r_usr.mean().item(), r_sys.mean().item(), r_glo.mean().item())) 487 | 488 | # 6. update dialog policy 489 | 490 | # 1. shuffle current batch 491 | perm = torch.randperm(batchsz) 492 | # shuffle the variable for mutliple optimize 493 | s_usr_shuf, a_usr_shuf, r_usr_shuf, s_usr_next_shuf, s_sys_shuf, a_sys_shuf, r_sys_shuf, s_sys_next_shuf, terminal_shuf, r_glo_shuf = \ 494 | s_usr[perm], a_usr[perm], r_usr[perm], s_usr_next[perm], s_sys[perm], a_sys[perm], r_sys[perm], s_sys_next[perm], ternimal[perm], r_glo[perm] 495 | 496 | # 2. get mini-batch for optimizing 497 | optim_chunk_num = int(np.ceil(batchsz / self.optim_batchsz)) 498 | # chunk the optim_batch for total batch 499 | s_usr_shuf, a_usr_shuf, r_usr_shuf, s_usr_next_shuf, s_sys_shuf, a_sys_shuf, r_sys_shuf, s_sys_next_shuf, terminal_shuf, r_glo_shuf = \ 500 | torch.chunk(s_usr_shuf, optim_chunk_num), torch.chunk(a_usr_shuf, optim_chunk_num), torch.chunk(r_usr_shuf, optim_chunk_num), torch.chunk(s_usr_next_shuf, optim_chunk_num),\ 501 | torch.chunk(s_sys_shuf, optim_chunk_num), torch.chunk(a_sys_shuf, optim_chunk_num), torch.chunk(r_sys_shuf, optim_chunk_num), torch.chunk(s_sys_next_shuf, optim_chunk_num),\ 502 | torch.chunk(terminal_shuf, optim_chunk_num), torch.chunk(r_glo_shuf, optim_chunk_num) 503 | 504 | # 3. iterate all mini-batch to optimize 505 | for s_usr_b, a_usr_b, r_usr_b, s_usr_next_b, s_sys_b, a_sys_b, r_sys_b, s_sys_next_b, t_b, r_glo_b in \ 506 | zip(s_usr_shuf, a_usr_shuf, r_usr_shuf, s_usr_next_shuf,\ 507 | s_sys_shuf, a_sys_shuf, r_sys_shuf, s_sys_next_shuf,\ 508 | terminal_shuf, r_glo_shuf): 509 | 510 | # 1. update critic network 511 | 512 | # update usr vnet 513 | vals_usr = self.vnet(s_usr_b, 'usr') 514 | target_usr = r_usr_b + self.gamma * (1-t_b) * self.target_vnet(s_usr_next_b, 'usr') 515 | loss_usr = self.l2_loss(vals_usr, target_usr) 516 | vnet_usr_loss += loss_usr.item() 517 | 518 | # update sys vnet 519 | vals_sys = self.vnet(s_sys_b, 'sys') 520 | target_sys = r_sys_b + self.gamma * (1-t_b) * self.target_vnet(s_sys_next_b, 'sys') 521 | loss_sys = self.l2_loss(vals_sys, target_sys) 522 | vnet_sys_loss += loss_sys.item() 523 | 524 | # update global vnet 525 | vals_glo = self.vnet((s_usr_b, s_sys_b), 'global') 526 | target_glo = r_glo_b + self.gamma * (1-t_b) * self.target_vnet((s_usr_next_b, s_sys_next_b), 'global') 527 | loss_glo = self.l2_loss(vals_glo, target_glo) 528 | vnet_glo_loss += loss_glo.item() 529 | 530 | self.vnet_optim.zero_grad() 531 | loss = loss_usr + loss_sys + loss_glo 532 | loss.backward() 533 | torch.nn.utils.clip_grad_norm_(self.vnet.parameters(), self.grad_norm_clip) 534 | self.vnet_optim.step() 535 | 536 | self.episode_num += 1 537 | if (self.episode_num - self.last_target_update_episode) / self.target_update_interval >= 1.0: 538 | self._update_targets() 539 | self.last_target_update_episode = self.episode_num 540 | 541 | # 2. update actor network 542 | 543 | # estimate advantage using current critic 544 | td_error_usr = r_usr_b + self.gamma * (1-t_b) * self.vnet(s_usr_next_b, 'usr') - self.vnet(s_usr_b, 'usr') 545 | td_error_sys = r_sys_b + self.gamma * (1-t_b) * self.vnet(s_sys_next_b, 'sys') - self.vnet(s_sys_b, 'sys') 546 | td_error_glo = r_glo_b + self.gamma * (1-t_b) * self.vnet((s_usr_next_b, s_sys_next_b), 'global') - self.vnet((s_usr_b, s_sys_b), 'global') 547 | 548 | self.policy_usr_optim.zero_grad() 549 | # [b, 1] 550 | log_pi_sa = self.policy_usr.get_log_prob(s_usr_b, a_usr_b) 551 | # this is element-wise comparing. 552 | # we add negative symbol to convert gradient ascent to gradient descent 553 | surrogate = - (log_pi_sa * (td_error_usr + td_error_glo)).mean() 554 | policy_usr_loss += surrogate.item() 555 | 556 | # backprop 557 | surrogate.backward(retain_graph=True) 558 | # gradient clipping, for stability 559 | torch.nn.utils.clip_grad_norm(self.policy_usr.parameters(), self.grad_norm_clip) 560 | # self.lock.acquire() # retain lock to update weights 561 | self.policy_usr_optim.step() 562 | # self.lock.release() # release lock 563 | 564 | self.policy_sys_optim.zero_grad() 565 | # [b, 1] 566 | log_pi_sa = self.policy_sys.get_log_prob(s_sys_b, a_sys_b) 567 | # this is element-wise comparing. 568 | # we add negative symbol to convert gradient ascent to gradient descent 569 | surrogate = - (log_pi_sa * (td_error_sys + td_error_glo)).mean() 570 | policy_sys_loss += surrogate.item() 571 | 572 | # backprop 573 | surrogate.backward() 574 | # gradient clipping, for stability 575 | torch.nn.utils.clip_grad_norm(self.policy_sys.parameters(), self.grad_norm_clip) 576 | # self.lock.acquire() # retain lock to update weights 577 | self.policy_sys_optim.step() 578 | # self.lock.release() # release lock 579 | 580 | vnet_usr_loss /= optim_chunk_num 581 | vnet_sys_loss /= optim_chunk_num 582 | vnet_glo_loss /= optim_chunk_num 583 | policy_usr_loss /= optim_chunk_num 584 | policy_sys_loss /= optim_chunk_num 585 | 586 | logging.debug('epoch {}, policy: usr {}, sys {}, value network: usr {}, sys {}, global {}'.format(epoch, \ 587 | policy_usr_loss, policy_sys_loss, vnet_usr_loss, vnet_sys_loss, vnet_glo_loss)) 588 | self.writer.add_scalar('train/usr_policy_loss', policy_usr_loss, epoch) 589 | self.writer.add_scalar('train/sys_policy_loss', policy_sys_loss, epoch) 590 | self.writer.add_scalar('train/vnet_usr_loss', vnet_usr_loss, epoch) 591 | self.writer.add_scalar('train/vnet_sys_loss', vnet_sys_loss, epoch) 592 | self.writer.add_scalar('train/vnet_glo_loss', vnet_glo_loss, epoch) 593 | 594 | if (epoch+1) % self.save_per_epoch == 0: 595 | self.save(self.save_dir, epoch) 596 | with open(self.save_dir+'/'+str(epoch)+'.pkl', 'wb') as f: 597 | pickle.dump(best, f) 598 | self.policy_usr.eval() 599 | self.policy_sys.eval() 600 | self.vnet.eval() 601 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: ryuichi takanobu 4 | """ 5 | import sys 6 | import time 7 | import logging 8 | from utils import get_parser, init_logging_handler 9 | from datamanager import DataManager 10 | from config import MultiWozConfig 11 | from torch import multiprocessing as mp 12 | from policy import Policy 13 | from learner import Learner 14 | from controller import Controller 15 | from agenda import UserAgenda 16 | from rule import SystemRule 17 | 18 | def worker_policy_sys(args, manager, config): 19 | init_logging_handler(args.log_dir, '_policy_sys') 20 | agent = Policy(None, args, manager, config, 0, 'sys', True) 21 | 22 | best = float('inf') 23 | for e in range(args.epoch): 24 | agent.imitating(e) 25 | best = agent.imit_test(e, best) 26 | 27 | def worker_policy_usr(args, manager, config): 28 | init_logging_handler(args.log_dir, '_policy_usr') 29 | agent = Policy(None, args, manager, config, 0, 'usr', True) 30 | 31 | best = float('inf') 32 | for e in range(args.epoch): 33 | agent.imitating(e) 34 | best = agent.imit_test(e, best) 35 | 36 | def make_env(data_dir, config): 37 | controller = Controller(data_dir, config) 38 | return controller 39 | 40 | def make_env_rule(data_dir, config): 41 | env = SystemRule(data_dir, config) 42 | return env 43 | 44 | def make_env_agenda(data_dir, config): 45 | env = UserAgenda(data_dir, config) 46 | return env 47 | 48 | if __name__ == '__main__': 49 | parser = get_parser() 50 | argv = sys.argv[1:] 51 | args, _ = parser.parse_known_args(argv) 52 | 53 | if args.config == 'multiwoz': 54 | config = MultiWozConfig() 55 | else: 56 | raise NotImplementedError('Config of the dataset {} not implemented'.format(args.config)) 57 | 58 | init_logging_handler(args.log_dir) 59 | logging.debug(str(args)) 60 | 61 | try: 62 | mp = mp.get_context('spawn') 63 | except RuntimeError: 64 | pass 65 | 66 | if args.pretrain: 67 | logging.debug('pretrain') 68 | 69 | manager = DataManager(args.data_dir, config) 70 | processes = [] 71 | process_args = (args, manager, config) 72 | processes.append(mp.Process(target=worker_policy_sys, args=process_args)) 73 | processes.append(mp.Process(target=worker_policy_usr, args=process_args)) 74 | for p in processes: 75 | p.start() 76 | 77 | for p in processes: 78 | p.join() 79 | 80 | elif args.test: 81 | logging.debug('test') 82 | logging.disable(logging.DEBUG) 83 | 84 | agent = Learner(make_env, args, config, 1, infer=True) 85 | agent.load(args.load) 86 | agent.evaluate(args.test_case) 87 | 88 | # test system policy with agenda 89 | env = make_env_agenda(args.data_dir, config) 90 | agent.evaluate_with_agenda(env, args.test_case) 91 | 92 | # test user policy with rule 93 | env = make_env_rule(args.data_dir, config) 94 | agent.evaluate_with_rule(env, args.test_case) 95 | 96 | else: # training 97 | current_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) 98 | logging.debug('train {}'.format(current_time)) 99 | 100 | agent = Learner(make_env, args, config, args.process) 101 | best = agent.load(args.load) 102 | 103 | for i in range(args.epoch): 104 | agent.update(args.batchsz_traj, i) 105 | # validation 106 | best = agent.update(args.batchsz, i, best) 107 | current_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) 108 | logging.debug('epoch {} {}'.format(i, current_time)) 109 | -------------------------------------------------------------------------------- /policy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: ryuichi takanobu 4 | """ 5 | import os 6 | import logging 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch import optim 11 | 12 | from utils import to_device 13 | from evaluator import MultiWozEvaluator 14 | from torch.utils.tensorboard import SummaryWriter 15 | 16 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | 18 | class MultiDiscretePolicy(nn.Module): 19 | def __init__(self, cfg, character='sys'): 20 | super(MultiDiscretePolicy, self).__init__() 21 | 22 | if character == 'sys': 23 | self.net = nn.Sequential(nn.Linear(cfg.s_dim, cfg.h_dim), 24 | nn.ReLU(), 25 | nn.Linear(cfg.h_dim, cfg.h_dim), 26 | nn.ReLU(), 27 | nn.Linear(cfg.h_dim, cfg.a_dim)) 28 | elif character == 'usr': 29 | self.net = nn.Sequential(nn.Linear(cfg.s_dim_usr, cfg.h_dim), 30 | nn.ReLU(), 31 | nn.Linear(cfg.h_dim, cfg.h_dim), 32 | nn.ReLU(), 33 | nn.Linear(cfg.h_dim, cfg.a_dim_usr)) 34 | else: 35 | raise NotImplementedError('Unknown character {}'.format(character)) 36 | 37 | def forward(self, s): 38 | # [b, s_dim] => [b, a_dim] 39 | a_weights = self.net(s) 40 | 41 | return a_weights 42 | 43 | def select_action(self, s, sample=True): 44 | """ 45 | :param s: [s_dim] 46 | :return: [a_dim] 47 | """ 48 | # forward to get action probs 49 | # [s_dim] => [a_dim] 50 | a_weights = self.forward(s) 51 | a_probs = torch.sigmoid(a_weights) 52 | 53 | # [a_dim] => [a_dim, 2] 54 | a_probs = a_probs.unsqueeze(1) 55 | a_probs = torch.cat([1-a_probs, a_probs], 1) 56 | 57 | # [a_dim, 2] => [a_dim] 58 | a = a_probs.multinomial(1).squeeze(1) if sample else a_probs.argmax(1) 59 | 60 | return a 61 | 62 | def batch_select_action(self, s, sample=False): 63 | """ 64 | :param s: [b, s_dim] 65 | :return: [b, a_dim] 66 | """ 67 | # forward to get action probs 68 | # [b, s_dim] => [b, a_dim] 69 | a_weights = self.forward(s) 70 | a_probs = torch.sigmoid(a_weights) 71 | 72 | # [b, a_dim] => [b, a_dim, 2] 73 | a_probs = a_probs.unsqueeze(2) 74 | a_probs = torch.cat([1-a_probs, a_probs], 2) 75 | 76 | # [b, a_dim, 2] => [b*a_dim, 2] => [b*a_dim, 1] => [b*a_dim] => [b, a_dim] 77 | a = a_probs.reshape(-1, 2).multinomial(1).squeeze(1).reshape(a_weights.shape) if sample else a_probs.argmax(2) 78 | 79 | return a 80 | 81 | def get_log_prob(self, s, a): 82 | """ 83 | :param s: [b, s_dim] 84 | :param a: [b, a_dim] 85 | :return: [b, 1] 86 | """ 87 | # forward to get action probs 88 | # [b, s_dim] => [b, a_dim] 89 | a_weights = self.forward(s) 90 | a_probs = torch.sigmoid(a_weights) 91 | 92 | # [b, a_dim] => [b, a_dim, 2] 93 | a_probs = a_probs.unsqueeze(-1) 94 | a_probs = torch.cat([1-a_probs, a_probs], -1) 95 | 96 | # [b, a_dim, 2] => [b, a_dim] 97 | trg_a_probs = a_probs.gather(-1, a.unsqueeze(-1)).squeeze(-1) 98 | log_prob = torch.log(trg_a_probs) 99 | 100 | return log_prob.sum(-1, keepdim=True) 101 | 102 | 103 | class Policy(object): 104 | def __init__(self, env_cls, args, manager, cfg, process_num, character, pre=False, infer=False): 105 | """ 106 | :param env_cls: env class or function, not instance, as we need to create several instance in class. 107 | :param args: 108 | :param manager: 109 | :param cfg: 110 | :param process_num: process number 111 | :param character: user or system 112 | :param pre: set to pretrain mode 113 | :param infer: set to test mode 114 | """ 115 | 116 | self.process_num = process_num 117 | self.character = character 118 | 119 | # initialize envs for each process 120 | self.env_list = [] 121 | for _ in range(process_num): 122 | self.env_list.append(env_cls()) 123 | 124 | # construct policy and value network 125 | self.policy = MultiDiscretePolicy(cfg, character).to(device=DEVICE) 126 | 127 | if pre: 128 | self.print_per_batch = args.print_per_batch 129 | from dbquery import DBQuery 130 | db = DBQuery(args.data_dir, cfg) 131 | self.data_train = manager.create_dataset_policy('train', args.batchsz, cfg, db, character) 132 | self.data_valid = manager.create_dataset_policy('valid', args.batchsz, cfg, db, character) 133 | self.data_test = manager.create_dataset_policy('test', args.batchsz, cfg, db, character) 134 | if character == 'sys': 135 | pos_weight = args.policy_weight_sys * torch.ones([cfg.a_dim]).to(device=DEVICE) 136 | elif character == 'usr': 137 | pos_weight = args.policy_weight_usr * torch.ones([cfg.a_dim_usr]).to(device=DEVICE) 138 | else: 139 | raise Exception('Unknown character') 140 | self.multi_entropy_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight) 141 | else: 142 | self.evaluator = MultiWozEvaluator(args.data_dir) 143 | 144 | self.save_dir = args.save_dir + '/' + character if pre else args.save_dir 145 | self.save_per_epoch = args.save_per_epoch 146 | self.optim_batchsz = args.batchsz 147 | self.policy.eval() 148 | 149 | self.gamma = args.gamma 150 | self.policy_optim = optim.RMSprop(self.policy.parameters(), lr=args.lr_policy, weight_decay=args.weight_decay) 151 | self.writer = SummaryWriter() 152 | 153 | def policy_loop(self, data): 154 | s, target_a = to_device(data) 155 | a_weights = self.policy(s) 156 | 157 | loss_a = self.multi_entropy_loss(a_weights, target_a) 158 | return loss_a 159 | 160 | def imitating(self, epoch): 161 | """ 162 | pretrain the policy by simple imitation learning (behavioral cloning) 163 | """ 164 | self.policy.train() 165 | a_loss = 0. 166 | for i, data in enumerate(self.data_train): 167 | self.policy_optim.zero_grad() 168 | loss_a = self.policy_loop(data) 169 | a_loss += loss_a.item() 170 | loss_a.backward() 171 | self.policy_optim.step() 172 | 173 | if (i+1) % self.print_per_batch == 0: 174 | a_loss /= self.print_per_batch 175 | logging.debug('<> epoch {}, iter {}, loss_a:{}'.format(self.character, epoch, i, a_loss)) 176 | a_loss = 0. 177 | 178 | if (epoch+1) % self.save_per_epoch == 0: 179 | self.save(self.save_dir, epoch) 180 | self.policy.eval() 181 | 182 | def imit_test(self, epoch, best): 183 | """ 184 | provide an unbiased evaluation of the policy fit on the training dataset 185 | """ 186 | a_loss = 0. 187 | for i, data in enumerate(self.data_valid): 188 | loss_a = self.policy_loop(data) 189 | a_loss += loss_a.item() 190 | 191 | a_loss /= len(self.data_valid) 192 | logging.debug('<> validation, epoch {}, loss_a:{}'.format(self.character, epoch, a_loss)) 193 | if a_loss < best: 194 | logging.info('<> best model saved'.format(self.character)) 195 | best = a_loss 196 | self.save(self.save_dir, 'best') 197 | 198 | a_loss = 0. 199 | for i, data in enumerate(self.data_test): 200 | loss_a = self.policy_loop(data) 201 | a_loss += loss_a.item() 202 | 203 | a_loss /= len(self.data_test) 204 | logging.debug('<> test, epoch {}, loss_a:{}'.format(self.character, epoch, a_loss)) 205 | self.writer.add_scalar('pretrain/dialogue_policy_{}/test'.format(self.character), a_loss, epoch) 206 | return best 207 | 208 | def save(self, directory, epoch): 209 | if not os.path.exists(directory): 210 | os.makedirs(directory) 211 | 212 | torch.save(self.policy.state_dict(), directory + '/' + str(epoch) + '_pol.mdl') 213 | 214 | logging.info('<> epoch {}: saved network to mdl'.format(self.character, epoch)) 215 | 216 | def load(self, filename): 217 | 218 | policy_mdl = filename + '_pol.mdl' 219 | if os.path.exists(policy_mdl): 220 | self.policy.load_state_dict(torch.load(policy_mdl)) 221 | logging.info('<> loaded checkpoint from file: {}'.format(self.character, policy_mdl)) 222 | -------------------------------------------------------------------------------- /rule.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | import torch 4 | from datamanager import expand_da 5 | from copy import deepcopy 6 | from tracker import StateTracker 7 | from goal_generator import GoalGenerator 8 | from utils import init_goal, init_session 9 | 10 | REF_USR_DA = { 11 | 'Attraction': { 12 | 'area': 'Area', 'type': 'Type', 'name': 'Name', 13 | 'entrance fee': 'Fee', 'address': 'Addr', 14 | 'postcode': 'Post', 'phone': 'Phone' 15 | }, 16 | 'Hospital': { 17 | 'department': 'Department', 'address': 'Addr', 'postcode': 'Post', 18 | 'phone': 'Phone' 19 | }, 20 | 'Hotel': { 21 | 'type': 'Type', 'parking': 'Parking', 'pricerange': 'Price', 22 | 'internet': 'Internet', 'area': 'Area', 'stars': 'Stars', 23 | 'name': 'Name', 'stay': 'Stay', 'day': 'Day', 'people': 'People', 24 | 'address': 'Addr', 'postcode': 'Post', 'phone': 'Phone' 25 | }, 26 | 'Police': { 27 | 'address': 'Addr', 'postcode': 'Post', 'phone': 'Phone' 28 | }, 29 | 'Restaurant': { 30 | 'food': 'Food', 'pricerange': 'Price', 'area': 'Area', 31 | 'name': 'Name', 'time': 'Time', 'day': 'Day', 'people': 'People', 32 | 'phone': 'Phone', 'postcode': 'Post', 'address': 'Addr' 33 | }, 34 | 'Taxi': { 35 | 'leaveAt': 'Leave', 'destination': 'Dest', 'departure': 'Depart', 'arriveBy': 'Arrive', 36 | 'car type': 'Car', 'phone': 'Phone' 37 | }, 38 | 'Train': { 39 | 'destination': 'Dest', 'day': 'Day', 'arriveBy': 'Arrive', 40 | 'departure': 'Depart', 'leaveAt': 'Leave', 'people': 'People', 41 | 'duration': 'Time', 'price': 'Ticket', 'trainID': 'Id' 42 | } 43 | } 44 | 45 | REF_SYS_DA = { 46 | 'Attraction': { 47 | 'Addr': "address", 'Area': "area", 'Choice': "choice", 48 | 'Fee': "entrance fee", 'Name': "name", 'Phone': "phone", 49 | 'Post': "postcode", 'Price': "pricerange", 'Type': "type", 50 | 'none': None, 'Open': None 51 | }, 52 | 'Hospital': { 53 | 'Department': 'department', 'Addr': 'address', 'Post': 'postcode', 54 | 'Phone': 'phone', 'none': None 55 | }, 56 | 'Booking': { 57 | 'Day': 'day', 'Name': 'name', 'People': 'people', 58 | 'Ref': 'ref', 'Stay': 'stay', 'Time': 'time', 59 | 'none': None 60 | }, 61 | 'Hotel': { 62 | 'Addr': "address", 'Area': "area", 'Choice': "choice", 63 | 'Internet': "internet", 'Name': "name", 'Parking': "parking", 64 | 'Phone': "phone", 'Post': "postcode", 'Price': "pricerange", 65 | 'Ref': "ref", 'Stars': "stars", 'Type': "type", 66 | 'none': None 67 | }, 68 | 'Restaurant': { 69 | 'Addr': "address", 'Area': "area", 'Choice': "choice", 70 | 'Name': "name", 'Food': "food", 'Phone': "phone", 71 | 'Post': "postcode", 'Price': "pricerange", 'Ref': "ref", 72 | 'none': None 73 | }, 74 | 'Taxi': { 75 | 'Arrive': "arriveBy", 'Car': "car type", 'Depart': "departure", 76 | 'Dest': "destination", 'Leave': "leaveAt", 'Phone': "phone", 77 | 'none': None 78 | }, 79 | 'Train': { 80 | 'Arrive': "arriveBy", 'Choice': "choice", 'Day': "day", 81 | 'Depart': "departure", 'Dest': "destination", 'Id': "trainID", 82 | 'Leave': "leaveAt", 'People': "people", 'Ref': "ref", 83 | 'Time': "duration", 'none': None, 'Ticket': 'price', 84 | }, 85 | 'Police': { 86 | 'Addr': "address", 'Post': "postcode", 'Phone': "phone" 87 | }, 88 | } 89 | 90 | SELECTABLE_SLOTS = { 91 | 'Attraction': ['area', 'entrance fee', 'name', 'type'], 92 | 'Hospital': ['department'], 93 | 'Hotel': ['area', 'internet', 'name', 'parking', 'pricerange', 'stars', 'type'], 94 | 'Restaurant': ['area', 'name', 'food', 'pricerange'], 95 | 'Taxi': [], 96 | 'Train': [], 97 | 'Police': [], 98 | } 99 | 100 | INFORMABLE_SLOTS = ["Fee", "Addr", "Area", "Stars", "Internet", "Department", "Choice", "Ref", "Food", "Type", "Price",\ 101 | "Stay", "Phone", "Post", "Day", "Name", "Car", "Leave", "Time", "Arrive", "Ticket", None, "Depart",\ 102 | "People", "Dest", "Parking", "Open", "Id"] 103 | 104 | REQUESTABLE_SLOTS = ['Food', 'Area', 'Fee', 'Price', 'Type', 'Department', 'Internet', 'Parking', 'Stars', 'Type'] 105 | 106 | # Information required to finish booking, according to different domain. 107 | booking_info = {'Train': ['People'], 108 | 'Restaurant': ['Time', 'Day', 'People'], 109 | 'Hotel': ['Stay', 'Day', 'People']} 110 | 111 | # Alphabet used to generate phone number 112 | digit = '0123456789' 113 | 114 | 115 | class SystemRule(StateTracker): 116 | ''' Rule-based bot. Implemented for Multiwoz dataset. ''' 117 | 118 | recommend_flag = -1 119 | choice = "" 120 | 121 | def __init__(self, data_dir, cfg): 122 | super(SystemRule, self).__init__(data_dir, cfg) 123 | self.last_state = {} 124 | self.goal_gen = GoalGenerator(data_dir, cfg, 125 | goal_model_path='processed_data/goal_model.pkl', 126 | corpus_path=cfg.data_file) 127 | 128 | def reset(self, random_seed=None): 129 | self.last_state = init_belief_state() 130 | self.time_step = 0 131 | self.topic = '' 132 | self.goal = self.goal_gen.get_user_goal(random_seed) 133 | 134 | dummy_state, dummy_goal = init_session(-1, self.cfg) 135 | init_goal(dummy_goal, dummy_state['goal_state'], self.goal, self.cfg) 136 | 137 | domain_ordering = self.goal['domain_ordering'] 138 | dummy_state['next_available_domain'] = domain_ordering[0] 139 | dummy_state['invisible_domains'] = domain_ordering[1:] 140 | 141 | dummy_state['user_goal'] = dummy_goal 142 | self.evaluator.add_goal(dummy_goal) 143 | 144 | return dummy_state 145 | 146 | def _action_to_dict(self, das): 147 | da_dict = {} 148 | for da, value in das.items(): 149 | domain, intent, slot = da.split('-') 150 | if domain != 'general': 151 | domain = domain.capitalize() 152 | if intent in ['inform', 'request']: 153 | intent = intent.capitalize() 154 | domint = '-'.join((domain, intent)) 155 | if domint not in da_dict: 156 | da_dict[domint] = [] 157 | da_dict[domint].append([slot.capitalize(), value]) 158 | return da_dict 159 | 160 | def _dict_to_vec(self, das): 161 | da_vector = torch.zeros(self.cfg.a_dim, dtype=torch.int32) 162 | expand_da(das) 163 | for domint in das: 164 | pairs = das[domint] 165 | for slot, p, value in pairs: 166 | da = '-'.join((domint, slot, p)).lower() 167 | if da in self.cfg.da2idx: 168 | idx = self.cfg.da2idx[da] 169 | da_vector[idx] = 1 170 | return da_vector 171 | 172 | def step(self, s, usr_a): 173 | """ 174 | interact with simulator for one user-sys turn 175 | """ 176 | # update state with user_act 177 | current_s = self.update_belief_usr(s, usr_a) 178 | da_dict = self._action_to_dict(current_s['user_action']) 179 | state = self._update_state(da_dict) 180 | sys_a = self.predict(state) 181 | sys_a = self._dict_to_vec(sys_a) 182 | 183 | # update state with sys_act 184 | next_s = self.update_belief_sys(current_s, sys_a) 185 | return next_s 186 | 187 | def predict(self, state): 188 | """ 189 | Args: 190 | State, please refer to util/state.py 191 | Output: 192 | DA(Dialog Act), in the form of {act_type1: [[slot_name_1, value_1], [slot_name_2, value_2], ...], ...} 193 | """ 194 | 195 | if self.recommend_flag != -1: 196 | self.recommend_flag += 1 197 | 198 | self.kb_result = {} 199 | 200 | DA = {} 201 | 202 | if 'user_action' in state and (len(state['user_action']) > 0): 203 | user_action = state['user_action'] 204 | else: 205 | user_action = check_diff(self.last_state, state) 206 | 207 | # Debug info for check_diff function 208 | 209 | self.last_state = state 210 | 211 | for user_act in user_action: 212 | domain, intent_type = user_act.split('-') 213 | 214 | # Respond to general greetings 215 | if domain == 'general': 216 | self._update_greeting(user_act, state, DA) 217 | 218 | # Book taxi for user 219 | elif domain == 'Taxi': 220 | self._book_taxi(user_act, state, DA) 221 | 222 | elif domain == 'Booking': 223 | self._update_booking(user_act, state, DA) 224 | 225 | # User's talking about other domain 226 | elif domain != "Train": 227 | self._update_DA(user_act, user_action, state, DA) 228 | 229 | # Info about train 230 | else: 231 | self._update_train(user_act, user_action, state, DA) 232 | 233 | # Judge if user want to book 234 | self._judge_booking(user_act, user_action, DA) 235 | 236 | if 'Booking-Book' in DA: 237 | if random.random() < 0.5: 238 | DA['general-reqmore'] = [] 239 | user_acts = [] 240 | for user_act in DA: 241 | if user_act != 'Booking-Book': 242 | user_acts.append(user_act) 243 | for user_act in user_acts: 244 | del DA[user_act] 245 | 246 | if DA == {}: 247 | return {'general-greet': [['none', 'none']]} 248 | return DA 249 | 250 | def _update_state(self, user_act=None): 251 | if not isinstance(user_act, dict): 252 | raise Exception('Expect user_act to be type but get {}.'.format(type(user_act))) 253 | previous_state = self.last_state 254 | new_belief_state = copy.deepcopy(previous_state['belief_state']) 255 | new_request_state = copy.deepcopy(previous_state['request_state']) 256 | for domain_type in user_act.keys(): 257 | domain, tpe = domain_type.lower().split('-') 258 | if domain in ['unk', 'general', 'booking']: 259 | continue 260 | if tpe == 'inform': 261 | for k, v in user_act[domain_type]: 262 | k = REF_SYS_DA[domain.capitalize()].get(k, k) 263 | if k is None: 264 | continue 265 | try: 266 | assert domain in new_belief_state 267 | except: 268 | raise Exception('Error: domain <{}> not in new belief state'.format(domain)) 269 | domain_dic = new_belief_state[domain] 270 | assert 'semi' in domain_dic 271 | assert 'book' in domain_dic 272 | 273 | if k in domain_dic['semi']: 274 | nvalue = v 275 | new_belief_state[domain]['semi'][k] = nvalue 276 | elif k in domain_dic['book']: 277 | new_belief_state[domain]['book'][k] = v 278 | elif k.lower() in domain_dic['book']: 279 | new_belief_state[domain]['book'][k.lower()] = v 280 | elif k == 'trainID' and domain == 'train': 281 | new_belief_state[domain]['book'][k] = v 282 | else: 283 | # raise Exception('unknown slot name <{}> of domain <{}>'.format(k, domain)) 284 | with open('unknown_slot.log', 'a+') as f: 285 | f.write('unknown slot name <{}> of domain <{}>\n'.format(k, domain)) 286 | elif tpe == 'request': 287 | for k, v in user_act[domain_type]: 288 | k = REF_SYS_DA[domain.capitalize()].get(k, k) 289 | if domain not in new_request_state: 290 | new_request_state[domain] = {} 291 | if k not in new_request_state[domain]: 292 | new_request_state[domain][k] = 0 293 | 294 | new_state = copy.deepcopy(previous_state) 295 | new_state['belief_state'] = new_belief_state 296 | new_state['request_state'] = new_request_state 297 | new_state['user_action'] = user_act 298 | 299 | return new_state 300 | 301 | 302 | def _update_greeting(self, user_act, state, DA): 303 | """ General request / inform. """ 304 | _, intent_type = user_act.split('-') 305 | 306 | # Respond to goodbye 307 | if intent_type == 'bye': 308 | if 'general-bye' not in DA: 309 | DA['general-bye'] = [] 310 | if random.random() < 0.3: 311 | if 'general-welcome' not in DA: 312 | DA['general-welcome'] = [] 313 | elif intent_type == 'thank': 314 | DA['general-welcome'] = [] 315 | 316 | def _book_taxi(self, user_act, state, DA): 317 | """ Book a taxi for user. """ 318 | 319 | blank_info = [] 320 | for info in ['departure', 'destination']: 321 | if state['belief_state']['taxi']['semi'] == "": 322 | info = REF_USR_DA['Taxi'].get(info, info) 323 | blank_info.append(info) 324 | if state['belief_state']['taxi']['semi']['leaveAt'] == "" and state['belief_state']['taxi']['semi']['arriveBy'] == "": 325 | blank_info += ['Leave', 'Arrive'] 326 | 327 | 328 | # Finish booking, tell user car type and phone number 329 | if len(blank_info) == 0: 330 | if 'Taxi-Inform' not in DA: 331 | DA['Taxi-Inform'] = [] 332 | car = generate_car() 333 | phone_num = generate_phone_num(11) 334 | DA['Taxi-Inform'].append(['Car', car]) 335 | DA['Taxi-Inform'].append(['Phone', phone_num]) 336 | return 337 | 338 | # Need essential info to finish booking 339 | request_num = random.randint(0, 999999) % len(blank_info) + 1 340 | if 'Taxi-Request' not in DA: 341 | DA['Taxi-Request'] = [] 342 | for i in range(request_num): 343 | slot = REF_USR_DA.get(blank_info[i], blank_info[i]) 344 | DA['Taxi-Request'].append([slot, '?']) 345 | 346 | def _update_booking(self, user_act, state, DA): 347 | pass 348 | 349 | def _update_DA(self, user_act, user_action, state, DA): 350 | """ Answer user's utterance about any domain other than taxi or train. """ 351 | 352 | domain, intent_type = user_act.split('-') 353 | 354 | constraints = [] 355 | for slot in state['belief_state'][domain.lower()]['semi']: 356 | if state['belief_state'][domain.lower()]['semi'][slot] != "": 357 | constraints.append([slot, state['belief_state'][domain.lower()]['semi'][slot]]) 358 | 359 | kb_result = self.db.query(domain.lower(), constraints) 360 | self.kb_result[domain] = deepcopy(kb_result) 361 | 362 | # Respond to user's request 363 | if intent_type == 'Request': 364 | if self.recommend_flag > 1: 365 | self.recommend_flag = -1 366 | self.choice = "" 367 | elif self.recommend_flag == 1: 368 | self.recommend_flag == 0 369 | if (domain + "-Inform") not in DA: 370 | DA[domain + "-Inform"] = [] 371 | for slot in user_action[user_act]: 372 | if len(kb_result) > 0: 373 | kb_slot_name = REF_SYS_DA[domain].get(slot[0], slot[0]) 374 | if kb_slot_name in kb_result[0]: 375 | DA[domain + "-Inform"].append([slot[0], kb_result[0][kb_slot_name]]) 376 | else: 377 | DA[domain + "-Inform"].append([slot[0], "unknown"]) 378 | 379 | else: 380 | # There's no result matching user's constraint 381 | if len(kb_result) == 0: 382 | if (domain + "-NoOffer") not in DA: 383 | DA[domain + "-NoOffer"] = [] 384 | 385 | for slot in state['belief_state'][domain.lower()]['semi']: 386 | if state['belief_state'][domain.lower()]['semi'][slot] != "" and \ 387 | state['belief_state'][domain.lower()]['semi'][slot] != "do n't care": 388 | slot_name = REF_USR_DA[domain].get(slot, slot) 389 | DA[domain + "-NoOffer"].append([slot_name, state['belief_state'][domain.lower()]['semi'][slot]]) 390 | 391 | p = random.random() 392 | 393 | # Ask user if he wants to change constraint 394 | if p < 0.3: 395 | req_num = min(random.randint(0, 999999) % len(DA[domain + "-NoOffer"]) + 1, 3) 396 | if domain + "-Request" not in DA: 397 | DA[domain + "-Request"] = [] 398 | for i in range(req_num): 399 | slot_name = REF_USR_DA[domain].get(DA[domain + "-NoOffer"][i][0], DA[domain + "-NoOffer"][i][0]) 400 | DA[domain + "-Request"].append([slot_name, "?"]) 401 | 402 | # There's exactly one result matching user's constraint 403 | elif len(kb_result) == 1: 404 | 405 | # Inform user about this result 406 | if (domain + "-Inform") not in DA: 407 | DA[domain + "-Inform"] = [] 408 | props = [] 409 | for prop in state['belief_state'][domain.lower()]['semi']: 410 | props.append(prop) 411 | property_num = len(props) 412 | if property_num > 0: 413 | info_num = random.randint(0, 999999) % property_num + 1 414 | random.shuffle(props) 415 | for i in range(info_num): 416 | slot_name = REF_USR_DA[domain].get(props[i], props[i]) 417 | DA[domain + "-Inform"].append([slot_name, kb_result[0][props[i]]]) 418 | 419 | # There are multiple resultes matching user's constraint 420 | else: 421 | p = random.random() 422 | 423 | # Recommend a choice from kb_list 424 | if True: #p < 0.3: 425 | if (domain + "-Inform") not in DA: 426 | DA[domain + "-Inform"] = [] 427 | if (domain + "-Recommend") not in DA: 428 | DA[domain + "-Recommend"] = [] 429 | DA[domain + "-Inform"].append(["Choice", str(len(kb_result))]) 430 | idx = random.randint(0, 999999) % len(kb_result) 431 | choice = kb_result[idx] 432 | if domain in ["Hotel", "Attraction", "Police", "Restaurant"]: 433 | DA[domain + "-Recommend"].append(['Name', choice['name']]) 434 | self.recommend_flag = 0 435 | self.candidate = choice 436 | props = [] 437 | for prop in choice: 438 | props.append([prop, choice[prop]]) 439 | prop_num = min(random.randint(0, 999999) % 3, len(props)) 440 | random.shuffle(props) 441 | for i in range(prop_num): 442 | slot = props[i][0] 443 | string = REF_USR_DA[domain].get(slot, slot) 444 | if string in INFORMABLE_SLOTS: 445 | DA[domain + "-Recommend"].append([string, str(props[i][1])]) 446 | 447 | # Ask user to choose a candidate. 448 | elif p < 0.5: 449 | prop_values = [] 450 | props = [] 451 | for prop in kb_result[0]: 452 | for candidate in kb_result: 453 | if prop not in candidate: 454 | continue 455 | if candidate[prop] not in prop_values: 456 | prop_values.append(candidate[prop]) 457 | if len(prop_values) > 1: 458 | props.append([prop, prop_values]) 459 | prop_values = [] 460 | random.shuffle(props) 461 | idx = 0 462 | while idx < len(props): 463 | if props[idx][0] not in SELECTABLE_SLOTS[domain]: 464 | props.pop(idx) 465 | idx -= 1 466 | idx += 1 467 | if domain + "-Select" not in DA: 468 | DA[domain + "-Select"] = [] 469 | for i in range(min(len(props[0][1]), 5)): 470 | prop_value = REF_USR_DA[domain].get(props[0][0], props[0][0]) 471 | DA[domain + "-Select"].append([prop_value, props[0][1][i]]) 472 | 473 | # Ask user for more constraint 474 | else: 475 | reqs = [] 476 | for prop in state['belief_state'][domain.lower()]['semi']: 477 | if state['belief_state'][domain.lower()]['semi'][prop] == "": 478 | prop_value = REF_USR_DA[domain].get(prop, prop) 479 | reqs.append([prop_value, "?"]) 480 | i = 0 481 | while i < len(reqs): 482 | if reqs[i][0] not in REQUESTABLE_SLOTS: 483 | reqs.pop(i) 484 | i -= 1 485 | i += 1 486 | random.shuffle(reqs) 487 | if len(reqs) == 0: 488 | return 489 | req_num = min(random.randint(0, 999999) % len(reqs) + 1, 2) 490 | if (domain + "-Request") not in DA: 491 | DA[domain + "-Request"] = [] 492 | for i in range(req_num): 493 | req = reqs[i] 494 | req[0] = REF_USR_DA[domain].get(req[0], req[0]) 495 | DA[domain + "-Request"].append(req) 496 | 497 | def _update_train(self, user_act, user_action, state, DA): 498 | constraints = [] 499 | for time in ['leaveAt', 'arriveBy']: 500 | if state['belief_state']['train']['semi'][time] != "": 501 | constraints.append([time, state['belief_state']['train']['semi'][time]]) 502 | 503 | if len(constraints) == 0: 504 | p = random.random() 505 | if 'Train-Request' not in DA: 506 | DA['Train-Request'] = [] 507 | if p < 0.33: 508 | DA['Train-Request'].append(['Leave', '?']) 509 | elif p < 0.66: 510 | DA['Train-Request'].append(['Arrive', '?']) 511 | else: 512 | DA['Train-Request'].append(['Leave', '?']) 513 | DA['Train-Request'].append(['Arrive', '?']) 514 | 515 | if 'Train-Request' not in DA: 516 | DA['Train-Request'] = [] 517 | for prop in ['day', 'destination', 'departure']: 518 | if state['belief_state']['train']['semi'][prop] == "": 519 | slot = REF_USR_DA['Train'].get(prop, prop) 520 | DA["Train-Request"].append([slot, '?']) 521 | else: 522 | constraints.append([prop, state['belief_state']['train']['semi'][prop]]) 523 | 524 | kb_result = self.db.query('train', constraints) 525 | self.kb_result['Train'] = deepcopy(kb_result) 526 | 527 | if user_act == 'Train-Request': 528 | del(DA['Train-Request']) 529 | if 'Train-Inform' not in DA: 530 | DA['Train-Inform'] = [] 531 | for slot in user_action[user_act]: 532 | slot_name = REF_SYS_DA['Train'].get(slot[0], slot[0]) 533 | try: 534 | DA['Train-Inform'].append([slot[0], kb_result[0][slot_name]]) 535 | except: 536 | pass 537 | return 538 | if len(kb_result) == 0: 539 | if 'Train-NoOffer' not in DA: 540 | DA['Train-NoOffer'] = [] 541 | for prop in constraints: 542 | DA['Train-NoOffer'].append([REF_USR_DA['Train'].get(prop[0], prop[0]), prop[1]]) 543 | if 'Train-Request' in DA: 544 | del DA['Train-Request'] 545 | elif len(kb_result) >= 1: 546 | if len(constraints) < 4: 547 | return 548 | if 'Train-Request' in DA: 549 | del DA['Train-Request'] 550 | if 'Train-OfferBook' not in DA: 551 | DA['Train-OfferBook'] = [] 552 | for prop in constraints: 553 | DA['Train-OfferBook'].append([REF_USR_DA['Train'].get(prop[0], prop[0]), prop[1]]) 554 | 555 | def _judge_booking(self, user_act, user_action, DA): 556 | """ If user want to book, return a ref number. """ 557 | if self.recommend_flag > 1: 558 | self.recommend_flag = -1 559 | self.choice = "" 560 | elif self.recommend_flag == 1: 561 | self.recommend_flag == 0 562 | domain, _ = user_act.split('-') 563 | for slot in user_action[user_act]: 564 | if domain in booking_info and slot[0] in booking_info[domain]: 565 | if 'Booking-Book' not in DA: 566 | if domain in self.kb_result and len(self.kb_result[domain]) > 0: 567 | if 'Ref' in self.kb_result[domain][0]: 568 | DA['Booking-Book'] = [["Ref", self.kb_result[domain][0]['Ref']]] 569 | else: 570 | DA['Booking-Book'] = [["Ref", "N/A"]] 571 | # TODO handle booking between multi turn 572 | 573 | def check_diff(last_state, state): 574 | user_action = {} 575 | if last_state == {}: 576 | for domain in state['belief_state']: 577 | for slot in state['belief_state'][domain]['book']: 578 | if slot != 'booked' and state['belief_state'][domain]['book'][slot] != '': 579 | if (domain.capitalize() + "-Inform") not in user_action: 580 | user_action[domain.capitalize() + "-Inform"] = [] 581 | if [REF_USR_DA[domain.capitalize()].get(slot, slot), state['belief_state'][domain]['book'][slot]] \ 582 | not in user_action[domain.capitalize() + "-Inform"]: 583 | user_action[domain.capitalize() + "-Inform"].append([REF_USR_DA[domain.capitalize()].get(slot, slot), \ 584 | state['belief_state'][domain]['book'][slot]]) 585 | for slot in state['belief_state'][domain]['semi']: 586 | if state['belief_state'][domain]['semi'][slot] != "": 587 | if (domain.capitalize() + "-Inform") not in user_action: 588 | user_action[domain.capitalize() + "-Inform"] = [] 589 | if [REF_USR_DA[domain.capitalize()].get(slot, slot), state['belief_state'][domain]['semi'][slot]] \ 590 | not in user_action[domain.capitalize() + "-Inform"]: 591 | user_action[domain.capitalize() + "-Inform"].append([REF_USR_DA[domain.capitalize()].get(slot, slot), \ 592 | state['belief_state'][domain]['semi'][slot]]) 593 | for domain in state['request_state']: 594 | for slot in state['request_state'][domain]: 595 | if (domain.capitalize() + "-Request") not in user_action: 596 | user_action[domain.capitalize() + "-Request"] = [] 597 | if [REF_USR_DA[domain].get(slot, slot), '?'] not in user_action[domain.capitalize() + "-Request"]: 598 | user_action[domain.capitalize() + "-Request"].append([REF_USR_DA[domain].get(slot, slot), '?']) 599 | 600 | else: 601 | for domain in state['belief_state']: 602 | for slot in state['belief_state'][domain]['book']: 603 | if slot != 'booked' and state['belief_state'][domain]['book'][slot] != last_state['belief_state'][domain]['book'][slot]: 604 | if (domain.capitalize() + "-Inform") not in user_action: 605 | user_action[domain.capitalize() + "-Inform"] = [] 606 | if [REF_USR_DA[domain.capitalize()].get(slot, slot), 607 | state['belief_state'][domain]['book'][slot]] \ 608 | not in user_action[domain.capitalize() + "-Inform"]: 609 | user_action[domain.capitalize() + "-Inform"].append( 610 | [REF_USR_DA[domain.capitalize()].get(slot, slot), \ 611 | state['belief_state'][domain]['book'][slot]]) 612 | for slot in state['belief_state'][domain]['semi']: 613 | if state['belief_state'][domain]['semi'][slot] != last_state['belief_state'][domain]['semi'][slot] and \ 614 | state['belief_state'][domain]['semi'][slot] != '': 615 | if (domain.capitalize() + "-Inform") not in user_action: 616 | user_action[domain.capitalize() + "-Inform"] = [] 617 | if [REF_USR_DA[domain.capitalize()].get(slot, slot), state['belief_state'][domain]['semi'][slot]] \ 618 | not in user_action[domain.capitalize() + "-Inform"]: 619 | user_action[domain.capitalize() + "-Inform"].append([REF_USR_DA[domain.capitalize()].get(slot, slot), \ 620 | state['belief_state'][domain]['semi'][slot]]) 621 | for domain in state['request_state']: 622 | for slot in state['request_state'][domain]: 623 | if (domain not in last_state['request_state']) or (slot not in last_state['request_state'][domain]): 624 | if (domain.capitalize() + "-Request") not in user_action: 625 | user_action[domain.capitalize() + "-Request"] = [] 626 | if [REF_USR_DA[domain.capitalize()].get(slot, slot), '?'] not in user_action[domain.capitalize() + "-Request"]: 627 | user_action[domain.capitalize() + "-Request"].append([REF_USR_DA[domain.capitalize()].get(slot, slot), '?']) 628 | return user_action 629 | 630 | 631 | def deduplicate(lst): 632 | i = 0 633 | while i < len(lst): 634 | if lst[i] in lst[0 : i]: 635 | lst.pop(i) 636 | i -= 1 637 | i += 1 638 | return lst 639 | 640 | def generate_phone_num(length): 641 | """ Generate a phone num. """ 642 | string = "" 643 | while len(string) < length: 644 | string += digit[random.randint(0, 999999) % 10] 645 | return string 646 | 647 | def generate_car(): 648 | """ Generate a car for taxi booking. """ 649 | car_types = ["toyota", "skoda", "bmw", "honda", "ford", "audi", "lexus", "volvo", "volkswagen", "tesla"] 650 | p = random.randint(0, 999999) % len(car_types) 651 | return car_types[p] 652 | 653 | def init_belief_state(): 654 | belief_state = { 655 | "police": { 656 | "book": { 657 | "booked": [] 658 | }, 659 | "semi": {} 660 | }, 661 | "hotel": { 662 | "book": { 663 | "booked": [], 664 | "people": "", 665 | "day": "", 666 | "stay": "" 667 | }, 668 | "semi": { 669 | "name": "", 670 | "area": "", 671 | "parking": "", 672 | "pricerange": "", 673 | "stars": "", 674 | "internet": "", 675 | "type": "" 676 | } 677 | }, 678 | "attraction": { 679 | "book": { 680 | "booked": [] 681 | }, 682 | "semi": { 683 | "type": "", 684 | "name": "", 685 | "area": "" 686 | } 687 | }, 688 | "restaurant": { 689 | "book": { 690 | "booked": [], 691 | "people": "", 692 | "day": "", 693 | "time": "" 694 | }, 695 | "semi": { 696 | "food": "", 697 | "pricerange": "", 698 | "name": "", 699 | "area": "", 700 | } 701 | }, 702 | "hospital": { 703 | "book": { 704 | "booked": [] 705 | }, 706 | "semi": { 707 | "department": "" 708 | } 709 | }, 710 | "taxi": { 711 | "book": { 712 | "booked": [] 713 | }, 714 | "semi": { 715 | "leaveAt": "", 716 | "destination": "", 717 | "departure": "", 718 | "arriveBy": "" 719 | } 720 | }, 721 | "train": { 722 | "book": { 723 | "booked": [], 724 | "people": "" 725 | }, 726 | "semi": { 727 | "leaveAt": "", 728 | "destination": "", 729 | "day": "", 730 | "arriveBy": "", 731 | "departure": "" 732 | } 733 | } 734 | } 735 | state = {'user_action': {}, 736 | 'belief_state': belief_state, 737 | 'request_state': {}} 738 | return state 739 | -------------------------------------------------------------------------------- /tracker.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: ryuichi takanobu 4 | """ 5 | import random 6 | import torch 7 | from copy import deepcopy 8 | from dbquery import DBQuery 9 | from utils import discard, reload 10 | from evaluator import MultiWozEvaluator 11 | 12 | class StateTracker(object): 13 | def __init__(self, data_dir, config): 14 | self.time_step = 0 15 | self.cfg = config 16 | self.db = DBQuery(data_dir, config) 17 | self.topic = '' 18 | self.evaluator = MultiWozEvaluator(data_dir) 19 | self.lock_evalutor = False 20 | 21 | def set_rollout(self, rollout): 22 | if rollout: 23 | self.save_time_step = self.time_step 24 | self.save_topic = self.topic 25 | self.lock_evalutor = True 26 | else: 27 | self.time_step = self.save_time_step 28 | self.save_topic = self.topic 29 | self.lock_evalutor = False 30 | 31 | def get_entities(self, s, domain): 32 | origin = s['belief_state'][domain].items() 33 | constraint = [] 34 | for k, v in origin: 35 | if v != '?' and k in self.cfg.mapping[domain]: 36 | constraint.append((self.cfg.mapping[domain][k], v)) 37 | entities = self.db.query(domain, constraint) 38 | random.shuffle(entities) 39 | return entities 40 | 41 | def update_belief_sys(self, old_s, a): 42 | """ 43 | update belief/goal state with sys action 44 | """ 45 | s = deepcopy(old_s) 46 | a_index = torch.nonzero(a) # get multiple da indices 47 | 48 | self.time_step += 1 49 | s['others']['turn'] = self.time_step 50 | 51 | # update sys/user dialog act 52 | s['sys_action'] = dict() 53 | 54 | # update belief part 55 | das = [self.cfg.idx2da[idx.item()] for idx in a_index] 56 | das = [da.split('-') for da in das] 57 | sorted(das, key=lambda x:x[0]) # sort by domain 58 | 59 | entities = [] if self.topic == '' else self.get_entities(s, self.topic) 60 | return_flag = False 61 | for domain, intent, slot, p in das: 62 | if domain in self.cfg.belief_domains and domain != self.topic: 63 | self.topic = domain 64 | entities = self.get_entities(s, domain) 65 | 66 | da = '-'.join((domain, intent, slot, p)) 67 | if intent == 'request': 68 | s['sys_action'][da] = '?' 69 | elif intent in ['nooffer', 'nobook'] and self.topic != '': 70 | return_flag = True 71 | if slot in s['belief_state'][self.topic] and s['belief_state'][self.topic][slot] != '?': 72 | s['sys_action'][da] = s['belief_state'][self.topic][slot] 73 | else: 74 | s['sys_action'][da] = 'none' 75 | elif slot == 'choice': 76 | s['sys_action'][da] = str(len(entities)) 77 | elif slot == 'none': 78 | s['sys_action'][da] = 'none' 79 | else: 80 | num = int(p) - 1 81 | if self.topic and len(entities) > num and slot in self.cfg.mapping[self.topic]: 82 | typ = self.cfg.mapping[self.topic][slot] 83 | if typ in entities[num]: 84 | s['sys_action'][da] = entities[num][typ] 85 | else: 86 | s['sys_action'][da] = 'none' 87 | else: 88 | s['sys_action'][da] = 'none' 89 | 90 | if not self.topic: 91 | continue 92 | if intent in ['inform', 'recommend', 'offerbook', 'offerbooked', 'book']: 93 | discard(s['belief_state'][self.topic], slot, '?') 94 | if slot in s['user_goal'][self.topic] and s['user_goal'][self.topic][slot] == '?': 95 | s['goal_state'][self.topic][slot] = s['sys_action'][da] 96 | 97 | # booked 98 | if intent == 'inform' and slot == 'car': # taxi 99 | if 'booked' not in s['belief_state']['taxi']: 100 | s['belief_state']['taxi']['booked'] = 'taxi-booked' 101 | elif intent in ['offerbooked', 'book'] and slot == 'ref': # train 102 | if self.topic in ['taxi', 'hospital', 'police']: 103 | s['belief_state'][self.topic]['booked'] = f'{self.topic}-booked' 104 | s['sys_action'][da] = f'{self.topic}-booked' 105 | elif entities: 106 | book_domain = entities[0]['ref'].split('-')[0] 107 | if 'booked' not in s['belief_state'][book_domain] and entities: 108 | s['belief_state'][book_domain]['booked'] = entities[0]['ref'] 109 | s['sys_action'][da] = entities[0]['ref'] 110 | 111 | if return_flag: 112 | for da in s['user_action']: 113 | d_usr, i_usr, s_usr = da.split('-') 114 | if i_usr == 'inform' and d_usr == self.topic: 115 | discard(s['belief_state'][d_usr], s_usr) 116 | reload(s['goal_state'], s['user_goal'], self.topic) 117 | 118 | if not self.lock_evalutor: 119 | self.evaluator.add_sys_da(s['sys_action']) 120 | 121 | return s 122 | 123 | def update_belief_usr(self, old_s, a): 124 | """ 125 | update belief/goal state with user action 126 | """ 127 | s = deepcopy(old_s) 128 | a_index = torch.nonzero(a) # get multiple da indices 129 | 130 | self.time_step += 1 131 | s['others']['turn'] = self.time_step 132 | s['others']['terminal'] = 1 if (self.cfg.a_dim_usr-1) in a_index else 0 133 | 134 | # update sys/user dialog act 135 | s['user_action'] = dict() 136 | 137 | # update belief part 138 | das = [self.cfg.idx2da_u[idx.item()] for idx in a_index if idx.item() != self.cfg.a_dim_usr-1] 139 | das = [da.split('-') for da in das] 140 | if s['invisible_domains']: 141 | for da in das: 142 | if da[0] == s['next_available_domain']: 143 | s['next_available_domain'] = s['invisible_domains'][0] 144 | s['invisible_domains'].remove(s['next_available_domain']) 145 | break 146 | sorted(das, key=lambda x:x[0]) # sort by domain 147 | 148 | for domain, intent, slot in das: 149 | if domain in self.cfg.belief_domains and domain != self.topic: 150 | self.topic = domain 151 | 152 | da = '-'.join((domain, intent, slot)) 153 | if intent == 'request': 154 | s['user_action'][da] = '?' 155 | s['belief_state'][self.topic][slot] = '?' 156 | elif slot == 'none': 157 | s['user_action'][da] = 'none' 158 | else: 159 | if self.topic and slot in s['user_goal'][self.topic] and s['user_goal'][domain][slot] != '?': 160 | s['user_action'][da] = s['user_goal'][domain][slot] 161 | else: 162 | s['user_action'][da] = 'dont care' 163 | 164 | if not self.topic: 165 | continue 166 | if intent == 'inform': 167 | s['belief_state'][domain][slot] = s['user_action'][da] 168 | if slot in s['user_goal'][self.topic] and s['user_goal'][self.topic][slot] != '?': 169 | discard(s['goal_state'][self.topic], slot) 170 | 171 | if not self.lock_evalutor: 172 | self.evaluator.add_usr_da(s['user_action']) 173 | 174 | return s 175 | 176 | def reset(self, random_seed=None): 177 | """ 178 | Args: 179 | random_seed (int): 180 | Returns: 181 | init_state (dict): 182 | """ 183 | pass 184 | 185 | def step(self, s, sys_a): 186 | """ 187 | Args: 188 | s (dict): 189 | sys_a (vector): 190 | Returns: 191 | next_s (dict): 192 | terminal (bool): 193 | """ 194 | pass 195 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: ryuichi takanobu 4 | """ 5 | import time 6 | import logging 7 | import os 8 | import numpy as np 9 | import argparse 10 | import torch 11 | 12 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | def get_parser(): 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument('--log_dir', type=str, default='log', help='Logging directory') 18 | parser.add_argument('--data_dir', type=str, default='data', help='Data directory') 19 | parser.add_argument('--save_dir', type=str, default='model_multi', help='Directory to store model') 20 | parser.add_argument('--load', type=str, default='', help='File name to load trained model') 21 | parser.add_argument('--pretrain', type=bool, default=False, help='Set to pretrain') 22 | parser.add_argument('--test', type=bool, default=False, help='Set to inference') 23 | parser.add_argument('--config', type=str, default='multiwoz', help='Dataset to use') 24 | parser.add_argument('--test_case', type=int, default=1000, help='Number of test cases') 25 | parser.add_argument('--save_per_epoch', type=int, default=4, help="Save model every XXX epoches") 26 | parser.add_argument('--print_per_batch', type=int, default=200, help="Print log every XXX batches") 27 | 28 | parser.add_argument('--epoch', type=int, default=48, help='Max number of epoch') 29 | parser.add_argument('--process', type=int, default=8, help='Process number') 30 | parser.add_argument('--batchsz', type=int, default=32, help='Batch size') 31 | parser.add_argument('--batchsz_traj', type=int, default=512, help='Batch size to collect trajectories') 32 | parser.add_argument('--policy_weight_sys', type=float, default=2.5, help='Pos weight on system policy pretraining') 33 | parser.add_argument('--policy_weight_usr', type=float, default=4, help='Pos weight on user policy pretraining') 34 | parser.add_argument('--lr_policy', type=float, default=1e-3, help='Learning rate of dialog policy') 35 | parser.add_argument('--lr_vnet', type=float, default=3e-5, help='Learning rate of value network') 36 | parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight decay (L2 penalty)') 37 | parser.add_argument('--gamma', type=float, default=0.99, help='Discounted factor') 38 | parser.add_argument('--clip', type=float, default=10, help='Gradient clipping') 39 | parser.add_argument('--interval', type=int, default=400, help='Update interval of target network') 40 | 41 | return parser 42 | 43 | def discard(dic, key, value=None): 44 | if key in dic: 45 | if value is None or dic[key] == value: 46 | del(dic[key]) 47 | 48 | def init_session(key, cfg): 49 | # shared info 50 | turn_data = {} 51 | turn_data['others'] = {'session_id':key, 'turn':0, 'terminal':False, 'change':False} 52 | turn_data['sys_action'] = dict() 53 | turn_data['user_action'] = dict() 54 | 55 | # belief & goal state 56 | turn_data['belief_state'] = {} 57 | turn_data['goal_state'] = {} 58 | for domain in cfg.belief_domains: 59 | turn_data['belief_state'][domain] = {} 60 | turn_data['goal_state'][domain] = {} 61 | 62 | # user goal 63 | session_data = {} 64 | for domain in cfg.belief_domains: 65 | session_data[domain] = {} 66 | 67 | return turn_data, session_data 68 | 69 | def init_goal(goal, state, off_goal, cfg): 70 | for domain in cfg.belief_domains: 71 | if domain in off_goal and off_goal[domain]: 72 | domain_data = off_goal[domain] 73 | # constraint 74 | if 'info' in domain_data: 75 | for slot, value in domain_data['info'].items(): 76 | slot = cfg.map_inverse[domain][slot] 77 | # single slot value for user goal 78 | inform_da = domain+'-'+slot 79 | if inform_da in cfg.inform_da_usr: 80 | goal[domain][slot] = value 81 | state[domain][slot] = value 82 | if 'fail_info' in domain_data and domain_data['fail_info']: 83 | goal[domain]['final'] = {} 84 | for slot, value in domain_data['fail_info'].items(): 85 | slot = cfg.map_inverse[domain][slot] 86 | # single slot value for user goal 87 | inform_da = domain+'-'+slot 88 | if inform_da in cfg.inform_da_usr: 89 | goal[domain]['final'][slot] = goal[domain][slot] 90 | goal[domain][slot] = value 91 | state[domain][slot] = value 92 | 93 | # booking 94 | if 'book' in domain_data: 95 | goal[domain]['book'] = True 96 | for slot, value in domain_data['book'].items(): 97 | if slot in cfg.map_inverse[domain]: 98 | slot = cfg.map_inverse[domain][slot] 99 | # single slot value for user goal 100 | inform_da = domain+'-'+slot 101 | if inform_da in cfg.inform_da_usr: 102 | goal[domain][slot] = value 103 | state[domain][slot] = value 104 | if 'fail_book' in domain_data and domain_data['fail_book']: 105 | if 'final' not in goal[domain]: 106 | goal[domain]['final'] = {} 107 | for slot, value in domain_data['fail_book'].items(): 108 | if slot in cfg.map_inverse[domain]: 109 | slot = cfg.map_inverse[domain][slot] 110 | # single slot value for user goal 111 | inform_da = domain+'-'+slot 112 | if inform_da in cfg.inform_da_usr: 113 | goal[domain]['final'][slot] = goal[domain][slot] 114 | goal[domain][slot] = value 115 | state[domain][slot] = value 116 | 117 | # request 118 | if 'reqt' in domain_data: 119 | for slot in domain_data['reqt']: 120 | slot = cfg.map_inverse[domain][slot] 121 | request_da = domain+'-'+slot 122 | if request_da in cfg.request_da_usr: 123 | goal[domain][slot] = '?' 124 | state[domain][slot] = '?' 125 | 126 | def reload(state, goal, domain): 127 | state[domain] = {} 128 | for key in goal[domain]: 129 | if key != 'final': 130 | state[domain][key] = goal[domain][key] 131 | if 'final' in goal[domain]: 132 | for key in goal[domain]['final']: 133 | goal[domain][key] = goal[domain]['final'][key] 134 | state[domain][key] = goal[domain][key] 135 | del(goal[domain]['final']) 136 | 137 | def init_logging_handler(log_dir, extra=''): 138 | if not os.path.exists(log_dir): 139 | os.makedirs(log_dir) 140 | current_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) 141 | 142 | stderr_handler = logging.StreamHandler() 143 | file_handler = logging.FileHandler('{}/log_{}.txt'.format(log_dir, current_time+extra)) 144 | logging.basicConfig(handlers=[stderr_handler, file_handler]) 145 | logger = logging.getLogger() 146 | logger.setLevel(logging.DEBUG) 147 | 148 | def to_device(data): 149 | if type(data) == dict: 150 | for k, v in data.items(): 151 | data[k] = v.to(device=DEVICE) 152 | else: 153 | for idx, item in enumerate(data): 154 | data[idx] = item.to(device=DEVICE) 155 | return data 156 | 157 | def check_constraint(slot, val_usr, val_sys): 158 | try: 159 | if slot == 'arrive': 160 | val1 = int(val_usr.split(':')[0]) * 100 + int(val_usr.split(':')[1]) 161 | val2 = int(val_sys.split(':')[0]) * 100 + int(val_sys.split(':')[1]) 162 | if val1 < val2: 163 | return True 164 | elif slot == 'leave': 165 | val1 = int(val_usr.split(':')[0]) * 100 + int(val_usr.split(':')[1]) 166 | val2 = int(val_sys.split(':')[0]) * 100 + int(val_sys.split(':')[1]) 167 | if val1 > val2: 168 | return True 169 | else: 170 | if val_usr != val_sys: 171 | return True 172 | return False 173 | except: 174 | return False 175 | 176 | def state_vectorize(state, config, db, noisy=False): 177 | """ 178 | state: dict_keys(['user_action', 'sys_action', 'select_entity', 'belief_state', 'others']) 179 | state_vec: [user_act, last_sys_act, inform, request, book, degree, entropy] 180 | """ 181 | user_act = np.zeros(len(config.da_usr)) 182 | for da in state['user_action']: 183 | user_act[config.da2idx_u[da]] = 1. 184 | 185 | last_sys_act = np.zeros(len(config.da)) 186 | for da in state['sys_action']: 187 | last_sys_act[config.da2idx[da]] = 1. 188 | 189 | inform = np.zeros(len(config.inform_da)) 190 | request = np.zeros(len(config.request_da)) 191 | for domain in state['belief_state']: 192 | for slot, value in state['belief_state'][domain].items(): 193 | key = domain+'-'+slot 194 | if value == '?': 195 | if key in config.request2idx: 196 | request[config.request2idx[key]] = 1. 197 | else: 198 | if key in config.inform2idx: 199 | inform[config.inform2idx[key]] = 1. 200 | 201 | # select entity 202 | book = np.zeros(len(config.belief_domains)) 203 | for domain in state['belief_state']: 204 | if 'booked' in state['belief_state'][domain]: 205 | book[config.domain2idx[domain]] = 1. 206 | 207 | degree, entropy = db.pointer(state['belief_state'], config.mapping, config.db_domains, config.requestable, noisy) 208 | 209 | final = 1. if state['others']['terminal'] else 0. 210 | 211 | state_vec = np.r_[user_act, last_sys_act, inform, request, book, degree, final] 212 | assert len(state_vec) == config.s_dim 213 | return state_vec 214 | 215 | def action_vectorize(action, config): 216 | act_vec = np.zeros(config.a_dim) 217 | for da in action: 218 | act_vec[config.da2idx[da]] = 1 219 | return act_vec 220 | 221 | def state_vectorize_user(state, config, current_domain): 222 | """ 223 | state: dict_keys(['user_action', 'sys_action', 'user_goal', 'goal_state', 'others']) 224 | state_vec: [sys_act, last_user_act, inform, request, focus, inconsistency, nooffer] 225 | """ 226 | sys_act = np.zeros(len(config.da)) 227 | for da in state['sys_action']: 228 | sys_act[config.da2idx[da]] = 1. 229 | 230 | last_user_act = np.zeros(len(config.da_usr)) 231 | for da in state['user_action']: 232 | last_user_act[config.da2idx_u[da]] = 1. 233 | 234 | inform = np.zeros(len(config.inform_da_usr)) 235 | request = np.zeros(len(config.request_da_usr)) 236 | for domain in state['goal_state']: 237 | if domain in state['invisible_domains']: 238 | continue 239 | for slot, value in state['goal_state'][domain].items(): 240 | key = domain+'-'+slot 241 | if value == '?': 242 | if key in config.request2idx_u: 243 | request[config.request2idx_u[key]] = 1. 244 | else: 245 | if key in config.inform2idx_u and slot in state['user_goal'][domain]\ 246 | and state['user_goal'][domain][slot] != '?': 247 | inform[config.inform2idx_u[key]] = 1. 248 | 249 | focus = np.zeros(len(config.belief_domains)) 250 | if current_domain: 251 | focus[config.domain2idx[current_domain]] = 1. 252 | 253 | inconsistency = np.zeros(len(config.inform_da_usr)) 254 | nooffer = np.zeros(len(config.belief_domains)) 255 | for da, value in state['sys_action'].items(): 256 | domain, intent, slot, p = da.split('-') 257 | if intent in ['inform', 'recommend', 'offerbook', 'offerbooked']: 258 | key = domain+'-'+slot 259 | if key in config.inform2idx_u and slot in state['user_goal'][domain]: 260 | refer = state['user_goal'][domain][slot] 261 | if refer != '?' and check_constraint(slot, refer, value): 262 | inconsistency[config.inform2idx_u[key]] = 1. 263 | if intent in ['nooffer', 'nobook'] and current_domain: 264 | nooffer[config.domain2idx[current_domain]] = 1. 265 | 266 | state_vec = np.r_[sys_act, last_user_act, inform, request, inconsistency, nooffer] 267 | assert len(state_vec) == config.s_dim_usr 268 | return state_vec 269 | 270 | def action_vectorize_user(action, terminal, config): 271 | act_vec = np.zeros(config.a_dim_usr) 272 | for da in action: 273 | act_vec[config.da2idx_u[da]] = 1 274 | if terminal: 275 | act_vec[-1] = 1 276 | return act_vec 277 | 278 | def reparameterize(mu, logvar): 279 | std = (0.5*logvar).exp() 280 | eps = torch.randn_like(std) 281 | return eps.mul(std) + mu 282 | --------------------------------------------------------------------------------