├── .gitignore ├── LICENSE ├── README.md ├── deep_dialog ├── __init__.py ├── agents │ ├── __init__.py │ ├── agent.py │ ├── agent_act_rule.py │ ├── agent_e2eRL_allact.py │ ├── agent_lu_rl.py │ ├── agent_nl_rule_hard.py │ ├── agent_nl_rule_no.py │ ├── agent_nl_rule_soft.py │ ├── agent_rl.py │ ├── agent_simpleRL_allact.py │ ├── agent_simpleRL_allact_hardDB.py │ ├── agent_simpleRL_allact_noDB.py │ ├── belief_tracker.py │ ├── feature_extractor.py │ ├── hardDB.py │ ├── softDB.py │ └── utils.py ├── dialog_config.py ├── dialog_system │ ├── __init__.py │ ├── database.py │ ├── dialog_manager.py │ ├── dict_reader.py │ └── movie_dict.py ├── objects │ ├── __init__.py │ └── slot_reader.py ├── tools.py └── usersims │ ├── NLG │ ├── __init__.py │ ├── decoders │ │ ├── __init__.py │ │ ├── decoder.py │ │ ├── lstm_decoder_tanh.py │ │ └── utils.py │ ├── eval │ │ ├── bleu.py │ │ └── multi-bleu.perl │ ├── fileio │ │ ├── __init__.py │ │ └── data_set.py │ └── predict.py │ ├── __init__.py │ ├── s2s_nlg.py │ ├── template_nlg.py │ ├── user_cmd.py │ └── usersim_rule.py ├── interact.py ├── requirements.txt ├── settings ├── __init__.py ├── config_imdb-L.py ├── config_imdb-M.py ├── config_imdb-S.py └── config_imdb-XL.py ├── sim.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | data/ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Bhuwan Dhingra 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | KB-InfoBot 2 | ================================================== 3 | 4 | This repository contains all the code and data accompanying the paper [Towards End-to-End Reinforcement Learning of Dialogue Agents for Information Access](https://arxiv.org/abs/1609.00777). 5 | 6 | Prerequisites 7 | -------------------------------------------------- 8 | See [requirements.txt](./requirements.txt) for required packacges. Also download nltk data: 9 | ```sh 10 | python -m nltk.downloader all 11 | ``` 12 | IMPORTANT: Download the data and pretrained models from [here](https://drive.google.com/file/d/0B7aCzQIaRTDUMDF1S3NVajlHTFk/view?usp=sharing), unpack the tar and place it at the root of the repository. 13 | 14 | Code Organization 15 | -------------------------------------------------- 16 | * All agents are in [deep_dialog/agents/](./deep_dialog/agents/) directory 17 | * The user-simulator along with a template based and seq2seq NLG is in [deep_dialog/usersims/](./deep_dialog/usersims/) directory 18 | * [deep_dialog/dialog_system/](./deep_dialog/dialog_system/) contains classes for dialog manager and the database 19 | 20 | Interact with the pre-trained InfoBot! 21 | -------------------------------------------------- 22 | ```sh 23 | $ python interact.py 24 | ``` 25 | 26 | This will launch the command line tool running the RL-SoftKB infobot trained on the "Medium-KB" split. Instructions on how to interact the system are displayed within the tool itself. You can also specify other agents to test: 27 | 28 | ```sh 29 | $ python interact.py --help 30 | usage: interact.py [-h] [--agent AGENT] 31 | 32 | optional arguments: 33 | -h, --help show this help message and exit 34 | --agent AGENT Agent to run -- (rule-no / rl-no / rule-hard / rl-hard / 35 | rule-soft / rl-soft / e2e-soft 36 | ``` 37 | 38 | Training 39 | -------------------------------------------------- 40 | To train the RL agents, call `train.py` with the following options: 41 | ```sh 42 | $ python train.py --help 43 | usage: train.py [-h] [--agent AGENT_TYPE] [--db DB] [--model_name MODEL_NAME] 44 | [--N N] [--max_turn MAX_TURN] [--nlg_temp NLG_TEMP] 45 | [--max_first_turn MAX_FIRST_TURN] [--err_prob ERR_PROB] 46 | [--dontknow_prob DONTKNOW_PROB] [--sub_prob SUB_PROB] 47 | [--reload RELOAD] 48 | 49 | optional arguments: 50 | -h, --help show this help message and exit 51 | --agent AGENT_TYPE agent to use (rl-no / rl-hard / rl-soft / e2e-soft) 52 | --db DB imdb-(S/M/L/XL) -- This is the KB split to use, e.g. 53 | imdb-M 54 | --model_name MODEL_NAME 55 | model name to save 56 | --N N Number of simulations 57 | --max_turn MAX_TURN maximum length of each dialog (default=20, 0=no 58 | maximum length) 59 | --nlg_temp NLG_TEMP Natural Language Generator softmax temperature (to 60 | control noise) 61 | --max_first_turn MAX_FIRST_TURN 62 | Maximum number of slots informed by user in first turn 63 | --err_prob ERR_PROB the probability of the user simulator corrupting a 64 | slot value 65 | --dontknow_prob DONTKNOW_PROB 66 | the probability that user simulator does not know a 67 | slot value 68 | --sub_prob SUB_PROB the probability that user simulator substitutes a slot 69 | value 70 | --reload RELOAD Reload previously saved model (0-no, 1-yes) 71 | ``` 72 | Example: 73 | ```sh 74 | python train.py --agent e2e-soft --db imdb-M --model_name e2e_soft_example.m 75 | ``` 76 | 77 | Testing 78 | ---------------------------------------------------- 79 | To evaluate both RL and Rule agents, call `sim.py` with the following options: 80 | ```sh 81 | $ python sim.py --help 82 | usage: sim.py [-h] [--agent AGENT_TYPE] [--N N] [--db DB] 83 | [--max_turn MAX_TURN] [--err_prob ERR_PROB] 84 | [--dontknow_prob DONTKNOW_PROB] [--sub_prob SUB_PROB] 85 | [--nlg_temp NLG_TEMP] [--max_first_turn MAX_FIRST_TURN] 86 | [--model_name MODEL_NAME] 87 | 88 | optional arguments: 89 | -h, --help show this help message and exit 90 | --agent AGENT_TYPE agent to use (rule-no / rl-no / rule-hard / rl-hard / 91 | rule-soft / rl-soft / e2e-soft) 92 | --N N Number of simulations 93 | --db DB imdb-(S/M/L/XL) -- This is the KB split to use, e.g. 94 | imdb-M 95 | --max_turn MAX_TURN maximum length of each dialog (default=20, 0=no 96 | maximum length) 97 | --err_prob ERR_PROB the probability of the user simulator corrupting a 98 | slot value 99 | --dontknow_prob DONTKNOW_PROB 100 | the probability that user simulator does not know a 101 | slot value 102 | --sub_prob SUB_PROB the probability that user simulator substitutes a slot 103 | value 104 | --nlg_temp NLG_TEMP Natural Language Generator softmax temperature (to 105 | control noise) 106 | --max_first_turn MAX_FIRST_TURN 107 | Maximum number of slots informed by user in first turn 108 | --model_name MODEL_NAME 109 | model name to evaluate (This should be the same as 110 | what you gave for training). Pass "pretrained" to use 111 | pretrained models. 112 | ``` 113 | Run without the `--model_name` argument to test on pre-trained models. Example: 114 | ```sh 115 | python sim.py --agent rl-soft --db imdb-M 116 | ``` 117 | 118 | Hyperparameters 119 | ------------------------------------------------- 120 | The default hyperparameters for each KB split are in `settings/config_.py`. These include: 121 | 1. RL agent options- 122 | * `nhid`: Number of hidden units 123 | * `batch`: Batch size 124 | * `ment`: Entropy regularization parameter 125 | * `lr`: Learning rate for initial supervised learning of policy. RL learning rate is fixed to 0.005. 126 | * `featN`: Only for end-to-end RL agent, *n* for n-gram feature extraction 127 | * `pol_start`: Number of supervised learning updates before switching to RL 128 | * `input`: Input type to the policy network - full/entropy 129 | * `sl`: Only for end-to-end RL agent, Type of supervised learning (bel-only belief tracker, pol-only policy, e2e (default)-both) 130 | * `rl`: Only for end-to-end RL agent, Type of reinforcement learning (bel-only belief tracker, pol-only policy, e2e (default)-both) 131 | 2. Rule agent options- 132 | * `tr`: Threshold for databse entropy to inform 133 | * `ts`: Threshold for slot entropy to request 134 | * `max_req`: Maximum requests allowed per slot 135 | * `frac`: Ratio to initial slot entropy, below which if the slot entropy falls it is not requested anymore 136 | * `upd`: Update count for bayesian belief tracking 137 | 138 | ## Note 139 | Make sure to add `THEANO_FLAGS=device=cpu,floatX=float32` before any command if you are running on a CPU. 140 | 141 | ## Contributors 142 | If you use this code please cite the following: 143 | 144 | Dhingra, B., Li, L., Li, X., Gao, J., Chen, Y. N., Ahmed, F., & Deng, L. (2017). Towards End-to-end reinforcement learning of dialogue agents for information access. ACL. 145 | ``` 146 | @inproceedings{dhingra2017towards, 147 | title={Towards End-to-end reinforcement learning of dialogue agents for information access}, 148 | author={Dhingra, Bhuwan and Li, Lihong and Li, Xiujun and Gao, Jianfeng and Chen, Yun-Nung and Ahmed, Faisal and Deng, Li}, 149 | booktitle={Proceddings of ACL}, 150 | year={2017} 151 | } 152 | ``` 153 | 154 | Report bugs and missing info to bdhingraATandrewDOTcmuDOTedu (replace AT, DOT appropriately). 155 | -------------------------------------------------------------------------------- /deep_dialog/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiuLab/KB-InfoBot/f472695fa083020825f799919c90a37235a5bb28/deep_dialog/__init__.py -------------------------------------------------------------------------------- /deep_dialog/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent_act_rule import * 2 | from .agent_nl_rule_soft import * 3 | from .agent_nl_rule_hard import * 4 | from .agent_nl_rule_no import * 5 | from .agent_simpleRL_allact import * 6 | from .agent_simpleRL_allact_noDB import * 7 | from .agent_simpleRL_allact_hardDB import * 8 | from .agent_e2eRL_allact import * 9 | -------------------------------------------------------------------------------- /deep_dialog/agents/agent.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | 4 | from deep_dialog import tools 5 | from feature_extractor import FeatureExtractor 6 | 7 | class Agent: 8 | def __init__(self, movie_dict=None, act_set=None, slot_set=None, db=None, corpus=None, \ 9 | tr=None, ts=None, frac=None, max_req=None, upd=None): 10 | self.movie_dict = movie_dict 11 | self.act_det = act_set 12 | self.slot_set = slot_set 13 | self.database = db 14 | self.tr = tr 15 | self.ts = ts 16 | self.frac = frac 17 | self.max_req = max_req 18 | self.upd = upd 19 | 20 | def init(self): 21 | self.state = {} 22 | 23 | def next(self, usr_action): 24 | pass 25 | -------------------------------------------------------------------------------- /deep_dialog/agents/agent_act_rule.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | 4 | from deep_dialog import dialog_config, tools 5 | from agent import Agent 6 | 7 | import operator 8 | import random 9 | 10 | class AgentActRule(Agent): 11 | 12 | def initialize_episode(self): 13 | self.state = {} 14 | self.state['diaact'] = 'UNK' 15 | self.state['inform_slots'] = {} 16 | self.state['turn'] = 0 17 | 18 | ''' update agent state ''' 19 | def _update_state(self, user_action): 20 | for s in user_action['inform_slots'].keys(): 21 | self.state['inform_slots'][s] = user_action['inform_slots'][s] 22 | 23 | ''' get next action based on rules ''' 24 | def next(self, user_action, verbose=False): 25 | self._update_state(user_action) 26 | 27 | act = {} 28 | act['diaact'] = 'UNK' 29 | act['request_slots'] = {} 30 | act['target'] = [] 31 | 32 | db_status, db_index = self._check_db() 33 | 34 | if not db_status: 35 | # no match, some error, re-ask some slot 36 | act['diaact'] = 'request' 37 | request_slot = random.choice(self.state['inform_slots'].keys()) 38 | act['request_slots'][request_slot] = 'UNK' 39 | 40 | elif len(self.state['inform_slots']) == len(dialog_config.sys_request_slots) \ 41 | or len(db_status)==1: 42 | act['diaact'] = 'inform' 43 | act['target'] = self._inform(db_index) 44 | 45 | else: 46 | # request a slot not known with max entropy 47 | known_slots = self.state['inform_slots'].keys() 48 | unknown_slots = [s for s in dialog_config.sys_request_slots if s not in known_slots] 49 | slot_entropy = {} 50 | for s in unknown_slots: 51 | db_idx = self.database.slots.index(s) 52 | db_matches = [m[db_idx] for m in db_status] 53 | slot_entropy[s] = tools.entropy(db_matches) 54 | request_slot, max_ent = max(slot_entropy.iteritems(), key=operator.itemgetter(1)) 55 | if max_ent > 0.: 56 | act['diaact'] = 'request' 57 | act['request_slots'][request_slot] = 'UNK' 58 | else: 59 | act['diaact'] = 'inform' 60 | act['target'] = self._inform(db_index) 61 | 62 | act['posterior'] = np.zeros((len(self.database.labels),)) 63 | act['posterior'][db_index] = 1./len(db_index) 64 | 65 | return act 66 | 67 | def terminate_episode(self, user_action): 68 | return 69 | 70 | def _inform(self, db_index): 71 | target = db_index 72 | if len(target) > 1: random.shuffle(target) 73 | full_range = range(self.database.N) 74 | random.shuffle(full_range) 75 | for i in full_range: 76 | if i not in db_index: target.append(i) 77 | return target 78 | 79 | ''' query DB based on current known slots ''' 80 | def _check_db(self): 81 | # from query to db form current inform_slots 82 | db_query = [] 83 | for s in self.database.slots: 84 | if s in self.state['inform_slots']: 85 | db_query.append(self.state['inform_slots'][s]) 86 | else: 87 | db_query.append(None) 88 | matches, index = self.database.lookup(db_query) 89 | return matches, index 90 | 91 | ''' sample value from current state of database ''' 92 | def _sample_slot(self, slot, matches): 93 | if not matches: 94 | return None 95 | index = self.database.slots.index(slot) 96 | return random.choice([m[index] for m in matches]) 97 | -------------------------------------------------------------------------------- /deep_dialog/agents/agent_e2eRL_allact.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | 4 | import numpy as np 5 | import cPickle as pkl 6 | 7 | from deep_dialog import dialog_config, tools 8 | from collections import Counter, defaultdict, deque 9 | from agent_lu_rl import E2ERLAgent, aggregate_rewards 10 | from belief_tracker import BeliefTracker 11 | from softDB import SoftDB 12 | from feature_extractor import FeatureExtractor 13 | from utils import * 14 | 15 | import operator 16 | import random 17 | import math 18 | import copy 19 | import re 20 | import nltk 21 | import time 22 | 23 | # params 24 | DISPF = 1 25 | SAVEF = 100 26 | ANNEAL = 800 27 | 28 | class AgentE2ERLAllAct(E2ERLAgent,SoftDB,BeliefTracker): 29 | def __init__(self, movie_dict=None, act_set=None, slot_set=None, db=None, corpus=None, 30 | train=True, _reload=False, n_hid=100, batch=128, ment=0., inputtype='full', upd=10, 31 | sl='e2e', rl='e2e', pol_start=600, lr=0.005, N=1, tr=2.0, ts=0.5, max_req=2, frac=0.5, 32 | name=None): 33 | self.movie_dict = movie_dict 34 | self.act_set = act_set 35 | self.slot_set = slot_set 36 | self.database = db 37 | self.max_turn = dialog_config.MAX_TURN 38 | self.training = train 39 | self.feat_extractor = FeatureExtractor(corpus,self.database.path,N=N) 40 | out_size = len(dialog_config.inform_slots)+1 41 | in_size = len(self.feat_extractor.grams) + len(dialog_config.inform_slots) 42 | slot_sizes = [self.movie_dict.lengths[s] for s in dialog_config.inform_slots] 43 | self._init_model(in_size, out_size, slot_sizes, self.database, \ 44 | n_hid=n_hid, learning_rate_sl=lr, batch_size=batch, ment=ment, inputtype=inputtype, \ 45 | sl=sl, rl=rl) 46 | self._name = name 47 | if _reload: self.load_model(dialog_config.MODEL_PATH+self._name) 48 | if train: self.save_model(dialog_config.MODEL_PATH+self._name) 49 | self._init_experience_pool(batch) 50 | self.episode_count = 0 51 | self.recent_rewards = deque([], 1000) 52 | self.recent_successes = deque([], 1000) 53 | self.recent_turns = deque([], 1000) 54 | self.recent_loss = deque([], 10) 55 | self.discount = 0.99 56 | self.num_updates = 0 57 | self.pol_start = pol_start 58 | self.tr = tr 59 | self.ts = ts 60 | self.max_req = max_req 61 | self.frac = frac 62 | self.upd = upd 63 | 64 | def _print_progress(self,loss,te,*args): 65 | self.recent_loss.append(loss) 66 | avg_ret = float(sum(self.recent_rewards))/len(self.recent_rewards) 67 | avg_turn = float(sum(self.recent_turns))/len(self.recent_turns) 68 | avg_loss = float(sum(self.recent_loss))/len(self.recent_loss) 69 | n_suc, n_fail, n_inc, tot = 0, 0, 0, 0 70 | for s in self.recent_successes: 71 | if s==-1: n_fail += 1 72 | elif s==0: n_inc += 1 73 | else: n_suc += 1 74 | tot += 1 75 | if len(args)>0: 76 | print 'Update %d. Avg turns = %.2f . Avg Reward = %.2f . Success Rate = %.2f . Fail Rate = %.2f . Incomplete Rate = %.2f . Loss = %.3f . Time = %.2f' % \ 77 | (self.num_updates, avg_turn, avg_ret, \ 78 | float(n_suc)/tot, float(n_fail)/tot, float(n_inc)/tot, avg_loss, te) 79 | #print 'kl loss = {}'.format(args[0]) 80 | #print 'x_loss = {}'.format(args[1]) 81 | else: 82 | print 'Update %d. Avg turns = %.2f . Avg Reward = %.2f . Success Rate = %.2f . Fail Rate = %.2f . Incomplete Rate = %.2f . Loss = %.3f . Time = %.2f' % \ 83 | (self.num_updates, avg_turn, avg_ret, \ 84 | float(n_suc)/tot, float(n_fail)/tot, float(n_inc)/tot, avg_loss, te) 85 | 86 | def initialize_episode(self): 87 | self.episode_count += 1 88 | if self.training and self.episode_count%self.batch_size==0: 89 | self.num_updates += 1 90 | if self.num_updates>self.pol_start and self.num_updates%ANNEAL==0: self.anneal_lr() 91 | tst = time.time() 92 | if self.num_updates < self.pol_start: 93 | all_loss = self.update(regime='SL') 94 | loss = all_loss[0] 95 | kl_loss = all_loss[1:len(dialog_config.inform_slots)+1] 96 | x_loss = all_loss[len(dialog_config.inform_slots)+1:] 97 | t_elap = time.time() - tst 98 | if self.num_updates%DISPF==0: self._print_progress(loss, t_elap, kl_loss, x_loss) 99 | else: 100 | loss = self.update(regime='RL') 101 | t_elap = time.time() - tst 102 | if self.num_updates%DISPF==0: self._print_progress(loss, t_elap) 103 | if self.num_updates%SAVEF==0: self.save_model(dialog_config.MODEL_PATH+self._name) 104 | 105 | self.state = {} 106 | self.state['database'] = pkl.loads(pkl.dumps(self.database,-1)) 107 | self.state['prevact'] = 'begin@begin' 108 | self.state['inform_slots'] = self._init_beliefs() 109 | self.state['turn'] = 0 110 | self.state['num_requests'] = {s:0 for s in self.state['database'].slots} 111 | self.state['slot_tracker'] = set() 112 | self.state['dont_care'] = set() 113 | p_db_i = (1./self.state['database'].N)*np.ones((self.state['database'].N,)) 114 | self.state['init_entropy'] = calc_entropies(self.state['inform_slots'], p_db_i, 115 | self.state['database']) 116 | self.state['inputs'] = [] 117 | self.state['actions'] = [] 118 | self.state['rewards'] = [] 119 | self.state['indices'] = [] 120 | self.state['ptargets'] = [] 121 | self.state['phitargets'] = [] 122 | self.state['hid_state'] = [np.zeros((1,self.r_hid)).astype('float32') \ 123 | for s in dialog_config.inform_slots] 124 | self.state['pol_state'] = np.zeros((1,self.n_hid)).astype('float32') 125 | 126 | ''' get next action based on rules ''' 127 | def next(self, user_action, verbose=False): 128 | self.state['turn'] += 1 129 | 130 | p_vector = np.zeros((self.in_size,)).astype('float32') 131 | p_vector[:self.feat_extractor.n] = self.feat_extractor.featurize( \ 132 | user_action['nl_sentence']) 133 | if self.state['turn']>1: 134 | pr_act = self.state['prevact'].split('@') 135 | assert pr_act[0]!='inform', 'Agent called after informing!' 136 | act_id = dialog_config.inform_slots.index(pr_act[1]) 137 | p_vector[self.feat_extractor.n+act_id] = 1 138 | p_vector = np.expand_dims(np.expand_dims(p_vector, axis=0), axis=0) 139 | p_vector = standardize(p_vector) 140 | 141 | p_targets = [] 142 | phi_targets = [] 143 | if self.training and self.num_updates= self.max_req: 232 | continue 233 | act['diaact'] = 'request' 234 | act['request_slots'][s] = 'UNK' 235 | action = dialog_config.inform_slots.index(s) 236 | req = True 237 | break 238 | if not req: 239 | # agent confident about all slots, inform 240 | act['diaact'] = 'inform' 241 | act['target'] = self._inform(db_probs) 242 | action = len(dialog_config.inform_slots) 243 | return act, action 244 | 245 | def terminate_episode(self, user_action): 246 | assert self.state['turn'] <= self.max_turn, "More turn than MAX_TURN!!" 247 | total_reward = aggregate_rewards(self.state['rewards']+[user_action['reward']],self.discount) 248 | 249 | if self.state['turn']==self.max_turn: 250 | db_index = np.arange(dialog_config.SUCCESS_MAX_RANK).astype('int32') 251 | db_switch = 0 252 | else: 253 | db_index = self.state['indices'] 254 | db_switch = 1 255 | 256 | inp = np.zeros((self.max_turn,self.in_size)).astype('float32') 257 | actmask = np.zeros((self.max_turn,self.out_size)).astype('int8') 258 | turnmask = np.zeros((self.max_turn,)).astype('int8') 259 | p_targets = [np.zeros((self.max_turn,self.slot_sizes[i])).astype('float32') \ 260 | for i in range(len(dialog_config.inform_slots))] 261 | phi_targets = [np.zeros((self.max_turn,)).astype('float32') \ 262 | for i in range(len(dialog_config.inform_slots))] 263 | for t in xrange(0,self.state['turn']): 264 | actmask[t,self.state['actions'][t]] = 1 265 | inp[t,:] = self.state['inputs'][t] 266 | turnmask[t] = 1 267 | if self.training and self.num_updates0: self.recent_successes.append(1) 278 | else: self.recent_successes.append(-1) 279 | 280 | -------------------------------------------------------------------------------- /deep_dialog/agents/agent_nl_rule_hard.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | 4 | from deep_dialog import dialog_config, tools 5 | from agent import Agent 6 | from hardDB import HardDB 7 | from belief_tracker import BeliefTracker 8 | from utils import * 9 | 10 | from collections import Counter, defaultdict 11 | 12 | import operator 13 | import random 14 | import math 15 | import numpy as np 16 | import cPickle as pkl 17 | import copy 18 | import re 19 | import nltk 20 | 21 | class AgentNLRuleHard(Agent,HardDB,BeliefTracker): 22 | 23 | def initialize_episode(self): 24 | self.state = {} 25 | self.state['database'] = pkl.loads(pkl.dumps(self.database,-1)) 26 | for slot in self.state['database'].slots: 27 | if slot not in dialog_config.inform_slots: self.state['database'].delete_slot(slot) 28 | self.state['prevact'] = 'begin@begin' 29 | self.state['inform_slots'] = self._init_beliefs() 30 | self.state['turn'] = 0 31 | p_db_i = np.ones((self.state['database'].N,))/self.state['database'].N 32 | self.state['init_entropy'] = calc_entropies(self.state['inform_slots'], p_db_i, 33 | self.state['database']) 34 | self.state['num_requests'] = {s:0 for s in self.state['inform_slots'].keys()} 35 | self.state['slot_tracker'] = set() 36 | self.state['dont_care'] = set() 37 | 38 | ''' get next action based on rules ''' 39 | def next(self, user_action, verbose=False): 40 | self._update_state(user_action['nl_sentence'], upd=self.upd, verbose=verbose) 41 | self.state['turn'] += 1 42 | 43 | act = {} 44 | act['diaact'] = 'UNK' 45 | act['request_slots'] = {} 46 | act['target'] = [] 47 | 48 | db_status, db_index = self._check_db() 49 | H_slots = {} 50 | for s in dialog_config.inform_slots: 51 | s_p = self.state['inform_slots'][s]/self.state['inform_slots'][s].sum() 52 | H_slots[s] = tools.entropy_p(s_p) 53 | sorted_entropies = sorted(H_slots.items(), key=operator.itemgetter(1), reverse=True) 54 | if verbose: 55 | print 'Agent slot belief entropies - ' 56 | print ' '.join(['%s:%.2f' %(k,v) for k,v in H_slots.iteritems()]) 57 | 58 | if not db_status: 59 | # no match, some error, re-ask some slot 60 | act['diaact'] = 'request' 61 | request_slot = random.choice(self.state['inform_slots'].keys()) 62 | act['request_slots'][request_slot] = 'UNK' 63 | self.state['prevact'] = 'request@%s' %request_slot 64 | self.state['num_requests'][request_slot] += 1 65 | elif len(db_status)==1: 66 | act['diaact'] = 'inform' 67 | act['target'] = self._inform(db_index) 68 | self.state['prevact'] = 'inform@inform' 69 | else: 70 | req = False 71 | for (s,h) in sorted_entropies: 72 | if H_slots[s]= self.max_req: 74 | continue 75 | act['diaact'] = 'request' 76 | act['request_slots'][s] = 'UNK' 77 | self.state['prevact'] = 'request@%s' %s 78 | self.state['num_requests'][s] += 1 79 | req = True 80 | break 81 | if not req: 82 | # agent confident about all slots, inform 83 | act['diaact'] = 'inform' 84 | act['target'] = self._inform(db_index) 85 | self.state['prevact'] = 'inform@inform' 86 | 87 | act['posterior'] = np.zeros((len(self.database.labels),)) 88 | act['posterior'][db_index] = 1./len(db_index) 89 | 90 | return act 91 | 92 | def terminate_episode(self, user_action): 93 | return 94 | -------------------------------------------------------------------------------- /deep_dialog/agents/agent_nl_rule_no.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | 4 | from deep_dialog import dialog_config, tools 5 | from agent import Agent 6 | from softDB import SoftDB 7 | from belief_tracker import BeliefTracker 8 | from utils import * 9 | 10 | from collections import Counter, defaultdict 11 | 12 | import operator 13 | import random 14 | import math 15 | import numpy as np 16 | import copy 17 | import re 18 | import nltk 19 | import cPickle as pkl 20 | 21 | class AgentNLRuleNoDB(Agent,SoftDB,BeliefTracker): 22 | 23 | def initialize_episode(self): 24 | self.state = {} 25 | self.state['database'] = pkl.loads(pkl.dumps(self.database,-1)) 26 | self.state['prevact'] = 'begin@begin' 27 | self.state['inform_slots'] = self._init_beliefs() 28 | self.state['turn'] = 0 29 | self.state['init_entropy'] = {} 30 | for s in dialog_config.inform_slots: 31 | s_p = self.state['inform_slots'][s]/self.state['inform_slots'][s].sum() 32 | self.state['init_entropy'][s] = tools.entropy_p(s_p) 33 | self.state['num_requests'] = {s:0 for s in self.state['inform_slots'].keys()} 34 | self.state['slot_tracker'] = set() 35 | self.state['dont_care'] = set() 36 | 37 | ''' get next action based on rules ''' 38 | def next(self, user_action, verbose=False): 39 | self._update_state(user_action['nl_sentence'], upd=self.upd, verbose=verbose) 40 | self.state['turn'] += 1 41 | 42 | act = {} 43 | act['diaact'] = 'UNK' 44 | act['request_slots'] = {} 45 | act['target'] = [] 46 | 47 | db_probs = self._check_db() 48 | H_slots = {} 49 | for s in dialog_config.inform_slots: 50 | s_p = self.state['inform_slots'][s]/self.state['inform_slots'][s].sum() 51 | H_slots[s] = tools.entropy_p(s_p) 52 | if verbose: 53 | print 'Agent slot belief entropies - ' 54 | print ' '.join(['%s:%.2f' %(k,v) for k,v in H_slots.iteritems()]) 55 | 56 | sorted_entropies = sorted(H_slots.items(), key=operator.itemgetter(1), reverse=True) 57 | req = False 58 | for (s,h) in sorted_entropies: 59 | if H_slots[s]= self.max_req: 61 | continue 62 | act['diaact'] = 'request' 63 | act['request_slots'][s] = 'UNK' 64 | self.state['prevact'] = 'request@%s' %s 65 | self.state['num_requests'][s] += 1 66 | req = True 67 | break 68 | if not req: 69 | # agent confident about all slots, inform 70 | act['diaact'] = 'inform' 71 | act['target'] = self._inform(db_probs) 72 | self.state['prevact'] = 'inform@inform' 73 | 74 | act['probs'] = [np.concatenate([self.state['inform_slots'][s]/ \ 75 | self.state['inform_slots'][s].sum(), \ 76 | np.asarray([float(self.state['database'].inv_counts[s][-1])/ \ 77 | self.state['database'].N])]) \ 78 | for s in dialog_config.inform_slots] 79 | act['phis'] = [1. if s in self.state['dont_care'] else 0. for s in dialog_config.inform_slots] 80 | act['posterior'] = db_probs 81 | return act 82 | 83 | def terminate_episode(self, user_action): 84 | return 85 | -------------------------------------------------------------------------------- /deep_dialog/agents/agent_nl_rule_soft.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | 4 | from deep_dialog import dialog_config, tools 5 | from agent import Agent 6 | from softDB import SoftDB 7 | from belief_tracker import BeliefTracker 8 | from utils import * 9 | 10 | from collections import Counter, defaultdict 11 | 12 | import operator 13 | import random 14 | import math 15 | import numpy as np 16 | import copy 17 | import re 18 | import nltk 19 | import cPickle as pkl 20 | 21 | class AgentNLRuleSoft(Agent,SoftDB,BeliefTracker): 22 | 23 | def initialize_episode(self): 24 | self.state = {} 25 | self.state['database'] = pkl.loads(pkl.dumps(self.database,-1)) 26 | for slot in self.state['database'].slots: 27 | if slot not in dialog_config.inform_slots: self.state['database'].delete_slot(slot) 28 | self.state['prevact'] = 'begin@begin' 29 | self.state['inform_slots'] = self._init_beliefs() 30 | self.state['turn'] = 0 31 | p_db_i = (1./self.state['database'].N)*np.ones((self.state['database'].N,)) 32 | self.state['init_entropy'] = calc_entropies(self.state['inform_slots'], p_db_i, 33 | self.state['database']) 34 | self.state['num_requests'] = {s:0 for s in self.state['inform_slots'].keys()} 35 | self.state['slot_tracker'] = set() 36 | self.state['dont_care'] = set() 37 | 38 | ''' get next action based on rules ''' 39 | def next(self, user_action, verbose=False): 40 | self._update_state(user_action['nl_sentence'], upd=self.upd, verbose=verbose) 41 | self.state['turn'] += 1 42 | 43 | act = {} 44 | act['diaact'] = 'UNK' 45 | act['request_slots'] = {} 46 | act['target'] = [] 47 | 48 | db_probs = self._check_db() 49 | H_db = tools.entropy_p(db_probs) 50 | H_slots = calc_entropies(self.state['inform_slots'], db_probs, self.state['database']) 51 | if verbose: 52 | print 'Agent DB entropy = ', H_db 53 | print 'Agent slot belief entropies - ' 54 | print ' '.join(['%s:%.2f' %(k,v) for k,v in H_slots.iteritems()]) 55 | 56 | if H_db < self.tr: 57 | # agent reasonable confident, inform 58 | act['diaact'] = 'inform' 59 | act['target'] = self._inform(db_probs) 60 | else: 61 | sorted_entropies = sorted(H_slots.items(), key=operator.itemgetter(1), reverse=True) 62 | req = False 63 | for (s,h) in sorted_entropies: 64 | if H_slots[s]= self.max_req: 66 | continue 67 | act['diaact'] = 'request' 68 | act['request_slots'][s] = 'UNK' 69 | self.state['prevact'] = 'request@%s' %s 70 | self.state['num_requests'][s] += 1 71 | req = True 72 | break 73 | if not req: 74 | # agent confident about all slots, inform 75 | act['diaact'] = 'inform' 76 | act['target'] = self._inform(db_probs) 77 | self.state['prevact'] = 'inform@inform' 78 | 79 | act['probs'] = [np.concatenate([self.state['inform_slots'][s]/self.state['inform_slots'][s].sum(), \ 80 | np.asarray([float(self.state['database'].inv_counts[s][-1])/self.state['database'].N])]) \ 81 | for s in dialog_config.inform_slots] 82 | act['phis'] = [1. if s in self.state['dont_care'] else 0. for s in dialog_config.inform_slots] 83 | act['posterior'] = db_probs 84 | return act 85 | 86 | def terminate_episode(self, user_action): 87 | return 88 | -------------------------------------------------------------------------------- /deep_dialog/agents/agent_rl.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | 4 | import lasagne 5 | import theano 6 | import lasagne.layers as L 7 | import theano.tensor as T 8 | import numpy as np 9 | import sys 10 | 11 | from collections import Counter, defaultdict, deque 12 | 13 | import random 14 | import cPickle as pkl 15 | 16 | def categorical_sample(probs, mode='sample'): 17 | if mode=='max': 18 | return np.argmax(probs) 19 | else: 20 | x = np.random.uniform() 21 | s = probs[0] 22 | i = 0 23 | while sself.pol_start and self.num_updates%ANNEAL==0: self.anneal_lr() 95 | if self.num_updates < self.pol_start: loss = self.update(regime='SL') 96 | else: loss = self.update(regime='RL') 97 | if self.num_updates%DISPF==0: self._print_progress(loss) 98 | if self.num_updates%SAVEF==0: self.save_model(dialog_config.MODEL_PATH+self._name) 99 | 100 | self.state = {} 101 | self.state['database'] = pkl.loads(pkl.dumps(self.database,-1)) 102 | self.state['prevact'] = 'begin@begin' 103 | self.state['inform_slots'] = self._init_beliefs() 104 | self.state['turn'] = 0 105 | self.state['num_requests'] = {s:0 for s in self.state['database'].slots} 106 | self.state['slot_tracker'] = set() 107 | self.state['dont_care'] = set() 108 | p_db_i = (1./self.state['database'].N)*np.ones((self.state['database'].N,)) 109 | self.state['init_entropy'] = calc_entropies(self.state['inform_slots'], p_db_i, 110 | self.state['database']) 111 | self.state['inputs'] = [] 112 | self.state['actions'] = [] 113 | self.state['rewards'] = [] 114 | self.state['pol_state'] = np.zeros((1,self.n_hid)).astype('float32') 115 | 116 | ''' get next action based on rules ''' 117 | def next(self, user_action, verbose=False): 118 | self._update_state(user_action['nl_sentence'], upd=self.upd, verbose=verbose) 119 | self.state['turn'] += 1 120 | 121 | db_probs = self._check_db() 122 | H_db = tools.entropy_p(db_probs) 123 | H_slots = calc_entropies(self.state['inform_slots'], db_probs, self.state['database']) 124 | p_vector = np.zeros((self.in_size,)).astype('float32') 125 | if self.inputtype=='entropy': 126 | for i,s in enumerate(dialog_config.inform_slots): 127 | if s in H_slots: p_vector[i] = H_slots[s] 128 | p_vector[i+len(dialog_config.inform_slots)] = 1. if s in self.state['dont_care'] \ 129 | else 0. 130 | if self.state['turn']>1: 131 | pr_act = self.state['prevact'].split('@') 132 | act_id = dialog_config.inform_slots.index(pr_act[1]) 133 | p_vector[2*len(dialog_config.inform_slots)+act_id] = 1. 134 | p_vector[-1] = H_db 135 | else: 136 | p_slots = self._dict2vec(self.state['inform_slots']) 137 | p_vector[:p_slots.shape[0]] = p_slots 138 | if self.state['turn']>1: 139 | pr_act = self.state['prevact'].split('@') 140 | act_id = dialog_config.inform_slots.index(pr_act[1]) 141 | p_vector[p_slots.shape[0]+act_id] = 1. 142 | p_vector[-self.database.N:] = db_probs 143 | p_vector = np.expand_dims(np.expand_dims(p_vector, axis=0), axis=0) 144 | p_vector = standardize(p_vector) 145 | 146 | if self.training and self.num_updates= self.max_req: 204 | continue 205 | act['diaact'] = 'request' 206 | act['request_slots'][s] = 'UNK' 207 | self.state['prevact'] = 'request@%s' %s 208 | self.state['num_requests'][s] += 1 209 | action = dialog_config.inform_slots.index(s) 210 | req = True 211 | break 212 | if not req: 213 | # agent confident about all slots, inform 214 | act['diaact'] = 'inform' 215 | act['target'] = self._inform(db_probs) 216 | self.state['prevact'] = 'inform@inform' 217 | action = len(dialog_config.inform_slots) 218 | return act, action 219 | 220 | def terminate_episode(self, user_action): 221 | assert self.state['turn'] <= self.max_turn, "More turn than MAX_TURN!!" 222 | total_reward = aggregate_rewards(self.state['rewards']+[user_action['reward']],self.discount) 223 | inp = np.zeros((self.max_turn,self.in_size)).astype('float32') 224 | actmask = np.zeros((self.max_turn,self.out_size)).astype('int32') 225 | turnmask = np.zeros((self.max_turn,)).astype('int32') 226 | for t in xrange(0,self.state['turn']): 227 | actmask[t,self.state['actions'][t]] = 1 228 | inp[t,:] = self.state['inputs'][t] 229 | turnmask[t] = 1 230 | self.add_to_pool(inp, turnmask, actmask, total_reward) 231 | self.recent_rewards.append(total_reward) 232 | self.recent_turns.append(self.state['turn']) 233 | if self.state['turn'] == self.max_turn: self.recent_successes.append(0) 234 | elif user_action['reward']>0: self.recent_successes.append(1) 235 | else: self.recent_successes.append(-1) 236 | 237 | -------------------------------------------------------------------------------- /deep_dialog/agents/agent_simpleRL_allact_hardDB.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | 4 | import numpy as np 5 | import cPickle as pkl 6 | 7 | from deep_dialog import dialog_config, tools 8 | from collections import Counter, defaultdict, deque 9 | from agent_rl import RLAgent, aggregate_rewards 10 | from belief_tracker import BeliefTracker 11 | from hardDB import HardDB 12 | from utils import * 13 | 14 | import operator 15 | import random 16 | import math 17 | import copy 18 | import re 19 | import nltk 20 | 21 | # params 22 | DISPF = 1 23 | SAVEF = 100 24 | ANNEAL = 800 25 | 26 | class AgentSimpleRLAllActHardDB(RLAgent,HardDB,BeliefTracker): 27 | def __init__(self, movie_dict=None, act_set=None, slot_set=None, db=None, 28 | train=True, _reload=False, n_hid=100, batch=128, ment=0., 29 | inputtype='full', pol_start=0, upd=10, tr=2.0, ts=0.5, 30 | max_req=2, frac=0.5, lr=0.005, name=None): 31 | self.movie_dict = movie_dict 32 | self.act_set = act_set 33 | self.slot_set = slot_set 34 | self.database = db 35 | self.max_turn = dialog_config.MAX_TURN 36 | self.training = train 37 | self.inputtype = inputtype 38 | self.pol_start = pol_start 39 | self.upd = upd 40 | if inputtype=='entropy': 41 | #in_size = 3*len(dialog_config.inform_slots)+1 42 | in_size = 3*len(dialog_config.inform_slots)+6 # 6 bins for number of retrieved results 43 | else: 44 | in_size = sum([len(self.movie_dict.dict[s])+2 for s in dialog_config.inform_slots]) + \ 45 | self.database.N 46 | out_size = len(dialog_config.inform_slots)+1 47 | self._init_model(in_size, out_size, n_hid=n_hid, learning_rate_sl=lr, batch_size=batch, \ 48 | ment=ment) 49 | self._name = name 50 | if _reload: self.load_model(dialog_config.MODEL_PATH+self._name) 51 | if train: self.save_model(dialog_config.MODEL_PATH+self._name) 52 | self._init_experience_pool(batch) 53 | self.episode_count = 0 54 | self.recent_rewards = deque([], 1000) 55 | self.recent_successes = deque([], 1000) 56 | self.recent_turns = deque([], 1000) 57 | self.recent_loss = deque([], 10) 58 | self.discount = 0.99 59 | self.num_updates = 0 60 | self.tr = tr 61 | self.ts = ts 62 | self.frac = frac 63 | self.max_req = max_req 64 | 65 | def _dict2vec(self, p_dict): 66 | p_vec = [] 67 | for s in dialog_config.inform_slots: 68 | s_np = p_dict[s]/p_dict[s].sum() 69 | if s in self.state['dont_care']: 70 | np.append(s_np,1.) 71 | else: 72 | np.append(s_np,0.) 73 | p_vec.append(s_np) 74 | return np.concatenate(p_vec).astype('float32') 75 | 76 | def _print_progress(self,loss): 77 | self.recent_loss.append(loss) 78 | avg_ret = float(sum(self.recent_rewards))/len(self.recent_rewards) 79 | avg_turn = float(sum(self.recent_turns))/len(self.recent_turns) 80 | avg_loss = float(sum(self.recent_loss))/len(self.recent_loss) 81 | n_suc, n_fail, n_inc, tot = 0, 0, 0, 0 82 | for s in self.recent_successes: 83 | if s==-1: n_fail += 1 84 | elif s==0: n_inc += 1 85 | else: n_suc += 1 86 | tot += 1 87 | print 'Update %d. Avg turns = %.2f . Avg Reward = %.2f . Success Rate = %.2f . Fail Rate = %.2f . Incomplete Rate = %.2f . Loss = %.3f' % \ 88 | (self.num_updates, avg_turn, avg_ret, \ 89 | float(n_suc)/tot, float(n_fail)/tot, float(n_inc)/tot, avg_loss) 90 | 91 | def initialize_episode(self): 92 | self.episode_count += 1 93 | if self.training and self.episode_count%self.batch_size==0: 94 | self.num_updates += 1 95 | if self.num_updates>self.pol_start and self.num_updates%ANNEAL==0: self.anneal_lr() 96 | if self.num_updates < self.pol_start: loss = self.update(regime='SL') 97 | else: loss = self.update(regime='RL') 98 | if self.num_updates%DISPF==0: self._print_progress(loss) 99 | if self.num_updates%SAVEF==0: self.save_model(dialog_config.MODEL_PATH+self._name) 100 | 101 | self.state = {} 102 | self.state['database'] = pkl.loads(pkl.dumps(self.database,-1)) 103 | self.state['prevact'] = 'begin@begin' 104 | self.state['inform_slots'] = self._init_beliefs() 105 | self.state['turn'] = 0 106 | self.state['num_requests'] = {s:0 for s in self.state['inform_slots'].keys()} 107 | self.state['slot_tracker'] = set() 108 | self.state['dont_care'] = set() 109 | p_db_i = (1./self.state['database'].N)*np.ones((self.state['database'].N,)) 110 | self.state['init_entropy'] = calc_entropies(self.state['inform_slots'], p_db_i, 111 | self.state['database']) 112 | self.state['inputs'] = [] 113 | self.state['actions'] = [] 114 | self.state['rewards'] = [] 115 | self.state['pol_state'] = np.zeros((1,self.n_hid)).astype('float32') 116 | 117 | ''' get next action based on rules ''' 118 | def next(self, user_action, verbose=False): 119 | self._update_state(user_action['nl_sentence'], upd=self.upd, verbose=verbose) 120 | self.state['turn'] += 1 121 | 122 | db_status, db_index = self._check_db() 123 | N_db = len(db_index) 124 | H_slots = {} 125 | for s in dialog_config.inform_slots: 126 | s_p = self.state['inform_slots'][s]/self.state['inform_slots'][s].sum() 127 | H_slots[s] = tools.entropy_p(s_p) 128 | p_vector = np.zeros((self.in_size,)).astype('float32') 129 | if self.inputtype=='entropy': 130 | for i,s in enumerate(dialog_config.inform_slots): 131 | if s in H_slots: p_vector[i] = H_slots[s] 132 | p_vector[i+len(dialog_config.inform_slots)] = 1. if s in self.state['dont_care'] \ 133 | else 0. 134 | if self.state['turn']>1: 135 | pr_act = self.state['prevact'].split('@') 136 | act_id = dialog_config.inform_slots.index(pr_act[1]) 137 | p_vector[2*len(dialog_config.inform_slots)+act_id] = 1. 138 | #p_vector[-1] = N_db/self.state['database'].N 139 | if N_db<=5: p_vector[N_db-6] = 1. 140 | else: p_vector[-1] = 1. 141 | else: 142 | p_slots = self._dict2vec(self.state['inform_slots']) 143 | p_vector[:p_slots.shape[0]] = p_slots 144 | if self.state['turn']>1: 145 | pr_act = self.state['prevact'].split('@') 146 | act_id = dialog_config.inform_slots.index(pr_act[1]) 147 | p_vector[p_slots.shape[0]+act_id] = 1. 148 | db_i_vector = np.zeros((self.database.N,)).astype('float32') 149 | db_i_vector[db_index] = 1. 150 | p_vector[-self.database.N:] = db_i_vector 151 | p_vector = np.expand_dims(np.expand_dims(p_vector, axis=0), axis=0) 152 | p_vector = standardize(p_vector) 153 | 154 | if self.training and self.num_updates0: 173 | act['posterior'][db_index] = 1./len(db_index) 174 | else: 175 | act['posterior'] = 1./len(self.database.labels) 176 | 177 | return act 178 | 179 | def _prob_act(self, p, db_index, mode='sample'): 180 | act = {} 181 | act['diaact'] = 'UNK' 182 | act['request_slots'] = {} 183 | act['target'] = [] 184 | 185 | action, probs, p_out = self.act(p, self.state['pol_state'], mode=mode) 186 | if action==self.out_size-1: 187 | act['diaact'] = 'inform' 188 | act['target'] = self._inform(db_index) 189 | self.state['prevact'] = 'inform@inform' 190 | else: 191 | act['diaact'] = 'request' 192 | s = dialog_config.inform_slots[action] 193 | act['request_slots'][s] = 'UNK' 194 | self.state['prevact'] = 'request@%s' %s 195 | self.state['num_requests'][s] += 1 196 | return act, action, p_out 197 | 198 | def _rule_act(self, p, db_index): 199 | act = {} 200 | act['diaact'] = 'UNK' 201 | act['request_slots'] = {} 202 | act['target'] = [] 203 | 204 | if p[-1] == 0: 205 | # no match, some error, re-ask some slot 206 | act['diaact'] = 'request' 207 | request_slot = random.choice(self.state['inform_slots'].keys()) 208 | act['request_slots'][request_slot] = 'UNK' 209 | self.state['prevact'] = 'request@%s' %request_slot 210 | self.state['num_requests'][request_slot] += 1 211 | action = dialog_config.inform_slots.index(request_slot) 212 | elif p[-1] == 1: 213 | # agent reasonable confident, inform 214 | act['diaact'] = 'inform' 215 | act['target'] = self._inform(db_index) 216 | action = len(dialog_config.inform_slots) 217 | self.state['prevact'] = 'inform@inform' 218 | else: 219 | H_slots = {s:p[i] for i,s in enumerate(dialog_config.inform_slots)} 220 | sorted_entropies = sorted(H_slots.items(), key=operator.itemgetter(1), reverse=True) 221 | req = False 222 | for (s,h) in sorted_entropies: 223 | if H_slots[s]= self.max_req: 225 | continue 226 | act['diaact'] = 'request' 227 | act['request_slots'][s] = 'UNK' 228 | self.state['prevact'] = 'request@%s' %s 229 | self.state['num_requests'][s] += 1 230 | action = dialog_config.inform_slots.index(s) 231 | req = True 232 | break 233 | if not req: 234 | # agent confident about all slots, inform 235 | act['diaact'] = 'inform' 236 | act['target'] = self._inform(db_index) 237 | self.state['prevact'] = 'inform@inform' 238 | action = len(dialog_config.inform_slots) 239 | return act, action 240 | 241 | def terminate_episode(self, user_action): 242 | assert self.state['turn'] <= self.max_turn, "More turn than MAX_TURN!!" 243 | total_reward = aggregate_rewards(self.state['rewards']+[user_action['reward']],self.discount) 244 | inp = np.zeros((self.max_turn,self.in_size)).astype('float32') 245 | actmask = np.zeros((self.max_turn,self.out_size)).astype('int32') 246 | turnmask = np.zeros((self.max_turn,)).astype('int32') 247 | for t in xrange(0,self.state['turn']): 248 | actmask[t,self.state['actions'][t]] = 1 249 | inp[t,:] = self.state['inputs'][t] 250 | turnmask[t] = 1 251 | self.add_to_pool(inp, turnmask, actmask, total_reward) 252 | self.recent_rewards.append(total_reward) 253 | self.recent_turns.append(self.state['turn']) 254 | if self.state['turn'] == self.max_turn: self.recent_successes.append(0) 255 | elif user_action['reward']>0: self.recent_successes.append(1) 256 | else: self.recent_successes.append(-1) 257 | 258 | -------------------------------------------------------------------------------- /deep_dialog/agents/agent_simpleRL_allact_noDB.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | 4 | import numpy as np 5 | import cPickle as pkl 6 | 7 | from deep_dialog import dialog_config, tools 8 | from collections import Counter, defaultdict, deque 9 | from agent_rl import RLAgent, aggregate_rewards 10 | from belief_tracker import BeliefTracker 11 | from softDB import SoftDB 12 | from utils import * 13 | 14 | import operator 15 | import random 16 | import math 17 | import copy 18 | import re 19 | import nltk 20 | 21 | # params 22 | DISPF = 1 23 | SAVEF = 100 24 | ANNEAL = 800 25 | 26 | class AgentSimpleRLAllActNoDB(RLAgent,SoftDB,BeliefTracker): 27 | def __init__(self, movie_dict=None, act_set=None, slot_set=None, db=None, 28 | train=True, _reload=False, n_hid=100, batch=128, ment=0., 29 | inputtype='full', pol_start=0, upd=10, tr=2.0, ts=0.5, 30 | max_req=2, frac=0.5, lr=0.005, name=None): 31 | self.movie_dict = movie_dict 32 | self.act_set = act_set 33 | self.slot_set = slot_set 34 | self.database = db 35 | self.max_turn = dialog_config.MAX_TURN 36 | self.training = train 37 | self.inputtype = inputtype 38 | self.pol_start = pol_start 39 | self.upd = upd 40 | if inputtype=='entropy': 41 | in_size = 3*len(dialog_config.inform_slots) 42 | else: 43 | in_size = sum([len(self.movie_dict.dict[s])+2 for s in dialog_config.inform_slots]) 44 | out_size = len(dialog_config.inform_slots)+1 45 | self._init_model(in_size, out_size, n_hid=n_hid, learning_rate_sl=lr, batch_size=batch, \ 46 | ment=ment) 47 | self._name = name 48 | if _reload: self.load_model(dialog_config.MODEL_PATH+self._name) 49 | if train: self.save_model(dialog_config.MODEL_PATH+self._name) 50 | self._init_experience_pool(batch) 51 | self.episode_count = 0 52 | self.recent_rewards = deque([], 1000) 53 | self.recent_successes = deque([], 1000) 54 | self.recent_turns = deque([], 1000) 55 | self.recent_loss = deque([], 10) 56 | self.discount = 0.99 57 | self.num_updates = 0 58 | self.tr = tr 59 | self.ts = ts 60 | self.frac = frac 61 | self.max_req = max_req 62 | 63 | def _dict2vec(self, p_dict): 64 | p_vec = [] 65 | for s in dialog_config.inform_slots: 66 | s_np = p_dict[s]/p_dict[s].sum() 67 | if s in self.state['dont_care']: 68 | np.append(s_np,1.) 69 | else: 70 | np.append(s_np,0.) 71 | p_vec.append(s_np) 72 | return np.concatenate(p_vec).astype('float32') 73 | 74 | def _print_progress(self,loss): 75 | self.recent_loss.append(loss) 76 | avg_ret = float(sum(self.recent_rewards))/len(self.recent_rewards) 77 | avg_turn = float(sum(self.recent_turns))/len(self.recent_turns) 78 | avg_loss = float(sum(self.recent_loss))/len(self.recent_loss) 79 | n_suc, n_fail, n_inc, tot = 0, 0, 0, 0 80 | for s in self.recent_successes: 81 | if s==-1: n_fail += 1 82 | elif s==0: n_inc += 1 83 | else: n_suc += 1 84 | tot += 1 85 | print 'Update %d. Avg turns = %.2f . Avg Reward = %.2f . Success Rate = %.2f . Fail Rate = %.2f . Incomplete Rate = %.2f . Loss = %.3f' % \ 86 | (self.num_updates, avg_turn, avg_ret, \ 87 | float(n_suc)/tot, float(n_fail)/tot, float(n_inc)/tot, avg_loss) 88 | 89 | def initialize_episode(self): 90 | self.episode_count += 1 91 | if self.training and self.episode_count%self.batch_size==0: 92 | self.num_updates += 1 93 | if self.num_updates>self.pol_start and self.num_updates%ANNEAL==0: self.anneal_lr() 94 | if self.num_updates < self.pol_start: loss = self.update(regime='SL') 95 | else: loss = self.update(regime='RL') 96 | if self.num_updates%DISPF==0: self._print_progress(loss) 97 | if self.num_updates%SAVEF==0: self.save_model(dialog_config.MODEL_PATH+self._name) 98 | 99 | self.state = {} 100 | self.state['database'] = pkl.loads(pkl.dumps(self.database,-1)) 101 | self.state['prevact'] = 'begin@begin' 102 | self.state['inform_slots'] = self._init_beliefs() 103 | self.state['turn'] = 0 104 | self.state['num_requests'] = {s:0 for s in self.state['database'].slots} 105 | self.state['slot_tracker'] = set() 106 | self.state['dont_care'] = set() 107 | self.state['init_entropy'] = {} 108 | for s in dialog_config.inform_slots: 109 | s_p = self.state['inform_slots'][s]/self.state['inform_slots'][s].sum() 110 | self.state['init_entropy'][s] = tools.entropy_p(s_p) 111 | self.state['inputs'] = [] 112 | self.state['actions'] = [] 113 | self.state['rewards'] = [] 114 | self.state['pol_state'] = np.zeros((1,self.n_hid)).astype('float32') 115 | 116 | ''' get next action based on rules ''' 117 | def next(self, user_action, verbose=False): 118 | self._update_state(user_action['nl_sentence'], upd=self.upd, verbose=verbose) 119 | self.state['turn'] += 1 120 | 121 | db_probs = self._check_db() 122 | H_slots = {} 123 | for s in dialog_config.inform_slots: 124 | s_p = self.state['inform_slots'][s]/self.state['inform_slots'][s].sum() 125 | H_slots[s] = tools.entropy_p(s_p) 126 | p_vector = np.zeros((self.in_size,)).astype('float32') 127 | if self.inputtype=='entropy': 128 | for i,s in enumerate(dialog_config.inform_slots): 129 | if s in H_slots: p_vector[i] = H_slots[s] 130 | p_vector[i+len(dialog_config.inform_slots)] = 1. if s in self.state['dont_care'] \ 131 | else 0. 132 | if self.state['turn']>1: 133 | pr_act = self.state['prevact'].split('@') 134 | act_id = dialog_config.inform_slots.index(pr_act[1]) 135 | p_vector[2*len(dialog_config.inform_slots)+act_id] = 1. 136 | else: 137 | p_slots = self._dict2vec(self.state['inform_slots']) 138 | p_vector[:p_slots.shape[0]] = p_slots 139 | if self.state['turn']>1: 140 | pr_act = self.state['prevact'].split('@') 141 | act_id = dialog_config.inform_slots.index(pr_act[1]) 142 | p_vector[p_slots.shape[0]+act_id] = 1. 143 | p_vector = np.expand_dims(np.expand_dims(p_vector, axis=0), axis=0) 144 | p_vector = standardize(p_vector) 145 | 146 | if self.training and self.num_updates= self.max_req: 197 | continue 198 | act['diaact'] = 'request' 199 | act['request_slots'][s] = 'UNK' 200 | self.state['prevact'] = 'request@%s' %s 201 | self.state['num_requests'][s] += 1 202 | action = dialog_config.inform_slots.index(s) 203 | req = True 204 | break 205 | if not req: 206 | # agent confident about all slots, inform 207 | act['diaact'] = 'inform' 208 | act['target'] = self._inform(db_probs) 209 | self.state['prevact'] = 'inform@inform' 210 | action = len(dialog_config.inform_slots) 211 | return act, action 212 | 213 | def terminate_episode(self, user_action): 214 | assert self.state['turn'] <= self.max_turn, "More turn than MAX_TURN!!" 215 | total_reward = aggregate_rewards(self.state['rewards']+[user_action['reward']],self.discount) 216 | inp = np.zeros((self.max_turn,self.in_size)).astype('float32') 217 | actmask = np.zeros((self.max_turn,self.out_size)).astype('int32') 218 | turnmask = np.zeros((self.max_turn,)).astype('int32') 219 | for t in xrange(0,self.state['turn']): 220 | actmask[t,self.state['actions'][t]] = 1 221 | inp[t,:] = self.state['inputs'][t] 222 | turnmask[t] = 1 223 | self.add_to_pool(inp, turnmask, actmask, total_reward) 224 | self.recent_rewards.append(total_reward) 225 | self.recent_turns.append(self.state['turn']) 226 | if self.state['turn'] == self.max_turn: self.recent_successes.append(0) 227 | elif user_action['reward']>0: self.recent_successes.append(1) 228 | else: self.recent_successes.append(-1) 229 | -------------------------------------------------------------------------------- /deep_dialog/agents/belief_tracker.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | 4 | import nltk 5 | import numpy as np 6 | import time 7 | 8 | from collections import Counter, defaultdict 9 | from deep_dialog.tools import to_tokens 10 | 11 | UPD = 10 12 | 13 | class BeliefTracker: 14 | def _search(self,w_t,s_t): 15 | #w_t = to_tokens(w) 16 | return float(sum([ww in s_t for ww in w_t]))/len(w_t) 17 | 18 | def _search_slots(self, s_t): 19 | matches = {} 20 | for slot,slot_t in self.state['database'].slot_tokens.iteritems(): 21 | m = self._search(slot_t,s_t) 22 | if m>0.: 23 | matches[slot] = m 24 | return matches 25 | 26 | def _search_values(self, s_t): 27 | matches = {} 28 | for slot in self.state['database'].slots: 29 | matches[slot] = defaultdict(float) 30 | for ss in s_t: 31 | if ss in self.movie_dict.tokens[slot]: 32 | for vi in self.movie_dict.tokens[slot][ss]: 33 | matches[slot][vi] += 1. 34 | for vi,f in matches[slot].iteritems(): 35 | val = self.movie_dict.dict[slot][vi] 36 | matches[slot][vi] = f/len(nltk.word_tokenize(val)) 37 | return matches 38 | 39 | ''' update agent state ''' 40 | def _update_state(self, user_utterance, upd=UPD, verbose=False): 41 | prev_act, prev_slot = self.state['prevact'].split('@') 42 | 43 | s_t = to_tokens(user_utterance) 44 | slot_match = self._search_slots(s_t) # search slots 45 | val_match = self._search_values(s_t) # search values 46 | 47 | for slot, values in val_match.iteritems(): 48 | requested = (prev_act=='request') and (prev_slot==slot) 49 | matched = (slot in slot_match) 50 | if not values: 51 | if requested: # asked for value but did not get it 52 | self.state['database'].delete_slot(slot) 53 | self.state['num_requests'][slot] = 1000 54 | self.state['dont_care'].add(slot) 55 | else: 56 | for y, match in values.iteritems(): 57 | #y = self.movie_dict.dict[slot].index(val) 58 | if verbose: 59 | print 'Detected %s' %self.movie_dict.dict[slot][y], ' update = ', match 60 | if matched and requested: 61 | alpha = upd*(match + 1. + slot_match[slot]) 62 | elif matched and not requested: 63 | alpha = upd*(match + slot_match[slot]) 64 | elif not matched and requested: 65 | alpha = upd*(match + 1.) 66 | else: 67 | alpha = upd*match 68 | self.state['inform_slots'][slot][y] += alpha 69 | self.state['slot_tracker'].add(slot) 70 | 71 | def _init_beliefs(self): 72 | beliefs = {s:np.copy(self.state['database'].priors[s]) 73 | for s in self.state['database'].slots} 74 | return beliefs 75 | -------------------------------------------------------------------------------- /deep_dialog/agents/feature_extractor.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | 4 | import io 5 | import nltk 6 | import numpy as np 7 | import cPickle as pkl 8 | import os.path 9 | import string 10 | 11 | from deep_dialog.tools import to_tokens 12 | 13 | class FeatureExtractor: 14 | def __init__(self, corpus_path, db_path, N=1): 15 | self.N = N 16 | save_path = db_path.rsplit('/',1)[0] + '/fdict_%d.p'%N 17 | if os.path.isfile(save_path): 18 | f = open(save_path, 'rb') 19 | self.grams = pkl.load(f) 20 | self.n = pkl.load(f) 21 | f.close() 22 | else: 23 | self.grams = {} 24 | self.n = 0 25 | if corpus_path is not None: self._build_vocab_from_corpus(corpus_path) 26 | if db_path is not None: self._build_vocab_from_db(db_path) 27 | f = open(save_path, 'wb') 28 | pkl.dump(self.grams, f) 29 | pkl.dump(self.n, f) 30 | f.close() 31 | print 'Vocab Size = %d' %self.n 32 | 33 | def _build_vocab_from_db(self, corpus): 34 | try: 35 | f = io.open(corpus, 'r') 36 | for line in f: 37 | elements = line.rstrip().split('\t')[1:] 38 | for ele in elements: 39 | tokens = to_tokens(ele) 40 | for i in range(len(tokens)): 41 | for t in range(self.N): 42 | if i-t<0: continue 43 | ngram = '_'.join(tokens[i-t:i+1]) 44 | if ngram not in self.grams: 45 | self.grams[ngram] = self.n 46 | self.n += 1 47 | f.close() 48 | except UnicodeDecodeError: 49 | f = open(corpus, 'r') 50 | for line in f: 51 | elements = line.rstrip().split('\t')[1:] 52 | for ele in elements: 53 | tokens = to_tokens(ele) 54 | for i in range(len(tokens)): 55 | for t in range(self.N): 56 | if i-t<0: continue 57 | ngram = '_'.join(tokens[i-t:i+1]) 58 | if ngram not in self.grams: 59 | self.grams[ngram] = self.n 60 | self.n += 1 61 | f.close() 62 | 63 | def _build_vocab_from_corpus(self, corpus): 64 | if not os.path.isfile(corpus): return 65 | try: 66 | f = io.open(corpus, 'r') 67 | for line in f: 68 | tokens = to_tokens(line.rstrip()) 69 | for i in range(len(tokens)): 70 | for t in range(self.N): 71 | if i-t<0: continue 72 | ngram = '_'.join(tokens[i-t:i+1]) 73 | if ngram not in self.grams: 74 | self.grams[ngram] = self.n 75 | self.n += 1 76 | f.close() 77 | except UnicodeDecodeError: 78 | f = open(corpus, 'r') 79 | for line in f: 80 | tokens = to_tokens(line.rstrip()) 81 | for i in range(len(tokens)): 82 | for t in range(self.N): 83 | if i-t<0: continue 84 | ngram = '_'.join(tokens[i-t:i+1]) 85 | if ngram not in self.grams: 86 | self.grams[ngram] = self.n 87 | self.n += 1 88 | f.close() 89 | 90 | def featurize(self, text): 91 | vec = np.zeros((len(self.grams),)).astype('float32') 92 | tokens = to_tokens(text) 93 | for i in range(len(tokens)): 94 | for t in range(self.N): 95 | if i-t<0: continue 96 | ngram = '_'.join(tokens[i-t:i+1]) 97 | if ngram in self.grams: 98 | vec[self.grams[ngram]] += 1. 99 | return vec 100 | 101 | if __name__=='__main__': 102 | F = FeatureExtractor('../data/corpora/selected_medium_corpus.txt','../data/selected_medium/db.txt') 103 | print '\n'.join(F.grams.keys()) 104 | print F.featurize('Please search for the movie with Matthew Saville as director') 105 | print F.featurize('I would like to see the movie with drama as genre') 106 | -------------------------------------------------------------------------------- /deep_dialog/agents/hardDB.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | 4 | import random 5 | import operator 6 | import numpy as np 7 | 8 | class HardDB: 9 | ''' get dist over DB based on current beliefs ''' 10 | def _check_db_soft(self): 11 | # induce disttribution over DB based on current beliefs over slots 12 | probs = {} 13 | p_s = np.zeros((self.state['database'].N, \ 14 | len(self.state['database'].slots))).astype('float32') 15 | for i,s in enumerate(self.state['database'].slots): 16 | p = self.state['inform_slots'][s]/self.state['inform_slots'][s].sum() 17 | n = self.state['database'].inv_counts[s] 18 | p_unk = float(n[-1])/self.state['database'].N 19 | p_tilde = p*(1.-p_unk) 20 | p_tilde = np.concatenate([p_tilde,np.asarray([p_unk])]) 21 | p_s[:,i] = p_tilde[self.state['database'].table[:,i]]/ \ 22 | n[self.state['database'].table[:,i]] 23 | p_db = np.sum(np.log(p_s), axis=1) 24 | p_db = np.exp(p_db - np.min(p_db)) 25 | p_db = p_db/p_db.sum() 26 | return p_db 27 | 28 | def _inform(self, db_index): 29 | probs = self._check_db_soft() 30 | return np.argsort(probs)[::-1].tolist() 31 | 32 | ''' query DB based on current known slots ''' 33 | def _check_db(self): 34 | # from query to db form current inform_slots 35 | db_query = [] 36 | for s in self.state['database'].slots: 37 | if s in self.state['slot_tracker'] and s in self.state['inform_slots']: 38 | max_i = np.argmax(self.state['inform_slots'][s]) 39 | max_key = self.movie_dict.dict[s][max_i] 40 | db_query.append(max_key) 41 | else: 42 | db_query.append(None) 43 | matches, index = self.state['database'].lookup(db_query) 44 | return matches, index 45 | 46 | -------------------------------------------------------------------------------- /deep_dialog/agents/softDB.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | 4 | import numpy as np 5 | 6 | class SoftDB: 7 | def _inform(self, probs): 8 | return np.argsort(probs)[::-1].tolist() 9 | 10 | ''' get dist over DB based on current beliefs ''' 11 | def _check_db(self): 12 | # induce disttribution over DB based on current beliefs over slots 13 | probs = {} 14 | p_s = np.zeros((self.state['database'].N, \ 15 | len(self.state['database'].slots))).astype('float32') 16 | for i,s in enumerate(self.state['database'].slots): 17 | p = self.state['inform_slots'][s]/self.state['inform_slots'][s].sum() 18 | n = self.state['database'].inv_counts[s] 19 | p_unk = float(n[-1])/self.state['database'].N 20 | p_tilde = p*(1.-p_unk) 21 | p_tilde = np.concatenate([p_tilde,np.asarray([p_unk])]) 22 | p_s[:,i] = p_tilde[self.state['database'].table[:,i]]/ \ 23 | n[self.state['database'].table[:,i]] 24 | p_db = np.sum(np.log(p_s), axis=1) 25 | p_db = np.exp(p_db - np.min(p_db)) 26 | p_db = p_db/p_db.sum() 27 | return p_db 28 | -------------------------------------------------------------------------------- /deep_dialog/agents/utils.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from deep_dialog import tools 3 | import numpy as np 4 | import time 5 | 6 | def standardize(arr): 7 | return arr 8 | 9 | def calc_entropies(state, q, db): 10 | entropies = {} 11 | for s,c in state.iteritems(): 12 | if s not in db.slots: 13 | entropies[s] = 0. 14 | else: 15 | p = (db.ids[s]*q).sum(axis=1) 16 | u = db.priors[s]*q[db.unks[s]].sum() 17 | c_tilde = p+u 18 | c_tilde = c_tilde/c_tilde.sum() 19 | entropies[s] = tools.entropy_p(c_tilde) 20 | return entropies 21 | -------------------------------------------------------------------------------- /deep_dialog/dialog_config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | 4 | all_acts = ['request', 'inform'] 5 | inform_slots = ['actor','critic_rating','genre','mpaa_rating','director','release_year'] 6 | 7 | sys_request_slots = ['actor', 'critic_rating', 'genre', 'mpaa_rating', 'director', 'release_year'] 8 | 9 | start_dia_acts = { 10 | #'greeting':[], 11 | 'request':['moviename', 'starttime', 'theater', 'city', 'state', 'date', 'genre', 'ticket', 'numberofpeople', 'numberofkids'] 12 | } 13 | 14 | #reward information 15 | FAILED_DIALOG_REWARD = -1 16 | SUCCESS_DIALOG_REWARD = 2 17 | PER_TURN_REWARD = -0.1 18 | SUCCESS_MAX_RANK = 5 19 | MAX_TURN = 10 20 | 21 | MODEL_PATH = './data/pretrained/' 22 | -------------------------------------------------------------------------------- /deep_dialog/dialog_system/__init__.py: -------------------------------------------------------------------------------- 1 | from .movie_dict import * 2 | from .database import * 3 | from .dialog_manager import * 4 | from .dict_reader import * 5 | -------------------------------------------------------------------------------- /deep_dialog/dialog_system/database.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 3 | Class for database 4 | ''' 5 | 6 | import csv 7 | import io 8 | import numpy as np 9 | import nltk 10 | import time 11 | 12 | from collections import defaultdict 13 | from deep_dialog import dialog_config 14 | from deep_dialog.tools import to_tokens 15 | 16 | class Database: 17 | def __init__(self, path, dicts, name=''): 18 | self.path = path 19 | self.name = name 20 | self._load_db(path) 21 | self._shuffle() 22 | self._build_inv_index(dicts) 23 | self._build_table(dicts) 24 | self._get_priors() 25 | self._prepare_for_entropy(dicts) 26 | self._prepare_for_search() 27 | 28 | def _load_db(self, path): 29 | try: 30 | fi = io.open(path,'r') 31 | self.slots = fi.readline().rstrip().split('\t')[1:] 32 | tupl = [line.rstrip().split('\t') for line in fi] 33 | self.labels = [t[0] for t in tupl] 34 | self.tuples = [t[1:] for t in tupl] 35 | fi.close() 36 | except UnicodeDecodeError: 37 | fi = open(path,'r') 38 | self.slots = fi.readline().rstrip().split('\t')[1:] 39 | tupl = [line.rstrip().split('\t') for line in fi] 40 | self.labels = [t[0] for t in tupl] 41 | self.tuples = [t[1:] for t in tupl] 42 | fi.close() 43 | self.N = len(self.tuples) 44 | 45 | def _shuffle(self): 46 | # match slot order to config 47 | index = [self.slots.index(s) for s in dialog_config.inform_slots] 48 | self.slots = [self.slots[ii] for ii in index] 49 | self.tuples = [[row[ii] for ii in index] for row in self.tuples] 50 | 51 | def lookup(self, query, match_unk=True): 52 | def _iseq(t1, t2): 53 | for i in range(len(t1)): 54 | if t1[i]!=t2[i] and t1[i]!='UNK' and t2[i]!='UNK': 55 | return False 56 | return True 57 | col_idx = [ii for ii,vv in enumerate(query) if vv is not None] 58 | c_db = [[row[ii] for ii in col_idx] for row in self.tuples] 59 | c_q = [query[ii] for ii in col_idx] 60 | if match_unk: row_match_idx = [ii for ii,ll in enumerate(c_db) if _iseq(ll,c_q)] 61 | else: row_match_idx = [ii for ii,ll in enumerate(c_db) if ll==c_q] 62 | results = [self.tuples[ii] for ii in row_match_idx] 63 | return results, row_match_idx 64 | 65 | def delete_slot(self, slot): 66 | try: 67 | slot_index = self.slots.index(slot) 68 | except ValueError: 69 | print 'Slot not found!!!' 70 | return 71 | for row in self.tuples: del row[slot_index] 72 | self.table = np.delete(self.table, slot_index, axis=1) 73 | self.counts = np.delete(self.counts, slot_index, axis=1) 74 | del self.slots[slot_index] 75 | 76 | def _build_inv_index(self, dicts): 77 | self.inv_index = {} 78 | self.inv_counts = {} 79 | for i,slot in enumerate(self.slots): 80 | V = dicts.lengths[slot] 81 | self.inv_index[slot] = defaultdict(list) 82 | self.inv_counts[slot] = np.zeros((V+1,)).astype('float32') 83 | values = [t[i] for t in self.tuples] 84 | for j,v in enumerate(values): 85 | v_id = dicts.dict[slot].index(v) if v!='UNK' else V 86 | self.inv_index[slot][v].append(j) 87 | self.inv_counts[slot][v_id] += 1 88 | 89 | def _build_table(self, dicts): 90 | self.table = np.zeros((len(self.tuples),len(self.slots))).astype('int16') 91 | self.counts = np.zeros((len(self.tuples),len(self.slots))).astype('float32') 92 | for i,t in enumerate(self.tuples): 93 | for j,v in enumerate(t): 94 | s = self.slots[j] 95 | self.table[i,j] = dicts.dict[s].index(v) if v!='UNK' else dicts.lengths[s] 96 | self.counts[i,j] = self.inv_counts[s][self.table[i,j]] 97 | 98 | def _get_priors(self): 99 | self.priors = {slot:self.inv_counts[slot][:-1]/self.inv_counts[slot][:-1].sum() \ 100 | for slot in self.slots} 101 | 102 | def _prepare_for_entropy(self, dicts): 103 | self.ids = {} 104 | self.ns = {} 105 | self.non0 = {} 106 | self.unks = {} 107 | for i,s in enumerate(self.slots): 108 | V = dicts.lengths[s] 109 | db_c = self.table[:,i] 110 | self.unks[s] = np.where(db_c==V)[0] 111 | self.ids[s] = (np.mgrid[:self.priors[s].shape[0],:self.N]==db_c)[0] 112 | self.ns[s] = self.ids[s].sum(axis=1) 113 | self.non0[s] = np.nonzero(self.ns[s])[0] 114 | 115 | def _prepare_for_search(self): 116 | self.slot_tokens = {} 117 | for slot in self.slots: 118 | self.slot_tokens[slot] = to_tokens(slot) 119 | -------------------------------------------------------------------------------- /deep_dialog/dialog_system/dialog_manager.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | import random, time 4 | from . import MovieDict, Database 5 | from deep_dialog import dialog_config 6 | 7 | 8 | class DialogManager: 9 | def __init__(self, agent, user, db_full, db_inc, movie_kb, verbose=True): 10 | self.agent = agent 11 | self.user = user 12 | self.user_action = None 13 | self.database = db_full 14 | self.database_incomplete = db_inc 15 | self.verbose = verbose 16 | self.movie_dict = movie_kb 17 | 18 | def initialize_episode(self): 19 | while True: 20 | self.user_action = self.user.initialize_episode() 21 | if self._check_user_goal()<=dialog_config.SUCCESS_MAX_RANK: break 22 | self.agent.initialize_episode() 23 | if self.verbose: self.user.print_goal() 24 | return self.user_action 25 | 26 | def next_turn(self): 27 | if self.verbose: 28 | print 'Turn', self.user_action['turn'], 'user action:', self.user_action['diaact'], \ 29 | '\t', 'inform slots:', self.user_action['inform_slots'] 30 | print 'Utterance:', self.user_action['nl_sentence'], '\n' 31 | 32 | self.sys_actions = self.agent.next(self.user_action, verbose=self.verbose) 33 | 34 | self.sys_actions['turn'] = self.user_action['turn'] + 1 35 | if self.verbose: 36 | print("Turn %d sys action: %s, request slots: %s" % \ 37 | (self.sys_actions['turn'], self.sys_actions['diaact'], \ 38 | self.sys_actions['request_slots']) + '\n') 39 | 40 | self.user_action, episode_over, reward = self.user.next(self.sys_actions) 41 | if episode_over: self.agent.terminate_episode(self.user_action) 42 | if episode_over and self.verbose: 43 | print("Agent Results:") 44 | if 'phis' in self.sys_actions: print '\t'.join(['dont-care:']+['%.3f'%s for s in self.sys_actions['phis']]) 45 | if self.sys_actions['target']: 46 | for ii in self.sys_actions['target'][:dialog_config.SUCCESS_MAX_RANK]: 47 | out = [self.database_incomplete.labels[ii]] 48 | for it,slot in enumerate(self.database_incomplete.slots): 49 | if 'probs' in self.sys_actions: 50 | sidx = dialog_config.inform_slots.index(slot) 51 | val = self.database_incomplete.tuples[ii][it] 52 | idx = self.movie_dict.dict[slot].index(val) if val!='UNK' else \ 53 | len(self.movie_dict.dict[slot]) 54 | count = self.database_incomplete.inv_counts[slot][idx] 55 | out.append('%s(%.3f/%d)'%(val,self.sys_actions['probs'][sidx].flatten()[idx], \ 56 | count)) 57 | else: 58 | val = self.database_incomplete.tuples[ii][it] 59 | out.append('%s'%val) 60 | print('\t'.join([o.encode('latin-1', 'replace') for o in out])) 61 | 62 | return (episode_over, reward, self.user_action, self.sys_actions) 63 | 64 | def check_db(self): 65 | db_query = [] 66 | for s in self.database.slots: 67 | if s in self.sys_actions['inform_slots']: 68 | db_query.append(self.sys_actions['inform_slots'][s]) 69 | elif s in self.user.goal['inform_slots']: 70 | db_query.append(self.user.goal['inform_slots'][s]) 71 | else: 72 | db_query.append(None) 73 | matches = self.database.lookup(db_query) 74 | if len(matches) > 0: 75 | return True 76 | else: 77 | return False 78 | 79 | def _check_user_goal(self): 80 | db_query = [] 81 | for s in self.database.slots: 82 | if s in self.user.goal['inform_slots']: 83 | db_query.append(self.user.goal['inform_slots'][s]) 84 | else: 85 | db_query.append(None) 86 | matches,_ = self.database.lookup(db_query, match_unk=False) 87 | return len(matches) 88 | -------------------------------------------------------------------------------- /deep_dialog/dialog_system/dict_reader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | 4 | 5 | class DictReader: 6 | def __init__(self): 7 | pass 8 | 9 | def load_dict_from_file(self, path): 10 | slot_set = {} 11 | 12 | file = open(path, 'r') 13 | index = 0 14 | for line in file: 15 | slot_set[line.strip('\n').strip('\r')] = index 16 | index += 1 17 | 18 | self.dict = slot_set 19 | 20 | 21 | def load_dict_from_array(self, array): 22 | slot_set = {} 23 | for index, ele in enumerate(array): 24 | slot_set[ele.strip('\n').strip('\r')] = index 25 | 26 | self.dict = slot_set 27 | -------------------------------------------------------------------------------- /deep_dialog/dialog_system/movie_dict.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 3 | ''' 4 | 5 | import cPickle as pickle 6 | import copy 7 | import nltk 8 | import string 9 | 10 | from collections import defaultdict 11 | from deep_dialog.tools import to_tokens 12 | 13 | class MovieDict: 14 | def __init__(self, path): 15 | self.load_dict(path) 16 | self.count_values() 17 | self._build_token_index() 18 | 19 | def load_dict(self, path): 20 | dict_data = pickle.load(open(path, 'rb')) 21 | self.dict = copy.deepcopy(dict_data) 22 | 23 | def count_values(self): 24 | self.lengths = {} 25 | for k,v in self.dict.iteritems(): 26 | self.lengths[k] = len(v) 27 | 28 | def _build_token_index(self): 29 | self.tokens = {} 30 | for slot,vals in self.dict.iteritems(): 31 | self.tokens[slot] = defaultdict(list) 32 | for vi,vv in enumerate(vals): 33 | w_v = to_tokens(vv) 34 | for w in w_v: self.tokens[slot][w].append(vi) 35 | -------------------------------------------------------------------------------- /deep_dialog/objects/__init__.py: -------------------------------------------------------------------------------- 1 | from .slot_reader import * 2 | -------------------------------------------------------------------------------- /deep_dialog/objects/slot_reader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created - June 3, 2016 3 | Author - t-bhdhi 4 | ''' 5 | 6 | class SlotReader: 7 | def __init__(self, path): 8 | self._load(path) 9 | self._invert() 10 | self.num_slots = len(self.slot_ids) 11 | 12 | def _load(self, path): 13 | self.slot_groups = {} 14 | self.slot_ids = {} 15 | n = 0 16 | f = open(path,'r') 17 | for line in f: 18 | sl = line.rstrip().split() 19 | i = 0 20 | for s in sl: 21 | self.slot_groups[s] = n 22 | self.slot_ids[s] = i # 0-head, 1-nonhead 23 | i = 1 24 | n += 1 25 | f.close() 26 | 27 | def _invert(self): 28 | # create inverted index of groups to slots 29 | self.group_slots = {} 30 | for s in self.slot_ids.keys(): 31 | if self.slot_groups[s] not in self.group_slots: 32 | self.group_slots[self.slot_groups[s]] = [s] 33 | else: 34 | self.group_slots[self.slot_groups[s]].append(s) 35 | -------------------------------------------------------------------------------- /deep_dialog/tools.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | 4 | from collections import Counter 5 | import math 6 | import numpy as np 7 | import sys 8 | import string 9 | import nltk 10 | from nltk.corpus import stopwords 11 | 12 | EXC = set(string.punctuation) 13 | 14 | def to_tokens(text): 15 | utt = ''.join(ch for ch in text if ch not in EXC) 16 | tokens = nltk.word_tokenize(utt.lower()) 17 | return [w for w in tokens if w not in stopwords.words('english')] 18 | 19 | def entropy(items): 20 | if type(items) is Counter or type(items) is dict: 21 | P = items 22 | elif type(items) is list: 23 | P = Counter(items) 24 | if not P: 25 | # empty distribution 26 | return -1 27 | H = 0. 28 | N = 0. 29 | for v in P.values(): 30 | if v==0: 31 | continue 32 | H -= v*math.log(v,2) 33 | N += v 34 | if N==0: 35 | return -1 36 | H = (H/N) + math.log(N,2) 37 | if math.isnan(H): 38 | print '\n'.join(['%s:%.7f' %(k,v) for k,v in items.iteritems()]) 39 | sys.exit() 40 | return H 41 | 42 | def entropy_p(p): 43 | return np.sum(-p*np.nan_to_num(np.log2(p))) 44 | 45 | def categorical_sample(probs): 46 | x = np.random.uniform() 47 | s = probs[0] 48 | i = 0 49 | while s 0: 88 | for p in self.regularize: 89 | mat = self.model[p] 90 | reg_cost += 0.5*regc*np.sum(mat*mat) 91 | grads[p] += regc*mat 92 | 93 | # normalize the cost and gradient by the batch size 94 | batch_size = len(batch) 95 | reg_cost /= batch_size 96 | loss_cost /= batch_size 97 | for k in grads: grads[k] /= batch_size 98 | 99 | out = {} 100 | out['cost'] = {'reg_cost' : reg_cost, 'loss_cost' : loss_cost, 'total_cost' : loss_cost + reg_cost} 101 | out['grads'] = grads 102 | return out 103 | 104 | 105 | """ A single batch """ 106 | def singleBatch(self, ds, batch, params): 107 | learning_rate = params.get('learning_rate', 0.0) 108 | decay_rate = params.get('decay_rate', 0.999) 109 | momentum = params.get('momentum', 0) 110 | grad_clip = params.get('grad_clip', 1) 111 | smooth_eps = params.get('smooth_eps', 1e-8) 112 | sdg_type = params.get('sdgtype', 'rmsprop') 113 | 114 | for u in self.update: 115 | if not u in self.step_cache: 116 | self.step_cache[u] = np.zeros(self.model[u].shape) 117 | 118 | cg = self.costFunc(ds, batch, params) 119 | 120 | cost = cg['cost'] 121 | grads = cg['grads'] 122 | 123 | # clip gradients if needed 124 | if params['activation_func'] == 'relu': 125 | if grad_clip > 0: 126 | for p in self.update: 127 | if p in grads: 128 | grads[p] = np.minimum(grads[p], grad_clip) 129 | grads[p] = np.maximum(grads[p], -grad_clip) 130 | 131 | # perform parameter update 132 | for p in self.update: 133 | if p in grads: 134 | if sdg_type == 'vanilla': 135 | if momentum > 0: dx = momentum*self.step_cache[p] - learning_rate*grads[p] 136 | else: dx = -learning_rate*grads[p] 137 | self.step_cache[p] = dx 138 | elif sdg_type == 'rmsprop': 139 | self.step_cache[p] = self.step_cache[p]*decay_rate + (1.0-decay_rate)*grads[p]**2 140 | dx = -(learning_rate*grads[p])/np.sqrt(self.step_cache[p] + smooth_eps) 141 | elif sdg_type == 'adgrad': 142 | self.step_cache[p] += grads[p]**2 143 | dx = -(learning_rate*grads[p])/np.sqrt(self.step_cache[p] + smooth_eps) 144 | 145 | self.model[p] += dx 146 | 147 | # create output dict and return 148 | out = {} 149 | out['cost'] = cost 150 | return out 151 | 152 | 153 | """ Evaluate on the dataset[split] """ 154 | def eval(self, ds, split, params): 155 | acc = 0 156 | total = 0 157 | 158 | total_cost = 0.0 159 | smooth_cost = 1e-15 160 | perplexity = 0 161 | 162 | for i, ele in enumerate(ds.split[split]): 163 | #ele_reps = self.prepare_input_rep(ds, [ele], params) 164 | #Ys, cache = self.fwdPass(ele_reps[0], params, predict_model=True) 165 | #labels = np.array(ele_reps[0]['labels'], dtype=int) 166 | 167 | Ys, cache = self.fwdPass(ele, params, predict_model=True) 168 | 169 | maxes = np.amax(Ys, axis=1, keepdims=True) 170 | e = np.exp(Ys - maxes) # for numerical stability shift into good numerical range 171 | probs = e/np.sum(e, axis=1, keepdims=True) 172 | 173 | labels = np.array(ele['labels'], dtype=int) 174 | 175 | if np.all(np.isnan(probs)): probs = np.zeros(probs.shape) 176 | 177 | log_perplex = 0 178 | log_perplex += -np.sum(np.log2(smooth_cost + probs[range(len(labels)), labels])) 179 | log_perplex /= len(labels) 180 | 181 | loss_cost = 0 182 | loss_cost += -np.sum(np.log(smooth_cost + probs[range(len(labels)), labels])) 183 | 184 | perplexity += log_perplex #2**log_perplex 185 | total_cost += loss_cost 186 | 187 | pred_words_indices = np.nanargmax(probs, axis=1) 188 | for index, l in enumerate(labels): 189 | if pred_words_indices[index] == l: 190 | acc += 1 191 | 192 | total += len(labels) 193 | 194 | perplexity /= len(ds.split[split]) 195 | total_cost /= len(ds.split[split]) 196 | accuracy = 0 if total == 0 else float(acc)/total 197 | 198 | #print ("perplexity: %s, total_cost: %s, accuracy: %s" % (perplexity, total_cost, accuracy)) 199 | result = {'perplexity': perplexity, 'cost': total_cost, 'accuracy': accuracy} 200 | return result 201 | 202 | 203 | """ prediction on dataset[split] """ 204 | def predict(self, ds, split, params): 205 | inverse_word_dict = {ds.data['word_dict'][k]:k for k in ds.data['word_dict'].keys()} 206 | for i, ele in enumerate(ds.split[split]): 207 | pred_ys, pred_words = self.forward(inverse_word_dict, ele, params, predict_model=True) 208 | 209 | sentence = ' '.join(pred_words[:-1]) 210 | real_sentence = ' '.join(ele['sentence'].split(' ')[1:-1]) 211 | 212 | if params['dia_slot_val'] == 2 or params['dia_slot_val'] == 3: 213 | sentence = self.post_process(sentence, ele['slotval'], ds.data['slot_dict']) 214 | 215 | print 'test case', i 216 | print 'real:', real_sentence 217 | print 'pred:', sentence 218 | 219 | """ post_process to fill the slot """ 220 | def post_process(self, pred_template, slot_val_dict, slot_dict): 221 | sentence = pred_template 222 | suffix = "_PLACEHOLDER" 223 | 224 | for slot in slot_val_dict.keys(): 225 | slot_vals = slot_val_dict[slot] 226 | slot_placeholder = slot + suffix 227 | if slot == 'result' or slot == 'numberofpeople': continue 228 | for slot_val in slot_vals: 229 | tmp_sentence = sentence.replace(slot_placeholder, slot_val, 1) 230 | sentence = tmp_sentence 231 | 232 | if 'numberofpeople' in slot_val_dict.keys(): 233 | slot_vals = slot_val_dict['numberofpeople'] 234 | slot_placeholder = 'numberofpeople' + suffix 235 | for slot_val in slot_vals: 236 | tmp_sentence = sentence.replace(slot_placeholder, slot_val, 1) 237 | sentence = tmp_sentence 238 | 239 | for slot in slot_dict.keys(): 240 | slot_placeholder = slot + suffix 241 | tmp_sentence = sentence.replace(slot_placeholder, '') 242 | sentence = tmp_sentence 243 | 244 | return sentence 245 | -------------------------------------------------------------------------------- /deep_dialog/usersims/NLG/decoders/lstm_decoder_tanh.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 3 | An LSTM decoder - add tanh after cell before output gate 4 | 5 | ''' 6 | 7 | from .decoder import decoder 8 | from .utils import * 9 | 10 | TEMP = 2.5 11 | 12 | class lstm_decoder_tanh(decoder): 13 | def __init__(self, diaact_input_size, input_size, hidden_size, output_size): 14 | self.model = {} 15 | # connections from diaact to hidden layer 16 | self.model['Wah'] = initWeights(diaact_input_size, 4*hidden_size) 17 | self.model['bah'] = np.zeros((1, 4*hidden_size)) 18 | 19 | # Recurrent weights: take x_t, h_{t-1}, and bias unit, and produce the 3 gates and the input to cell signal 20 | self.model['WLSTM'] = initWeights(input_size + hidden_size + 1, 4*hidden_size) 21 | # Hidden-Output Connections 22 | self.model['Wd'] = initWeights(hidden_size, output_size)*0.1 23 | self.model['bd'] = np.zeros((1, output_size)) 24 | 25 | self.update = ['Wah', 'bah', 'WLSTM', 'Wd', 'bd'] 26 | self.regularize = ['Wah', 'WLSTM', 'Wd'] 27 | 28 | self.step_cache = {} 29 | 30 | """ Activation Function: Sigmoid, or tanh, or ReLu """ 31 | def fwdPass(self, Xs, params, **kwargs): 32 | predict_mode = kwargs.get('predict_mode', False) 33 | feed_recurrence = params.get('feed_recurrence', 0) 34 | 35 | Ds = Xs['diaact'] 36 | Ws = Xs['words'] 37 | 38 | # diaact input layer to hidden layer 39 | Wah = self.model['Wah'] 40 | bah = self.model['bah'] 41 | Dsh = Ds.dot(Wah) + bah 42 | 43 | WLSTM = self.model['WLSTM'] 44 | n, xd = Ws.shape 45 | 46 | d = self.model['Wd'].shape[0] # size of hidden layer 47 | Hin = np.zeros((n, WLSTM.shape[0])) # xt, ht-1, bias 48 | Hout = np.zeros((n, d)) 49 | IFOG = np.zeros((n, 4*d)) 50 | IFOGf = np.zeros((n, 4*d)) # after nonlinearity 51 | Cellin = np.zeros((n, d)) 52 | Cellout = np.zeros((n, d)) 53 | 54 | for t in xrange(n): 55 | prev = np.zeros(d) if t==0 else Hout[t-1] 56 | Hin[t,0] = 1 # bias 57 | Hin[t, 1:1+xd] = Ws[t] 58 | Hin[t, 1+xd:] = prev 59 | 60 | # compute all gate activations. dots: 61 | IFOG[t] = Hin[t].dot(WLSTM) 62 | 63 | # add diaact vector here 64 | if feed_recurrence == 0: 65 | if t == 0: IFOG[t] += Dsh[0] 66 | else: 67 | IFOG[t] += Dsh[0] 68 | 69 | IFOGf[t, :3*d] = 1/(1+np.exp(-IFOG[t, :3*d])) # sigmoids; these are three gates 70 | IFOGf[t, 3*d:] = np.tanh(IFOG[t, 3*d:]) # tanh for input value 71 | 72 | Cellin[t] = IFOGf[t, :d] * IFOGf[t, 3*d:] 73 | if t>0: Cellin[t] += IFOGf[t, d:2*d]*Cellin[t-1] 74 | 75 | Cellout[t] = np.tanh(Cellin[t]) 76 | 77 | Hout[t] = IFOGf[t, 2*d:3*d] * Cellout[t] 78 | 79 | Wd = self.model['Wd'] 80 | bd = self.model['bd'] 81 | 82 | Y = Hout.dot(Wd)+bd 83 | 84 | cache = {} 85 | if not predict_mode: 86 | cache['WLSTM'] = WLSTM 87 | cache['Hout'] = Hout 88 | cache['WLSTM'] = WLSTM 89 | cache['Wd'] = Wd 90 | cache['IFOGf'] = IFOGf 91 | cache['IFOG'] = IFOG 92 | cache['Cellin'] = Cellin 93 | cache['Cellout'] = Cellout 94 | cache['Ws'] = Ws 95 | cache['Ds'] = Ds 96 | cache['Hin'] = Hin 97 | cache['Dsh'] = Dsh 98 | cache['Wah'] = Wah 99 | cache['feed_recurrence'] = feed_recurrence 100 | 101 | return Y, cache 102 | 103 | """ Forward pass on prediction """ 104 | def forward(self, dict, Xs, params, **kwargs): 105 | max_len = params.get('max_len', 30) 106 | feed_recurrence = params.get('feed_recurrence', 0) 107 | decoder_sampling = params.get('decoder_sampling', 0) 108 | 109 | Ds = Xs['diaact'] 110 | Ws = Xs['words'] 111 | 112 | # diaact input layer to hidden layer 113 | Wah = self.model['Wah'] 114 | bah = self.model['bah'] 115 | Dsh = Ds.dot(Wah) + bah 116 | 117 | WLSTM = self.model['WLSTM'] 118 | xd = Ws.shape[1] 119 | 120 | d = self.model['Wd'].shape[0] # size of hidden layer 121 | Hin = np.zeros((1, WLSTM.shape[0])) # xt, ht-1, bias 122 | Hout = np.zeros((1, d)) 123 | IFOG = np.zeros((1, 4*d)) 124 | IFOGf = np.zeros((1, 4*d)) # after nonlinearity 125 | Cellin = np.zeros((1, d)) 126 | Cellout = np.zeros((1, d)) 127 | 128 | Wd = self.model['Wd'] 129 | bd = self.model['bd'] 130 | 131 | Hin[0,0] = 1 # bias 132 | Hin[0,1:1+xd] = Ws[0] 133 | 134 | IFOG[0] = Hin[0].dot(WLSTM) 135 | IFOG[0] += Dsh[0] 136 | 137 | IFOGf[0, :3*d] = 1/(1+np.exp(-IFOG[0, :3*d])) # sigmoids; these are three gates 138 | IFOGf[0, 3*d:] = np.tanh(IFOG[0, 3*d:]) # tanh for input value 139 | 140 | Cellin[0] = IFOGf[0, :d] * IFOGf[0, 3*d:] 141 | Cellout[0] = np.tanh(Cellin[0]) 142 | Hout[0] = IFOGf[0, 2*d:3*d] * Cellout[0] 143 | 144 | pred_y = [] 145 | pred_words = [] 146 | 147 | Y = Hout.dot(Wd) + bd 148 | maxes = np.amax(Y, axis=1, keepdims=True) 149 | e = np.exp(Y - maxes) # for numerical stability shift into good numerical range 150 | probs = e/np.sum(e, axis=1, keepdims=True) 151 | 152 | if decoder_sampling == 0: # sampling or argmax 153 | pred_y_index = np.nanargmax(Y) 154 | else: 155 | pred_y_index = np.random.choice(Y.shape[1], 1, p=probs[0])[0] 156 | pred_y.append(pred_y_index) 157 | pred_words.append(dict[pred_y_index]) 158 | 159 | time_stamp = 0 160 | while True: 161 | if dict[pred_y_index] == 'e_o_s' or time_stamp >= max_len: break 162 | 163 | X = np.zeros(xd) 164 | X[pred_y_index] = 1 165 | Hin[0,0] = 1 # bias 166 | Hin[0,1:1+xd] = X 167 | Hin[0, 1+xd:] = Hout[0] 168 | 169 | IFOG[0] = Hin[0].dot(WLSTM) 170 | if feed_recurrence == 1: 171 | IFOG[0] += Dsh[0] 172 | 173 | IFOGf[0, :3*d] = 1/(1+np.exp(-IFOG[0, :3*d])) # sigmoids; these are three gates 174 | IFOGf[0, 3*d:] = np.tanh(IFOG[0, 3*d:]) # tanh for input value 175 | 176 | C = IFOGf[0, :d]*IFOGf[0, 3*d:] 177 | Cellin[0] = C + IFOGf[0, d:2*d]*Cellin[0] 178 | Cellout[0] = np.tanh(Cellin[0]) 179 | Hout[0] = IFOGf[0, 2*d:3*d]*Cellout[0] 180 | 181 | Y = Hout.dot(Wd) + bd 182 | maxes = np.amax(Y, axis=1, keepdims=True) 183 | e = np.exp(Y - maxes) # for numerical stability shift into good numerical range 184 | probs = e/np.sum(e, axis=1, keepdims=True) 185 | 186 | if decoder_sampling == 0: 187 | pred_y_index = np.nanargmax(Y) 188 | else: 189 | pred_y_index = np.random.choice(Y.shape[1], 1, p=probs[0])[0] 190 | pred_y.append(pred_y_index) 191 | pred_words.append(dict[pred_y_index]) 192 | 193 | time_stamp += 1 194 | 195 | return pred_y, pred_words 196 | 197 | """ Forward pass on prediction with Beam Search """ 198 | def beam_forward(self, dict, Xs, params, **kwargs): 199 | max_len = params.get('max_len', 30) 200 | feed_recurrence = params.get('feed_recurrence', 0) 201 | beam_size = params.get('beam_size', 10) 202 | decoder_sampling = params.get('decoder_sampling', 0) 203 | temp = params.get('temp', 1.) 204 | 205 | Ds = Xs['diaact'] 206 | Ws = Xs['words'] 207 | 208 | # diaact input layer to hidden layer 209 | Wah = self.model['Wah'] 210 | bah = self.model['bah'] 211 | Dsh = Ds.dot(Wah) + bah 212 | 213 | WLSTM = self.model['WLSTM'] 214 | xd = Ws.shape[1] 215 | 216 | d = self.model['Wd'].shape[0] # size of hidden layer 217 | Hin = np.zeros((1, WLSTM.shape[0])) # xt, ht-1, bias 218 | Hout = np.zeros((1, d)) 219 | IFOG = np.zeros((1, 4*d)) 220 | IFOGf = np.zeros((1, 4*d)) # after nonlinearity 221 | Cellin = np.zeros((1, d)) 222 | Cellout = np.zeros((1, d)) 223 | 224 | Wd = self.model['Wd'] 225 | bd = self.model['bd'] 226 | 227 | Hin[0,0] = 1 # bias 228 | Hin[0,1:1+xd] = Ws[0] 229 | 230 | IFOG[0] = Hin[0].dot(WLSTM) 231 | IFOG[0] += Dsh[0] 232 | 233 | IFOGf[0, :3*d] = 1/(1+np.exp(-IFOG[0, :3*d])) # sigmoids; these are three gates 234 | IFOGf[0, 3*d:] = np.tanh(IFOG[0, 3*d:]) # tanh for input value 235 | 236 | Cellin[0] = IFOGf[0, :d] * IFOGf[0, 3*d:] 237 | Cellout[0] = np.tanh(Cellin[0]) 238 | Hout[0] = IFOGf[0, 2*d:3*d] * Cellout[0] 239 | 240 | # keep a beam here 241 | beams = [] 242 | 243 | Y = Hout.dot(Wd) + bd 244 | maxes = np.amax(Y, axis=1, keepdims=True) 245 | e = np.exp((Y - maxes)/temp) # for numerical stability shift into good numerical range 246 | probs = e/np.sum(e, axis=1, keepdims=True) 247 | 248 | # add beam search here 249 | if decoder_sampling == 0: # no sampling 250 | beam_candidate_t = (-probs[0]).argsort()[:beam_size] 251 | else: 252 | beam_candidate_t = np.random.choice(Y.shape[1], beam_size, p=probs[0]) 253 | #beam_candidate_t = (-probs[0]).argsort()[:beam_size] 254 | for ele in beam_candidate_t: 255 | beams.append((np.log(probs[0][ele]), [ele], [dict[ele]], Hout[0], Cellin[0])) 256 | 257 | #beams.sort(key=lambda x:x[0], reverse=True) 258 | #beams.sort(reverse = True) 259 | 260 | time_stamp = 0 261 | while True: 262 | beam_candidates = [] 263 | for b in beams: 264 | log_prob = b[0] 265 | pred_y_index = b[1][-1] 266 | cell_in = b[4] 267 | hout_prev = b[3] 268 | 269 | if b[2][-1] == "e_o_s": # this beam predicted end token. Keep in the candidates but don't expand it out any more 270 | beam_candidates.append(b) 271 | continue 272 | 273 | X = np.zeros(xd) 274 | X[pred_y_index] = 1 275 | Hin[0,0] = 1 # bias 276 | Hin[0,1:1+xd] = X 277 | Hin[0, 1+xd:] = hout_prev 278 | 279 | IFOG[0] = Hin[0].dot(WLSTM) 280 | if feed_recurrence == 1: IFOG[0] += Dsh[0] 281 | 282 | IFOGf[0, :3*d] = 1/(1+np.exp(-IFOG[0, :3*d])) # sigmoids; these are three gates 283 | IFOGf[0, 3*d:] = np.tanh(IFOG[0, 3*d:]) # tanh for input value 284 | 285 | C = IFOGf[0, :d]*IFOGf[0, 3*d:] 286 | cell_in = C + IFOGf[0, d:2*d]*cell_in 287 | cell_out = np.tanh(cell_in) 288 | hout_prev = IFOGf[0, 2*d:3*d]*cell_out 289 | 290 | Y = hout_prev.dot(Wd) + bd 291 | maxes = np.amax(Y, axis=1, keepdims=True) 292 | e = np.exp((Y - maxes)/temp) # for numerical stability shift into good numerical range 293 | probs = e/np.sum(e, axis=1, keepdims=True) 294 | 295 | if decoder_sampling == 0: # no sampling 296 | beam_candidate_t = (-probs[0]).argsort()[:beam_size] 297 | else: 298 | beam_candidate_t = np.random.choice(Y.shape[1], beam_size, p=probs[0]) 299 | #beam_candidate_t = (-probs[0]).argsort()[:beam_size] 300 | for ele in beam_candidate_t: 301 | beam_candidates.append((log_prob+np.log(probs[0][ele]), np.append(b[1], ele), np.append(b[2], dict[ele]), hout_prev, cell_in)) 302 | 303 | beam_candidates.sort(key=lambda x:x[0], reverse=True) 304 | #beam_candidates.sort(reverse = True) # decreasing order 305 | beams = beam_candidates[:beam_size] 306 | time_stamp += 1 307 | 308 | if time_stamp >= max_len: break 309 | 310 | return beams[0][1], beams[0][2] 311 | 312 | """ Backward Pass """ 313 | def bwdPass(self, dY, cache): 314 | Wd = cache['Wd'] 315 | Hout = cache['Hout'] 316 | IFOG = cache['IFOG'] 317 | IFOGf = cache['IFOGf'] 318 | Cellin = cache['Cellin'] 319 | Cellout = cache['Cellout'] 320 | Hin = cache['Hin'] 321 | WLSTM = cache['WLSTM'] 322 | Ws = cache['Ws'] 323 | Ds = cache['Ds'] 324 | Dsh = cache['Dsh'] 325 | Wah = cache['Wah'] 326 | feed_recurrence = cache['feed_recurrence'] 327 | 328 | n,d = Hout.shape 329 | 330 | # backprop the hidden-output layer 331 | dWd = Hout.transpose().dot(dY) 332 | dbd = np.sum(dY, axis=0, keepdims = True) 333 | dHout = dY.dot(Wd.transpose()) 334 | 335 | # backprop the LSTM 336 | dIFOG = np.zeros(IFOG.shape) 337 | dIFOGf = np.zeros(IFOGf.shape) 338 | dWLSTM = np.zeros(WLSTM.shape) 339 | dHin = np.zeros(Hin.shape) 340 | dCellin = np.zeros(Cellin.shape) 341 | dCellout = np.zeros(Cellout.shape) 342 | dWs = np.zeros(Ws.shape) 343 | 344 | dDsh = np.zeros(Dsh.shape) 345 | 346 | for t in reversed(xrange(n)): 347 | dIFOGf[t,2*d:3*d] = Cellout[t] * dHout[t] 348 | dCellout[t] = IFOGf[t,2*d:3*d] * dHout[t] 349 | 350 | dCellin[t] += (1-Cellout[t]**2) * dCellout[t] 351 | 352 | if t>0: 353 | dIFOGf[t, d:2*d] = Cellin[t-1] * dCellin[t] 354 | dCellin[t-1] += IFOGf[t,d:2*d] * dCellin[t] 355 | 356 | dIFOGf[t, :d] = IFOGf[t,3*d:] * dCellin[t] 357 | dIFOGf[t,3*d:] = IFOGf[t, :d] * dCellin[t] 358 | 359 | # backprop activation functions 360 | dIFOG[t, 3*d:] = (1-IFOGf[t, 3*d:]**2) * dIFOGf[t, 3*d:] 361 | y = IFOGf[t, :3*d] 362 | dIFOG[t, :3*d] = (y*(1-y)) * dIFOGf[t, :3*d] 363 | 364 | # backprop matrix multiply 365 | dWLSTM += np.outer(Hin[t], dIFOG[t]) 366 | dHin[t] = dIFOG[t].dot(WLSTM.transpose()) 367 | 368 | if t > 0: dHout[t-1] += dHin[t,1+Ws.shape[1]:] 369 | 370 | if feed_recurrence == 0: 371 | if t == 0: dDsh[t] = dIFOG[t] 372 | else: 373 | dDsh[0] += dIFOG[t] 374 | 375 | # backprop to the diaact-hidden connections 376 | dWah = Ds.transpose().dot(dDsh) 377 | dbah = np.sum(dDsh, axis=0, keepdims = True) 378 | 379 | return {'Wah':dWah, 'bah':dbah, 'WLSTM':dWLSTM, 'Wd':dWd, 'bd':dbd} 380 | 381 | 382 | """ Batch data representation """ 383 | def prepare_input_rep(self, ds, batch, params): 384 | batch_reps = [] 385 | for i,x in enumerate(batch): 386 | batch_rep = {} 387 | 388 | vec = np.zeros((1, self.model['Wah'].shape[0])) 389 | vec[0][x['diaact_rep']] = 1 390 | for v in x['slotrep']: 391 | vec[0][v] = 1 392 | 393 | word_arr = x['sentence'].split(' ') 394 | word_vecs = np.zeros((len(word_arr), self.model['Wxh'].shape[0])) 395 | labels = [0] * (len(word_arr)-1) 396 | for w_index, w in enumerate(word_arr[:-1]): 397 | if w in ds.data['word_dict'].keys(): 398 | w_dict_index = ds.data['word_dict'][w] 399 | word_vecs[w_index][w_dict_index] = 1 400 | 401 | if word_arr[w_index+1] in ds.data['word_dict'].keys(): 402 | labels[w_index] = ds.data['word_dict'][word_arr[w_index+1]] 403 | 404 | batch_rep['diaact'] = vec 405 | batch_rep['words'] = word_vecs 406 | batch_rep['labels'] = labels 407 | batch_reps.append(batch_rep) 408 | return batch_reps 409 | -------------------------------------------------------------------------------- /deep_dialog/usersims/NLG/decoders/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | 4 | import math 5 | import numpy as np 6 | 7 | 8 | def initWeights(n,d): 9 | """ Initialization Strategy """ 10 | #scale_factor = 0.1 11 | scale_factor = math.sqrt(float(6)/(n + d)) 12 | return (np.random.rand(n,d)*2-1)*scale_factor 13 | 14 | def mergeDicts(d0, d1): 15 | """ for all k in d0, d0 += d1 . d's are dictionaries of key -> numpy array """ 16 | for k in d1: 17 | if k in d0: d0[k] += d1[k] 18 | else: d0[k] = d1[k] 19 | -------------------------------------------------------------------------------- /deep_dialog/usersims/NLG/eval/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # $Id: bleu.py 1307 2007-03-14 22:22:36Z hieuhoang1972 $ 4 | 5 | '''Provides: 6 | 7 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 8 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 9 | score_cooked(alltest, n=4): Score a list of cooked test sentences. 10 | 11 | score_set(s, testid, refids, n=4): Interface with dataset.py; calculate BLEU score of testid against refids. 12 | 13 | The reason for breaking the BLEU computation into three phases cook_refs(), cook_test(), and score_cooked() is to allow the caller to calculate BLEU scores for multiple test sets as efficiently as possible. 14 | ''' 15 | 16 | import optparse 17 | import sys, math, re, xml.sax.saxutils 18 | sys.path.append('/fs/clip-mteval/Programs/hiero') 19 | import dataset 20 | import log 21 | 22 | # Added to bypass NIST-style pre-processing of hyp and ref files -- wade 23 | nonorm = 0 24 | 25 | preserve_case = False 26 | eff_ref_len = "shortest" 27 | 28 | normalize1 = [ 29 | ('', ''), # strip "skipped" tags 30 | (r'-\n', ''), # strip end-of-line hyphenation and join lines 31 | (r'\n', ' '), # join lines 32 | # (r'(\d)\s+(?=\d)', r'\1'), # join digits 33 | ] 34 | normalize1 = [(re.compile(pattern), replace) for (pattern, replace) in normalize1] 35 | 36 | normalize2 = [ 37 | (r'([\{-\~\[-\` -\&\(-\+\:-\@\/])',r' \1 '), # tokenize punctuation. apostrophe is missing 38 | (r'([^0-9])([\.,])',r'\1 \2 '), # tokenize period and comma unless preceded by a digit 39 | (r'([\.,])([^0-9])',r' \1 \2'), # tokenize period and comma unless followed by a digit 40 | (r'([0-9])(-)',r'\1 \2 ') # tokenize dash when preceded by a digit 41 | ] 42 | normalize2 = [(re.compile(pattern), replace) for (pattern, replace) in normalize2] 43 | 44 | def normalize(s): 45 | '''Normalize and tokenize text. This is lifted from NIST mteval-v11a.pl.''' 46 | # Added to bypass NIST-style pre-processing of hyp and ref files -- wade 47 | if (nonorm): 48 | return s.split() 49 | if type(s) is not str: 50 | s = " ".join(s) 51 | # language-independent part: 52 | for (pattern, replace) in normalize1: 53 | s = re.sub(pattern, replace, s) 54 | s = xml.sax.saxutils.unescape(s, {'"':'"'}) 55 | # language-dependent part (assuming Western languages): 56 | s = " %s " % s 57 | if not preserve_case: 58 | s = s.lower() # this might not be identical to the original 59 | for (pattern, replace) in normalize2: 60 | s = re.sub(pattern, replace, s) 61 | return s.split() 62 | 63 | def count_ngrams(words, n=4): 64 | counts = {} 65 | for k in xrange(1,n+1): 66 | for i in xrange(len(words)-k+1): 67 | ngram = tuple(words[i:i+k]) 68 | counts[ngram] = counts.get(ngram, 0)+1 69 | return counts 70 | 71 | def cook_refs(refs, n=4): 72 | '''Takes a list of reference sentences for a single segment 73 | and returns an object that encapsulates everything that BLEU 74 | needs to know about them.''' 75 | 76 | refs = [normalize(ref) for ref in refs] 77 | maxcounts = {} 78 | for ref in refs: 79 | counts = count_ngrams(ref, n) 80 | for (ngram,count) in counts.iteritems(): 81 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 82 | return ([len(ref) for ref in refs], maxcounts) 83 | 84 | def cook_test(test, (reflens, refmaxcounts), n=4): 85 | '''Takes a test sentence and returns an object that 86 | encapsulates everything that BLEU needs to know about it.''' 87 | 88 | test = normalize(test) 89 | result = {} 90 | result["testlen"] = len(test) 91 | 92 | # Calculate effective reference sentence length. 93 | 94 | if eff_ref_len == "shortest": 95 | result["reflen"] = min(reflens) 96 | elif eff_ref_len == "average": 97 | result["reflen"] = float(sum(reflens))/len(reflens) 98 | elif eff_ref_len == "closest": 99 | min_diff = None 100 | for reflen in reflens: 101 | if min_diff is None or abs(reflen-len(test)) < min_diff: 102 | min_diff = abs(reflen-len(test)) 103 | result['reflen'] = reflen 104 | 105 | result["guess"] = [max(len(test)-k+1,0) for k in xrange(1,n+1)] 106 | 107 | result['correct'] = [0]*n 108 | counts = count_ngrams(test, n) 109 | for (ngram, count) in counts.iteritems(): 110 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) 111 | 112 | return result 113 | 114 | def score_cooked(allcomps, n=4): 115 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} 116 | for comps in allcomps: 117 | for key in ['testlen','reflen']: 118 | totalcomps[key] += comps[key] 119 | for key in ['guess','correct']: 120 | for k in xrange(n): 121 | totalcomps[key][k] += comps[key][k] 122 | logbleu = 0.0 123 | for k in xrange(n): 124 | if totalcomps['correct'][k] == 0: 125 | return 0.0 126 | log.write("%d-grams: %f\n" % (k,float(totalcomps['correct'][k])/totalcomps['guess'][k])) 127 | logbleu += math.log(totalcomps['correct'][k])-math.log(totalcomps['guess'][k]) 128 | logbleu /= float(n) 129 | log.write("Effective reference length: %d test length: %d\n" % (totalcomps['reflen'], totalcomps['testlen'])) 130 | logbleu += min(0,1-float(totalcomps['reflen'])/totalcomps['testlen']) 131 | return math.exp(logbleu) 132 | 133 | def score_set(set, testid, refids, n=4): 134 | alltest = [] 135 | for seg in set.segs(): 136 | try: 137 | test = seg.versions[testid].words 138 | except KeyError: 139 | log.write("Warning: missing test sentence\n") 140 | continue 141 | try: 142 | refs = [seg.versions[refid].words for refid in refids] 143 | except KeyError: 144 | log.write("Warning: missing reference sentence, %s\n" % seg.id) 145 | refs = cook_refs(refs, n) 146 | alltest.append(cook_test(test, refs, n)) 147 | log.write("%d sentences\n" % len(alltest)) 148 | return score_cooked(alltest, n) 149 | 150 | if __name__ == "__main__": 151 | import psyco 152 | psyco.full() 153 | 154 | import getopt 155 | raw_test = False 156 | (opts,args) = getopt.getopt(sys.argv[1:], "rc", []) 157 | for (opt,parm) in opts: 158 | if opt == "-r": 159 | raw_test = True 160 | elif opt == "-c": 161 | preserve_case = True 162 | 163 | s = dataset.Dataset() 164 | if args[0] == '-': 165 | infile = sys.stdin 166 | else: 167 | infile = args[0] 168 | if raw_test: 169 | (root, testids) = s.read_raw(infile, docid='whatever', sysid='testsys') 170 | else: 171 | (root, testids) = s.read(infile) 172 | print "Test systems: %s" % ", ".join(testids) 173 | (root, refids) = s.read(args[1]) 174 | print "Reference systems: %s" % ", ".join(refids) 175 | 176 | for testid in testids: 177 | print "BLEU score: ", score_set(s, testid, refids) 178 | 179 | 180 | -------------------------------------------------------------------------------- /deep_dialog/usersims/NLG/eval/multi-bleu.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | 3 | # $Id$ 4 | use strict; 5 | 6 | my $lowercase = 0; 7 | if ($ARGV[0] eq "-lc") { 8 | $lowercase = 1; 9 | shift; 10 | } 11 | 12 | my $stem = $ARGV[0]; 13 | if (!defined $stem) { 14 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n"; 15 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 16 | exit(1); 17 | } 18 | 19 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 20 | 21 | my @REF; 22 | my $ref=0; 23 | while(-e "$stem$ref") { 24 | &add_to_ref("$stem$ref",\@REF); 25 | $ref++; 26 | } 27 | &add_to_ref($stem,\@REF) if -e $stem; 28 | die("ERROR: could not find reference file $stem") unless scalar @REF; 29 | 30 | sub add_to_ref { 31 | my ($file,$REF) = @_; 32 | my $s=0; 33 | open(REF,$file) or die "Can't read $file"; 34 | while() { 35 | chop; 36 | push @{$$REF[$s++]}, $_; 37 | } 38 | close(REF); 39 | } 40 | 41 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 42 | my $s=0; 43 | while() { 44 | chop; 45 | $_ = lc if $lowercase; 46 | my @WORD = split; 47 | my %REF_NGRAM = (); 48 | my $length_translation_this_sentence = scalar(@WORD); 49 | my ($closest_diff,$closest_length) = (9999,9999); 50 | foreach my $reference (@{$REF[$s]}) { 51 | # print "$s $_ <=> $reference\n"; 52 | $reference = lc($reference) if $lowercase; 53 | my @WORD = split(' ',$reference); 54 | my $length = scalar(@WORD); 55 | my $diff = abs($length_translation_this_sentence-$length); 56 | if ($diff < $closest_diff) { 57 | $closest_diff = $diff; 58 | $closest_length = $length; 59 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 60 | } elsif ($diff == $closest_diff) { 61 | $closest_length = $length if $length < $closest_length; 62 | # from two references with the same closeness to me 63 | # take the *shorter* into account, not the "first" one. 64 | } 65 | for(my $n=1;$n<=4;$n++) { 66 | my %REF_NGRAM_N = (); 67 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 68 | my $ngram = "$n"; 69 | for(my $w=0;$w<$n;$w++) { 70 | $ngram .= " ".$WORD[$start+$w]; 71 | } 72 | $REF_NGRAM_N{$ngram}++; 73 | } 74 | foreach my $ngram (keys %REF_NGRAM_N) { 75 | if (!defined($REF_NGRAM{$ngram}) || 76 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 77 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 78 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 79 | } 80 | } 81 | } 82 | } 83 | $length_translation += $length_translation_this_sentence; 84 | $length_reference += $closest_length; 85 | for(my $n=1;$n<=4;$n++) { 86 | my %T_NGRAM = (); 87 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 88 | my $ngram = "$n"; 89 | for(my $w=0;$w<$n;$w++) { 90 | $ngram .= " ".$WORD[$start+$w]; 91 | } 92 | $T_NGRAM{$ngram}++; 93 | } 94 | foreach my $ngram (keys %T_NGRAM) { 95 | $ngram =~ /^(\d+) /; 96 | my $n = $1; 97 | # my $corr = 0; 98 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 99 | $TOTAL[$n] += $T_NGRAM{$ngram}; 100 | if (defined($REF_NGRAM{$ngram})) { 101 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 102 | $CORRECT[$n] += $T_NGRAM{$ngram}; 103 | # $corr = $T_NGRAM{$ngram}; 104 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 105 | } 106 | else { 107 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 108 | # $corr = $REF_NGRAM{$ngram}; 109 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 110 | } 111 | } 112 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 113 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 114 | } 115 | } 116 | $s++; 117 | } 118 | my $brevity_penalty = 1; 119 | my $bleu = 0; 120 | 121 | my @bleu=(); 122 | 123 | for(my $n=1;$n<=4;$n++) { 124 | if (defined ($TOTAL[$n])){ 125 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 126 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 127 | }else{ 128 | $bleu[$n]=0; 129 | } 130 | } 131 | 132 | if ($length_reference==0){ 133 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 134 | exit(1); 135 | } 136 | 137 | #if ($length_translation<$length_reference) { 138 | # $brevity_penalty = exp(1-$length_reference/$length_translation); 139 | #} 140 | 141 | #$bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 142 | # my_log( $bleu[2] ) + 143 | # my_log( $bleu[3] ) + 144 | # my_log( $bleu[4] ) ) / 4) ; 145 | 146 | my $bleu_1 = $brevity_penalty * exp((my_log( $bleu[1] ))); 147 | 148 | my $bleu_2 = $brevity_penalty * exp((my_log( $bleu[1] ) + 149 | my_log( $bleu[2] ) ) / 2) ; 150 | 151 | my $bleu_3 = $brevity_penalty * exp((my_log( $bleu[1] ) + 152 | my_log( $bleu[2] ) + 153 | my_log( $bleu[3] ) ) / 3) ; 154 | 155 | my $bleu_4 = $brevity_penalty * exp((my_log( $bleu[1] ) + 156 | my_log( $bleu[2] ) + 157 | my_log( $bleu[3] ) + 158 | my_log( $bleu[4] ) ) / 4) ; 159 | 160 | printf "BLEU = %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", 161 | 100*$bleu_1, 162 | 100*$bleu_2, 163 | 100*$bleu_3, 164 | 100*$bleu_4, 165 | $brevity_penalty, 166 | $length_translation / $length_reference, 167 | $length_translation, 168 | $length_reference; 169 | 170 | sub my_log { 171 | return -9999999999 unless $_[0]; 172 | return log($_[0]); 173 | } 174 | 175 | 176 | -------------------------------------------------------------------------------- /deep_dialog/usersims/NLG/fileio/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_set import * -------------------------------------------------------------------------------- /deep_dialog/usersims/__init__.py: -------------------------------------------------------------------------------- 1 | from .usersim_rule import * 2 | from .template_nlg import * 3 | from .s2s_nlg import * 4 | from .user_cmd import * 5 | -------------------------------------------------------------------------------- /deep_dialog/usersims/s2s_nlg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 3 | Takes user action and produces NL utterance using Xiujun's NLG 4 | Backs off to template nlg 5 | ''' 6 | 7 | import cPickle as pkl 8 | import random 9 | import copy 10 | import sys 11 | 12 | from NLG import predict 13 | from NLG.decoders.lstm_decoder_tanh import lstm_decoder_tanh 14 | from NLG.decoders.decoder import decoder 15 | 16 | BEAM_SIZE = 3 17 | SAMPLING = 1 18 | 19 | class S2SNLG: 20 | def __init__(self, template_file, slot_file, model_file, temp): 21 | self.templates = pkl.load(open(template_file, 'rb')) 22 | self._read_slots(slot_file) 23 | self._load_model(model_file, temp) 24 | 25 | def _load_model(self, model_path, temp): 26 | model_params = pkl.load(open(model_path, 'rb')) 27 | hidden_size = model_params['model']['Wd'].shape[0] 28 | output_size = model_params['model']['Wd'].shape[1] 29 | 30 | model_params['params']['beam_size'] = BEAM_SIZE 31 | model_params['params']['decoder_sampling'] = SAMPLING 32 | model_params['params']['temp'] = temp 33 | 34 | if model_params['params']['model'] == 'lstm_tanh': # lstm_tanh 35 | diaact_input_size = model_params['model']['Wah'].shape[0] 36 | input_size = model_params['model']['WLSTM'].shape[0] - hidden_size - 1 37 | self.rnnmodel = lstm_decoder_tanh(diaact_input_size, input_size, hidden_size, output_size) 38 | self.rnnmodel.model = copy.deepcopy(model_params['model']) 39 | self.model_params = model_params 40 | 41 | def _read_slots(self, fil): 42 | f = open(fil,'r') 43 | self.slots = [] 44 | for line in f: 45 | self.slots.append(line.rstrip()) 46 | 47 | def generate(self, act, request_slots, inform_slots): 48 | if all([r in self.slots for r in request_slots.keys()]) and \ 49 | all([i in self.slots for i in inform_slots.keys()]): 50 | return self.generate_from_nlg(act, request_slots, inform_slots) 51 | else: 52 | return self.generate_from_template(act, request_slots, inform_slots) 53 | 54 | def generate_from_nlg(self, act, request_slots, inform_slots): 55 | act_string = act + '(' 56 | for s,v in request_slots.iteritems(): 57 | act_string += '%s=%s;' % (s,v) if v!='UNK' else '%s;' %s 58 | i_slots = {k:v for k,v in inform_slots.iteritems() if v is not None} 59 | for s,v in i_slots.iteritems(): 60 | act_string += '%s=%s;' % (s,v) 61 | act_string = act_string.rstrip(';') 62 | act_string += ')' 63 | sent = predict.generate(self.model_params, self.rnnmodel, act_string) 64 | try: 65 | out = unicode(sent) 66 | except UnicodeDecodeError: 67 | out = unicode(sent.decode('utf8')) 68 | return out 69 | 70 | def generate_from_template(self, act, request_slots, inform_slots): 71 | n_r = len(request_slots.keys()) 72 | i_slots = {k:v for k,v in inform_slots.iteritems() if v is not None} 73 | n_i = len(i_slots.keys()) 74 | key = '%s_%d_%d' % (act, n_r, n_i) 75 | 76 | temp = random.choice(self.templates[key]) 77 | sent = self._fill_slots(temp, request_slots, i_slots) 78 | 79 | return unicode(sent) 80 | 81 | def _fill_slots(self, temp, request_slots, i_slots): 82 | reqs = request_slots.keys() 83 | infs = i_slots.keys() 84 | random.shuffle(reqs) 85 | random.shuffle(infs) 86 | 87 | for i,k in enumerate(reqs): 88 | temp = temp.replace('@rslot%d'%i, k) 89 | 90 | for i,k in enumerate(infs): 91 | temp = temp.replace('@islot%d'%i, k) 92 | temp = temp.replace('@ival%d'%i, i_slots[k]) 93 | 94 | return temp 95 | 96 | if __name__=='__main__': 97 | temp_file = '../data/templates.p' 98 | slot_file = 'NLG/data/slot_set.txt' 99 | model_file = 'NLG/checkpoints/nlg_infobot/lstm_tanh_[1470015675.73]_115_120_0.657.p' 100 | 101 | acts = ['inform', 'request'] 102 | slots = ['actor', 'director', 'release_year', 'genre', 'mpaa_rating', 'critic_rating'] 103 | 104 | nlg = S2SNLG(temp_file, slot_file, model_file, 2.0) 105 | 106 | for i in range(10000): 107 | a = random.choice(acts) 108 | if a=='inform': 109 | i_slots = [random.choice(slots)] 110 | inform_slots = {} 111 | for s in i_slots: 112 | inform_slots[s] = u'blah' 113 | request_slots = {} 114 | print a, inform_slots, request_slots 115 | print nlg.generate(a, inform_slots, request_slots) 116 | else: 117 | request_slots = {} 118 | request_slots['moviename'] = 'UNK' 119 | inform_slots = {} 120 | i_slots = random.sample(slots, 2) 121 | for s in i_slots: 122 | inform_slots[s] = u'blah' 123 | print a, inform_slots, request_slots 124 | print nlg.generate(a, inform_slots, request_slots) 125 | -------------------------------------------------------------------------------- /deep_dialog/usersims/template_nlg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 3 | Takes user action and templates from file and produces NL utterance. 4 | ''' 5 | 6 | import cPickle as pkl 7 | import random 8 | 9 | class TemplateNLG: 10 | def __init__(self, template_file): 11 | self.templates = pkl.load(open(template_file, 'rb')) 12 | 13 | def generate(self, act, request_slots, inform_slots): 14 | n_r = len(request_slots.keys()) 15 | i_slots = {k:v for k,v in inform_slots.iteritems() if v is not None} 16 | n_i = len(i_slots.keys()) 17 | key = '%s_%d_%d' % (act, n_r, n_i) 18 | 19 | temp = random.choice(self.templates[key]) 20 | sent = self._fill_slots(temp, request_slots, i_slots) 21 | 22 | return unicode(sent) 23 | 24 | def _fill_slots(self, temp, request_slots, i_slots): 25 | reqs = request_slots.keys() 26 | infs = i_slots.keys() 27 | random.shuffle(reqs) 28 | random.shuffle(infs) 29 | 30 | for i,k in enumerate(reqs): 31 | temp = temp.replace('@rslot%d'%i, k) 32 | 33 | for i,k in enumerate(infs): 34 | temp = temp.replace('@islot%d'%i, k) 35 | temp = temp.replace('@ival%d'%i, i_slots[k]) 36 | 37 | return temp 38 | -------------------------------------------------------------------------------- /deep_dialog/usersims/user_cmd.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 3 | a rule-based user simulator 4 | 5 | ''' 6 | 7 | import argparse, json, time 8 | import random 9 | import copy 10 | import nltk 11 | import cPickle as pkl 12 | 13 | from deep_dialog import dialog_config 14 | from deep_dialog.tools import to_tokens 15 | from collections import defaultdict 16 | 17 | DOMAIN_NAME = 'movie' 18 | 19 | GENERIC = ['I dont know', 'I cannot remember', 'I am not sure'] 20 | 21 | def weighted_choice(choices, weights): 22 | total = sum(weights) 23 | r = random.uniform(0, total) 24 | upto = 0 25 | for c, w in zip(choices,weights): 26 | if upto + w >= r: 27 | return c 28 | upto += w 29 | assert False, "shouldnt get here" 30 | 31 | class CmdUser: 32 | def __init__(self, movie_dict=None, act_set=None, slot_set=None, 33 | start_set=None, max_turn=20, err_prob=0., db=None, 34 | dk_prob=0., sub_prob=0., max_first_turn=5, 35 | fdict_path=None): 36 | self.max_turn = dialog_config.MAX_TURN 37 | self.movie_dict = movie_dict 38 | self.act_set = act_set 39 | self.slot_set = slot_set 40 | self.start_set = start_set 41 | self.err_prob = err_prob 42 | self.database = db 43 | self.dk_prob = dk_prob 44 | self.sub_prob = sub_prob 45 | self.max_first_turn = max_first_turn 46 | self.grams = self._load_vocab(fdict_path) 47 | self.N = 2 48 | 49 | def _load_vocab(self, path): 50 | if path is None: return set() 51 | else: return pkl.load(open(path,'rb')) 52 | 53 | def _vocab_search(self, text): 54 | tokens = to_tokens(text) 55 | for i in range(len(tokens)): 56 | for t in range(self.N): 57 | if i-t<0: continue 58 | ngram = '_'.join(tokens[i-t:i+1]) 59 | if ngram in self.grams: 60 | return True 61 | return False 62 | 63 | ''' show target entity and known slots (corrupted) to user 64 | and get NL input ''' 65 | def prompt_input(self, agent_act, turn): 66 | print '' 67 | print 'turn ', str(turn) 68 | print 'agent action: ', agent_act 69 | print 'target ', DOMAIN_NAME, ': ', self.database.labels[self.goal['target']] 70 | print 'known slots: ', ' '.join( 71 | ['%s={ %s }' %(k,' , '.join(vv for vv in v)) 72 | for k,v in self.state['inform_slots_noisy'].iteritems()]) 73 | inp = raw_input('your input: ') 74 | if not self._vocab_search(inp): return random.choice(GENERIC) 75 | else: return inp 76 | 77 | ''' display agent results at end of dialog ''' 78 | def display_results(self, ranks, reward, turns): 79 | print '' 80 | print 'agent results: ', ', '.join([self.database.labels[ii] for ii in ranks[:5]]) 81 | print 'target movie rank = ', ranks.index(self.goal['target']) + 1 82 | if reward > 0: print 'successful dialog!' 83 | else: print 'failed dialog' 84 | print 'number of turns = ', str(turns) 85 | 86 | ''' randomly sample a start state ''' 87 | def _sample_action(self): 88 | self.state = {} 89 | 90 | self.state['diaact'] = '' 91 | self.state['turn'] = 0 92 | self.state['inform_slots'] = {} 93 | self.state['request_slots'] = {} 94 | self.state['prev_diaact'] = 'UNK' 95 | 96 | self.corrupt() 97 | sent = self.prompt_input('Hi! I am Info-Bot. I can help you search for movies if you tell me their attributes!', 0).lower() 98 | if sent=='quit': episode_over=True 99 | else: episode_over=False 100 | 101 | self.state['nl_sentence'] = sent 102 | self.state['episode_over'] = episode_over 103 | self.state['reward'] = 0 104 | self.state['goal'] = self.goal['target'] 105 | 106 | return episode_over, self.state 107 | 108 | ''' sample a goal ''' 109 | def _sample_goal(self): 110 | if self.start_set is not None: 111 | self.goal = random.choice(self.start_set) # sample user's goal from the dataset 112 | else: 113 | # sample a DB record as target 114 | self.goal = {} 115 | self.goal['request_slots'] = {} 116 | self.goal['request_slots'][DOMAIN_NAME] = 'UNK' 117 | self.goal['target'] = random.randint(0,self.database.N-1) 118 | self.goal['inform_slots'] = {} 119 | known_slots = [s for i,s in enumerate(dialog_config.inform_slots) 120 | if self.database.tuples[self.goal['target']][i]!='UNK'] 121 | care_about = random.sample(known_slots, int(self.dk_prob*len(known_slots))) 122 | for i,s in enumerate(self.database.slots): 123 | if s not in dialog_config.inform_slots: continue 124 | val = self.database.tuples[self.goal['target']][i] 125 | if s in care_about and val!='UNK': 126 | self.goal['inform_slots'][s] = val 127 | else: 128 | self.goal['inform_slots'][s] = None 129 | if all([v==None for v in self.goal['inform_slots'].values()]): 130 | while True: 131 | s = random.choice(self.goal['inform_slots'].keys()) 132 | i = self.database.slots.index(s) 133 | val = self.database.tuples[self.goal['target']][i] 134 | if val!='UNK': 135 | self.goal['inform_slots'][s] = val 136 | break 137 | 138 | def print_goal(self): 139 | print 'User target = ', ', '.join(['%s:%s' %(s,v) for s,v in \ 140 | zip(['movie']+self.database.slots, \ 141 | [self.database.labels[self.goal['target']]] + \ 142 | self.database.tuples[self.goal['target']])]) 143 | print 'User information = ', ', '.join(['%s:%s' %(s,v) for s,v in \ 144 | self.goal['inform_slots'].iteritems() if v is not None]), '\n' 145 | 146 | ''' initialization ''' 147 | def initialize_episode(self): 148 | self._sample_goal() 149 | 150 | # first action 151 | episode_over, user_action = self._sample_action() 152 | assert (episode_over != 1),' but we just started' 153 | return user_action 154 | 155 | ''' update state: state is sys_action ''' 156 | def next(self, state): 157 | self.state['turn'] += 1 158 | reward = 0 159 | episode_over = False 160 | self.state['prev_diaact'] = self.state['diaact'] 161 | self.state['inform_slots'].clear() 162 | self.state['request_slots'].clear() 163 | 164 | act = state['diaact'] 165 | if act == 'inform': 166 | episode_over = True 167 | goal_rank = state['target'].index(self.goal['target']) 168 | if goal_rank < dialog_config.SUCCESS_MAX_RANK: 169 | reward = dialog_config.SUCCESS_DIALOG_REWARD*\ 170 | (1.-float(goal_rank)/dialog_config.SUCCESS_MAX_RANK) 171 | self.state['diaact'] = 'thanks' 172 | else: 173 | reward = dialog_config.FAILED_DIALOG_REWARD 174 | self.state['diaact'] = 'deny' 175 | self.display_results(state['target'], reward, self.state['turn']) 176 | else: 177 | slot = state['request_slots'].keys()[0] 178 | agent_act = act + ' ' + slot 179 | sent = self.prompt_input(agent_act, self.state['turn']).lower() 180 | if sent=='quit' or self.state['turn'] >= self.max_turn: episode_over=True 181 | reward = 0 182 | self.state['nl_sentence'] = sent 183 | 184 | self.state['episode_over'] = episode_over 185 | self.state['reward'] = reward 186 | 187 | return self.state, episode_over, 0 188 | 189 | ''' user may make mistakes ''' 190 | def corrupt(self): 191 | self.state['inform_slots_noisy'] = {} 192 | for slot in self.goal['inform_slots'].keys(): 193 | self.state['inform_slots_noisy'][slot] = set() 194 | if self.goal['inform_slots'][slot] is not None: 195 | cset = set([self.goal['inform_slots'][slot]]) 196 | prob_sub = random.random() 197 | if prob_sub < self.sub_prob: # substitute value 198 | cset.add(random.choice(self.movie_dict.dict[slot])) 199 | for item in cset: 200 | prob_err = random.random() 201 | if prob_err < self.err_prob: # corrupt value 202 | self.state['inform_slots_noisy'][slot].update( 203 | self._corrupt_value(item)) 204 | #else: 205 | # self.state['inform_slots_noisy'][slot].add(item) 206 | self.state['inform_slots_noisy'][slot].add(item) 207 | 208 | def _corrupt_value(self, val): 209 | def _is_int(s): 210 | try: 211 | int(s) 212 | return True 213 | except ValueError: 214 | return False 215 | 216 | def _is_float(s): 217 | try: 218 | float(s) 219 | return True 220 | except ValueError: 221 | return False 222 | 223 | tokens = nltk.word_tokenize(val) 224 | if len(tokens)>1: 225 | tokens.pop(random.randrange(len(tokens))) 226 | out = set([' '.join(tokens)]) 227 | else: 228 | t = tokens[0] 229 | out = set() 230 | if _is_int(t): 231 | pert = round(random.gauss(0,0.5)) 232 | if pert>0: out.add('%d' %(int(t)+pert)) 233 | out.add(t) 234 | elif _is_float(t): 235 | pert = random.gauss(0,0.5) 236 | if pert>0.05: out.add('%.1f' %(float(t)+pert)) 237 | out.add(t) 238 | else: 239 | out.add(t) 240 | return out 241 | 242 | ''' user may make mistakes 243 | def corrupt(self): 244 | self.state['inform_slots_noisy'] = {} 245 | for slot in self.goal['inform_slots'].keys(): 246 | if self.goal['inform_slots'][slot]==None: 247 | self.state['inform_slots_noisy'][slot] = None 248 | else: 249 | prob_sub = random.random() 250 | if prob_sub < self.sub_prob: # substitute value 251 | self.state['inform_slots_noisy'][slot] = \ 252 | random.choice(self.movie_dict.dict[slot]) 253 | else: 254 | self.state['inform_slots_noisy'][slot] = self.goal['inform_slots'][slot] 255 | prob_err = random.random() 256 | if prob_err < self.err_prob: # corrupt value 257 | self.state['inform_slots_noisy'][slot] = \ 258 | self._corrupt_value(self.state['inform_slots_noisy'][slot]) 259 | 260 | def _corrupt_value(self, val): 261 | def _is_int(s): 262 | try: 263 | int(s) 264 | return True 265 | except ValueError: 266 | return False 267 | 268 | def _is_float(s): 269 | try: 270 | float(s) 271 | return True 272 | except ValueError: 273 | return False 274 | 275 | tokens = nltk.word_tokenize(val) 276 | if len(tokens)>1: tokens.pop(random.randrange(len(tokens))) 277 | out = [] 278 | for t in tokens: 279 | if _is_int(t): 280 | out.append(str(int(random.gauss(int(t),0.5)))) 281 | elif _is_float(t): 282 | out.append('%.1f' %random.gauss(float(t),0.5)) 283 | else: 284 | out.append(t) 285 | return ' '.join([o for o in out]) 286 | ''' 287 | 288 | ''' user state representation ''' 289 | def stateVector(self, action): 290 | vec = [0]*(len(self.act_set.dict) + len(self.slot_set.slot_ids)*2) 291 | 292 | if action['diaact'] in self.act_set.dict.keys(): vec[self.act_set.dict[action['diaact']]] = 1 293 | for slot in action['slots'].keys(): 294 | slot_id = self.slot_set.slot_ids[slot] * 2 + len(self.act_set.dict) 295 | slot_id += 1 296 | if action['slots'][slot] == 'UNK': vec[slot_id] =1 297 | 298 | return vec 299 | 300 | ''' print the state ''' 301 | def print_state(self, action): 302 | stateStr = 'Turn %d user action: %s, history slots: %s, inform_slots: %s, request slots: %s, rest_slots: %s' % (action['turn'], action['diaact'], action['history_slots'], action['inform_slots'], action['request_slots'], action['rest_slots']) 303 | print stateStr 304 | 305 | 306 | 307 | def main(params): 308 | user_sim = RuleSimulator() 309 | user_sim.init() 310 | 311 | 312 | 313 | if __name__ == "__main__": 314 | parser = argparse.ArgumentParser() 315 | 316 | args = parser.parse_args() 317 | params = vars(args) 318 | 319 | print 'User Simulator Parameters: ' 320 | print json.dumps(params, indent=2) 321 | 322 | main(params) 323 | -------------------------------------------------------------------------------- /deep_dialog/usersims/usersim_rule.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 3 | a rule-based user simulator 4 | 5 | ''' 6 | 7 | import argparse, json, time 8 | import random 9 | import copy 10 | import nltk 11 | 12 | from deep_dialog import dialog_config 13 | 14 | DOMAIN_NAME = 'movie' 15 | 16 | def weighted_choice(choices, weights): 17 | total = sum(weights) 18 | r = random.uniform(0, total) 19 | upto = 0 20 | for c, w in zip(choices,weights): 21 | if upto + w >= r: 22 | return c 23 | upto += w 24 | assert False, "shouldnt get here" 25 | 26 | class RuleSimulator: 27 | def __init__(self, movie_dict=None, act_set=None, slot_set=None, 28 | start_set=None, max_turn=20, nlg=None, err_prob=0., db=None, 29 | dk_prob=0., sub_prob=0., max_first_turn=5): 30 | self.max_turn = dialog_config.MAX_TURN 31 | self.movie_dict = movie_dict 32 | self.act_set = act_set 33 | self.slot_set = slot_set 34 | self.start_set = start_set 35 | self.nlg = nlg 36 | self.err_prob = err_prob 37 | self.database = db 38 | self.dk_prob = dk_prob 39 | self.sub_prob = sub_prob 40 | self.max_first_turn = max_first_turn 41 | 42 | ''' randomly sample a start state ''' 43 | def _sample_action(self): 44 | self.state = {} 45 | 46 | self.state['diaact'] = random.choice(dialog_config.start_dia_acts.keys()) 47 | self.state['turn'] = 0 48 | self.state['inform_slots'] = {} 49 | self.state['request_slots'] = {} 50 | self.state['prev_diaact'] = 'UNK' 51 | 52 | if (len(self.goal['inform_slots']) + len(self.goal['request_slots'])) > 0: 53 | if len(self.goal['inform_slots']) > 0: 54 | care_about = [s for s,v in self.goal['inform_slots'].iteritems() if v is not None] 55 | known_slots = random.sample(care_about, 56 | random.randint(1,min(self.max_first_turn,len(care_about)))) 57 | for s in known_slots: 58 | self.state['inform_slots'][s] = self.goal['inform_slots'][s] 59 | 60 | if len(self.goal['request_slots']) > 0: 61 | request_slot = random.choice(self.goal['request_slots'].keys()) 62 | self.state['request_slots'][request_slot] = 'UNK' 63 | 64 | if (self.state['diaact'] in ['thanks','closing']): episode_over = True 65 | else: episode_over = False 66 | 67 | if not episode_over: 68 | self.corrupt() 69 | 70 | sent = self.nlg.generate(self.state['diaact'],self.state['request_slots'], 71 | self.state['inform_slots_noisy']) if self.nlg is not None else '' 72 | self.state['nl_sentence'] = sent 73 | self.state['episode_over'] = episode_over 74 | self.state['reward'] = 0 75 | 76 | return episode_over, self.state 77 | 78 | ''' sample a goal ''' 79 | def _sample_goal(self): 80 | if self.start_set is not None: 81 | self.goal = random.choice(self.start_set) # sample user's goal from the dataset 82 | else: 83 | # sample a DB record as target 84 | self.goal = {} 85 | self.goal['request_slots'] = {} 86 | self.goal['request_slots'][DOMAIN_NAME] = 'UNK' 87 | self.goal['target'] = random.randint(0,self.database.N-1) 88 | self.goal['inform_slots'] = {} 89 | known_slots = [s for i,s in enumerate(dialog_config.inform_slots) 90 | if self.database.tuples[self.goal['target']][i]!='UNK'] 91 | care_about = random.sample(known_slots, int(self.dk_prob*len(known_slots))) 92 | for i,s in enumerate(self.database.slots): 93 | if s not in dialog_config.inform_slots: continue 94 | val = self.database.tuples[self.goal['target']][i] 95 | if s in care_about and val!='UNK': 96 | self.goal['inform_slots'][s] = val 97 | else: 98 | self.goal['inform_slots'][s] = None 99 | if all([v==None for v in self.goal['inform_slots'].values()]): 100 | while True: 101 | s = random.choice(self.goal['inform_slots'].keys()) 102 | i = self.database.slots.index(s) 103 | val = self.database.tuples[self.goal['target']][i] 104 | if val!='UNK': 105 | self.goal['inform_slots'][s] = val 106 | break 107 | 108 | def print_goal(self): 109 | print 'User target = ', ', '.join(['%s:%s' %(s,v) for s,v in \ 110 | zip(['movie']+self.database.slots, \ 111 | [self.database.labels[self.goal['target']]] + \ 112 | self.database.tuples[self.goal['target']])]) 113 | print 'User information = ', ', '.join(['%s:%s' %(s,v) for s,v in \ 114 | self.goal['inform_slots'].iteritems() if v is not None]), '\n' 115 | 116 | ''' initialization ''' 117 | def initialize_episode(self): 118 | self._sample_goal() 119 | 120 | # first action 121 | episode_over, user_action = self._sample_action() 122 | assert (episode_over != 1),' but we just started' 123 | return user_action 124 | 125 | ''' update state: state is sys_action ''' 126 | def next(self, state): 127 | self.state['turn'] += 1 128 | reward = 0 129 | episode_over = False 130 | self.state['prev_diaact'] = self.state['diaact'] 131 | self.state['inform_slots'].clear() 132 | self.state['request_slots'].clear() 133 | self.state['inform_slots_noisy'].clear() 134 | 135 | if (self.max_turn > 0 and self.state['turn'] >= self.max_turn): 136 | reward = dialog_config.FAILED_DIALOG_REWARD 137 | episode_over = True 138 | self.state['diaact'] = 'deny' 139 | else: 140 | act = state['diaact'] 141 | if act == 'inform': 142 | episode_over = True 143 | goal_rank = state['target'].index(self.goal['target']) 144 | if goal_rank < dialog_config.SUCCESS_MAX_RANK: 145 | reward = dialog_config.SUCCESS_DIALOG_REWARD*\ 146 | (1.-float(goal_rank)/dialog_config.SUCCESS_MAX_RANK) 147 | self.state['diaact'] = 'thanks' 148 | else: 149 | reward = dialog_config.FAILED_DIALOG_REWARD 150 | self.state['diaact'] = 'deny' 151 | elif act == 'request': 152 | slot = state['request_slots'].keys()[0] 153 | if slot in self.goal['inform_slots']: 154 | self.state['inform_slots'][slot] = self.goal['inform_slots'][slot] 155 | else: 156 | self.state['inform_slots'][slot] = None 157 | self.state['diaact'] = 'inform' 158 | reward = dialog_config.PER_TURN_REWARD 159 | 160 | if not episode_over: 161 | self.corrupt() 162 | 163 | sent = self.nlg.generate(self.state['diaact'],self.state['request_slots'], 164 | self.state['inform_slots_noisy']) if self.nlg is not None else '' 165 | self.state['nl_sentence'] = sent 166 | self.state['episode_over'] = episode_over 167 | self.state['reward'] = reward 168 | 169 | return self.state, episode_over, reward 170 | 171 | ''' user may make mistakes ''' 172 | def corrupt(self): 173 | self.state['inform_slots_noisy'] = {} 174 | for slot in self.state['inform_slots'].keys(): 175 | if self.state['inform_slots'][slot]==None: 176 | self.state['inform_slots_noisy'][slot] = None 177 | else: 178 | prob_sub = random.random() 179 | if prob_sub < self.sub_prob: # substitute value 180 | self.state['inform_slots_noisy'][slot] = \ 181 | random.choice(self.movie_dict.dict[slot]) 182 | else: 183 | self.state['inform_slots_noisy'][slot] = self.state['inform_slots'][slot] 184 | prob_err = random.random() 185 | if prob_err < self.err_prob: # corrupt value 186 | self.state['inform_slots_noisy'][slot] = \ 187 | self._corrupt_value(self.state['inform_slots_noisy'][slot]) 188 | 189 | def _corrupt_value(self, val): 190 | def _is_int(s): 191 | try: 192 | int(s) 193 | return True 194 | except ValueError: 195 | return False 196 | 197 | def _is_float(s): 198 | try: 199 | float(s) 200 | return True 201 | except ValueError: 202 | return False 203 | 204 | tokens = nltk.word_tokenize(val) 205 | if len(tokens)>1: tokens.pop(random.randrange(len(tokens))) 206 | out = [] 207 | for t in tokens: 208 | if _is_int(t): 209 | out.append(str(int(random.gauss(int(t),0.5)))) 210 | elif _is_float(t): 211 | out.append('%.1f' %random.gauss(float(t),0.5)) 212 | else: 213 | out.append(t) 214 | return ' '.join([o for o in out]) 215 | 216 | ''' user state representation ''' 217 | def stateVector(self, action): 218 | vec = [0]*(len(self.act_set.dict) + len(self.slot_set.slot_ids)*2) 219 | 220 | if action['diaact'] in self.act_set.dict.keys(): vec[self.act_set.dict[action['diaact']]] = 1 221 | for slot in action['slots'].keys(): 222 | slot_id = self.slot_set.slot_ids[slot] * 2 + len(self.act_set.dict) 223 | slot_id += 1 224 | if action['slots'][slot] == 'UNK': vec[slot_id] =1 225 | 226 | return vec 227 | 228 | ''' print the state ''' 229 | def print_state(self, action): 230 | stateStr = 'Turn %d user action: %s, history slots: %s, inform_slots: %s, request slots: %s, rest_slots: %s' % (action['turn'], action['diaact'], action['history_slots'], action['inform_slots'], action['request_slots'], action['rest_slots']) 231 | print stateStr 232 | 233 | 234 | 235 | def main(params): 236 | user_sim = RuleSimulator() 237 | user_sim.init() 238 | 239 | 240 | 241 | if __name__ == "__main__": 242 | parser = argparse.ArgumentParser() 243 | 244 | args = parser.parse_args() 245 | params = vars(args) 246 | 247 | print 'User Simulator Parameters: ' 248 | print json.dumps(params, indent=2) 249 | 250 | main(params) 251 | -------------------------------------------------------------------------------- /interact.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | 4 | WELCOME="""\nHello and welcome to the InfoBot! Please take a moment to read the instructions below on how to interact with the system. 5 | 6 | BACKGROUND: The InfoBot helps users search a database for entities (in our case movies) based on its attributes (in our case any of actor, director, release-year, critic-rating, mpaa-rating). 7 | 8 | INSTRUCTIONS: In each interaction a movie will be selected at random from the database and presented to you, along with some of its attributes. To simulate a real-world scenario where the user may not know all attribute values perfectly, multiple noisy values may be presented (separated by ','). For example, the dialog may start as follows: 9 | 10 | turn 0 11 | agent action: Hi! I am Info-Bot. I can help you search for movies if you tell me their attributes! 12 | target movie : louis c.k.: shameless 13 | known slots: release_year={ 2007 } critic_rating={ 6.5 , 8.7 } actor={} director={} mpaa_rating={ tv-ma } genre={} 14 | your input: 15 | 16 | Only some of the slot values will be provided. In the above example critic-rating may be either 6.5 or 8.7, please select one value when informing the agent. At this stage you must initiate the dialog by asking the InfoBot for a movie which matches some of the provided attributes. You may specify all attributes in one go or only a subset of them. Please frame your inputs in natural language and try to provide a diverse vareity of inputs. For example: 17 | 18 | your input: which movie has critic rating 6.5? 19 | 20 | In each subsequent turn, the agent will either request for an attribute or inform results from the database. In case of the latter the dialog will end. A typical turn may look like: 21 | 22 | turn 1 23 | agent action: request actor 24 | target movie : louis c.k.: shameless 25 | known slots: release_year={ 2007 } critic_rating={ 6.5 , 8.7 } actor={} director={} mpaa_rating={ tv-ma } genre={} 26 | your input: 27 | 28 | Here the agent is requesting for the actor of the movie. Since the actor is not in one of the known slots, you may respond by saying: 29 | 30 | your input: i dont know 31 | 32 | This is just an example, you may respond anyway you like. Be creative! 33 | 34 | At the end of the dialog, the agent will inform the top 5 matches from the database, which will be checked if they contain the correct movie: 35 | 36 | agent results: night catches us, spider man 3, precious, she's out of my league, pineapple 37 | target movie rank = 169 38 | failed dialog 39 | number of turns = 4 40 | 41 | This is it. After this a new dialog will be initiated. 42 | 43 | Type 'quit' to end the current dialog (it will be considered a failure). Press Ctrl-C at any time to exit the application.""" 44 | 45 | import argparse, json, shutil, sys, os, random, copy 46 | import numpy as np 47 | import cPickle as pkl 48 | import datetime 49 | import importlib 50 | 51 | agent_map = {'rule-no' : 'nl-rule-no', 52 | 'rl-no' : 'simple-rl-no', 53 | 'rule-hard' : 'nl-rule-hard', 54 | 'rl-hard' : 'simple-rl-hard', 55 | 'rule-soft' : 'nl-rule-soft', 56 | 'rl-soft' : 'simple-rl-soft', 57 | 'e2e-soft' : 'e2e-rl-soft', 58 | } 59 | 60 | sys.setrecursionlimit(10000) 61 | 62 | """ Conduct dialogs between InfoBot agents and real users 63 | """ 64 | 65 | parser = argparse.ArgumentParser() 66 | 67 | parser.add_argument('--agent', dest='agent', type=str, default='rl-soft', 68 | help='Agent to run -- (rule-no / rl-no / rule-hard / rl-hard / rule-soft / rl-soft / e2e-soft') 69 | 70 | args = parser.parse_args() 71 | params = vars(args) 72 | 73 | params['N'] = 1000 74 | params['db'] = 'imdb-M' 75 | params['act_set'] = './data/dia_acts.txt' 76 | params['max_turn'] = 20 77 | params['err_prob'] = 0.5 78 | params['dontknow_prob'] = 0.5 79 | params['sub_prob'] = 0.05 80 | params['max_first_turn'] = 5 81 | config = importlib.import_module('settings.config_'+params['db']) 82 | agent_params = config.agent_params 83 | dataset_params = config.dataset_params 84 | for k,v in dataset_params[params['db']].iteritems(): 85 | params[k] = v 86 | 87 | max_turn = params['max_turn'] 88 | err_prob = params['err_prob'] 89 | dk_prob = params['dontknow_prob'] 90 | N = params['N'] 91 | 92 | datadir = './data/' + params['dataset'] 93 | db_full_path = datadir + '/db.txt' 94 | db_inc_path = datadir + '/incomplete_db_%.2f.txt' %params['unk'] 95 | dict_path = datadir + '/dicts.json' 96 | slot_path = datadir + '/slot_set.txt' 97 | corpus_path = './data/corpora/' + params['dataset'] + '_corpus.txt' 98 | 99 | from deep_dialog.dialog_system import DialogManager, MovieDict, DictReader, Database 100 | from deep_dialog.agents import AgentActRule, AgentNLRuleSoft, AgentNLRuleHard, AgentNLRuleNoDB, AgentSimpleRLAllAct, AgentSimpleRLAllActHardDB, AgentSimpleRLAllActNoDB, AgentE2ERLAllAct 101 | from deep_dialog.usersims import CmdUser 102 | from deep_dialog.objects import SlotReader 103 | 104 | act_set = DictReader() 105 | act_set.load_dict_from_file(params['act_set']) 106 | 107 | slot_set = SlotReader(slot_path) 108 | 109 | movie_kb = MovieDict(dict_path) 110 | 111 | db_full = Database(db_full_path, movie_kb, name=params['dataset']) 112 | db_inc = Database(db_inc_path, movie_kb, name='incomplete%.2f_'%params['unk']+params['dataset']) 113 | 114 | user_sim = CmdUser(movie_kb, act_set, slot_set, None, max_turn, err_prob, db_full, \ 115 | dk_prob, sub_prob=params['sub_prob'], max_first_turn=params['max_first_turn'], 116 | fdict_path = 'data/'+params['db']+'/fdict_2.p') 117 | 118 | # load all agents 119 | print WELCOME 120 | print "Loading agents... This may take a few minutes" 121 | agent_type = agent_map[params['agent']] 122 | for k,v in agent_params[agent_type].iteritems(): 123 | params[k] = v 124 | params['model_name'] = 'best_'+agent_type+'_imdb.m' 125 | 126 | if agent_type == 'simple-rl-soft': 127 | agent = AgentSimpleRLAllAct(movie_kb, act_set, slot_set, db_inc, train=False, _reload=True, 128 | n_hid=params['nhid'], batch=params['batch'], ment=params['ment'], 129 | inputtype=params['input'], 130 | pol_start=params['pol_start'], lr=params['lr'], upd=params['upd'], 131 | tr=params['tr'], ts=params['ts'], frac=params['frac'], max_req=params['max_req'], 132 | name=params['model_name']) 133 | elif agent_type == 'simple-rl-hard': 134 | agent = AgentSimpleRLAllActHardDB(movie_kb, act_set, slot_set, db_inc, train=False, 135 | _reload=True, 136 | n_hid=params['nhid'], batch=params['batch'], ment=params['ment'], 137 | inputtype=params['input'], 138 | pol_start=params['pol_start'], lr=params['lr'], upd=params['upd'], 139 | ts=params['ts'], frac=params['frac'], max_req=params['max_req'], 140 | name=params['model_name']) 141 | elif agent_type == 'simple-rl-no': 142 | agent = AgentSimpleRLAllActNoDB(movie_kb, act_set, slot_set, db_inc, train=False, 143 | _reload=True, 144 | n_hid=params['nhid'], batch=params['batch'], ment=params['ment'], 145 | inputtype=params['input'], 146 | pol_start=params['pol_start'], lr=params['lr'], upd=params['upd'], 147 | ts=params['ts'], frac=params['frac'], max_req=params['max_req'], 148 | name=params['model_name']) 149 | elif agent_type == 'e2e-rl-soft': 150 | agent = AgentE2ERLAllAct(movie_kb, act_set, slot_set, db_inc, corpus_path, train=False, 151 | _reload=True, n_hid=params['nhid'], batch=params['batch'], ment=params['ment'], 152 | lr=params['lr'], N=params['featN'], 153 | inputtype=params['input'], sl=params['sl'], rl=params['rl'], 154 | pol_start=params['pol_start'], tr=params['tr'], ts=params['ts'], frac=params['frac'], 155 | max_req=params['max_req'], upd=params['upd'], name=params['model_name']) 156 | else: 157 | print "Invalid Agent" 158 | sys.exit() 159 | 160 | uname = raw_input("Please Enter User Name: ").lower() 161 | uid = hash(uname) 162 | 163 | cdir = "sessions/"+str(uid)+'_'+datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')+"/" 164 | if not os.path.exists(cdir): os.makedirs(cdir) 165 | f = open(os.path.join(cdir,'credentials'), 'w') 166 | f.write(uname) 167 | f.close() 168 | try: 169 | for i in range(N): 170 | print "--------------------------------------------------------------------------------" 171 | print "Dialog %d" %i 172 | dia = [] 173 | curr_agent = agent 174 | dia.append(curr_agent) 175 | dialog_manager = DialogManager(curr_agent, user_sim, db_full, db_inc, movie_kb, verbose=False) 176 | utt = dialog_manager.initialize_episode() 177 | dia.append(copy.deepcopy(utt)) 178 | total_reward = 0 179 | while(True): 180 | episode_over, reward, utt, agact = dialog_manager.next_turn() 181 | dia.append(agact) 182 | dia.append(copy.deepcopy(utt)) 183 | total_reward += reward 184 | if episode_over: 185 | break 186 | pkl.dump(dia, open(cdir+str(i)+".p",'w')) 187 | except KeyboardInterrupt: 188 | sys.exit() 189 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Lasagne==0.2.dev1 2 | nltk==3.1 3 | numpy==1.12.1 4 | Theano==0.9.0.dev1 5 | dataset==0.8.0 6 | psyco==1.6 7 | -------------------------------------------------------------------------------- /settings/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiuLab/KB-InfoBot/f472695fa083020825f799919c90a37235a5bb28/settings/__init__.py -------------------------------------------------------------------------------- /settings/config_imdb-L.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: @t-bhdhi 3 | Created on August 5, 2016 4 | 5 | This file specifies the default parameter settings for all agents, user simulator and the database. 6 | ''' 7 | 8 | agent_params = {} 9 | 10 | agent_params['act-rule'] = {} 11 | agent_params['act-rule']['tr'] = 1.0 12 | agent_params['act-rule']['ts'] = 1.0 13 | agent_params['act-rule']['max_req'] = 1 14 | agent_params['act-rule']['frac'] = 0.5 15 | agent_params['act-rule']['upd'] = 10 16 | 17 | agent_params['nl-rule-no'] = {} 18 | agent_params['nl-rule-no']['ts'] = 1.0 19 | agent_params['nl-rule-no']['max_req'] = 4 20 | agent_params['nl-rule-no']['frac'] = 0.5 21 | agent_params['nl-rule-no']['upd'] = 05 22 | 23 | agent_params['nl-rule-hard'] = {} 24 | agent_params['nl-rule-hard']['ts'] = 5.0 25 | agent_params['nl-rule-hard']['max_req'] = 2 26 | agent_params['nl-rule-hard']['frac'] = 0.1 27 | agent_params['nl-rule-hard']['upd'] = 05 28 | 29 | agent_params['nl-rule-soft'] = {} 30 | agent_params['nl-rule-soft']['tr'] = 1.0 31 | agent_params['nl-rule-soft']['ts'] = 1.0 32 | agent_params['nl-rule-soft']['max_req'] = 1 33 | agent_params['nl-rule-soft']['frac'] = 0.5 34 | agent_params['nl-rule-soft']['upd'] = 10 35 | 36 | agent_params['simple-rl-soft'] = {} 37 | agent_params['simple-rl-soft']['tr'] = 1.0 38 | agent_params['simple-rl-soft']['ts'] = 1.0 39 | agent_params['simple-rl-soft']['max_req'] = 1 40 | agent_params['simple-rl-soft']['frac'] = 0.5 41 | agent_params['simple-rl-soft']['upd'] = 10 42 | agent_params['simple-rl-soft']['input'] = 'entropy' 43 | agent_params['simple-rl-soft']['pol_start'] = 0 44 | agent_params['simple-rl-soft']['nhid'] = 50 45 | agent_params['simple-rl-soft']['lr'] = 0.05 46 | agent_params['simple-rl-soft']['batch'] = 128 47 | agent_params['simple-rl-soft']['ment'] = 0. 48 | 49 | agent_params['simple-rl-hard'] = {} 50 | agent_params['simple-rl-hard']['ts'] = 5.0 51 | agent_params['simple-rl-hard']['max_req'] = 2 52 | agent_params['simple-rl-hard']['frac'] = 0.1 53 | agent_params['simple-rl-hard']['upd'] = 05 54 | agent_params['simple-rl-hard']['input'] = 'entropy' 55 | agent_params['simple-rl-hard']['pol_start'] = 0 56 | agent_params['simple-rl-hard']['nhid'] = 50 57 | agent_params['simple-rl-hard']['lr'] = 0.05 58 | agent_params['simple-rl-hard']['batch'] = 128 59 | agent_params['simple-rl-hard']['ment'] = 0. 60 | 61 | agent_params['simple-rl-no'] = {} 62 | agent_params['simple-rl-no']['ts'] = 1.0 63 | agent_params['simple-rl-no']['max_req'] = 4 64 | agent_params['simple-rl-no']['frac'] = 0.5 65 | agent_params['simple-rl-no']['upd'] = 05 66 | agent_params['simple-rl-no']['input'] = 'entropy' 67 | agent_params['simple-rl-no']['pol_start'] = 0 68 | agent_params['simple-rl-no']['nhid'] = 50 69 | agent_params['simple-rl-no']['lr'] = 0.05 70 | agent_params['simple-rl-no']['batch'] = 128 71 | agent_params['simple-rl-no']['ment'] = 0. 72 | 73 | agent_params['e2e-rl-soft'] = {} 74 | agent_params['e2e-rl-soft']['tr'] = 1.0 75 | agent_params['e2e-rl-soft']['ts'] = 1.0 76 | agent_params['e2e-rl-soft']['max_req'] = 1 77 | agent_params['e2e-rl-soft']['frac'] = 0.5 78 | agent_params['e2e-rl-soft']['upd'] = 10 79 | agent_params['e2e-rl-soft']['input'] = 'entropy' 80 | agent_params['e2e-rl-soft']['pol_start'] = 500 81 | agent_params['e2e-rl-soft']['nhid'] = 100 82 | agent_params['e2e-rl-soft']['lr'] = 0.05 83 | agent_params['e2e-rl-soft']['featN'] = 2 84 | agent_params['e2e-rl-soft']['batch'] = 128 85 | agent_params['e2e-rl-soft']['ment'] = 0. 86 | agent_params['e2e-rl-soft']['sl'] = 'e2e' 87 | agent_params['e2e-rl-soft']['rl'] = 'e2e' 88 | 89 | dataset_params = {} 90 | 91 | dataset_params['imdb-S'] = {} 92 | dataset_params['imdb-S']['dataset'] = 'imdb-S' 93 | dataset_params['imdb-S']['unk'] = 0.20 94 | 95 | dataset_params['imdb-M'] = {} 96 | dataset_params['imdb-M']['dataset'] = 'imdb-M' 97 | dataset_params['imdb-M']['unk'] = 0.20 98 | 99 | dataset_params['imdb-L'] = {} 100 | dataset_params['imdb-L']['dataset'] = 'imdb-L' 101 | dataset_params['imdb-L']['unk'] = 0.20 102 | 103 | dataset_params['imdb-XL'] = {} 104 | dataset_params['imdb-XL']['dataset'] = 'imdb-XL' 105 | dataset_params['imdb-XL']['unk'] = 0.20 106 | -------------------------------------------------------------------------------- /settings/config_imdb-M.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: @t-bhdhi 3 | Created on August 5, 2016 4 | 5 | This file specifies the default parameter settings for all agents, user simulator and the database. 6 | ''' 7 | 8 | agent_params = {} 9 | 10 | agent_params['nl-rule-no'] = {} 11 | agent_params['nl-rule-no']['ts'] = 0.5 12 | agent_params['nl-rule-no']['max_req'] = 4 13 | agent_params['nl-rule-no']['frac'] = 0.5 14 | agent_params['nl-rule-no']['upd'] = 05 15 | 16 | agent_params['nl-rule-hard'] = {} 17 | agent_params['nl-rule-hard']['ts'] = 5.0 18 | agent_params['nl-rule-hard']['max_req'] = 1 19 | agent_params['nl-rule-hard']['frac'] = 0.1 20 | agent_params['nl-rule-hard']['upd'] = 10 21 | 22 | agent_params['nl-rule-soft'] = {} 23 | agent_params['nl-rule-soft']['tr'] = 1.0 24 | agent_params['nl-rule-soft']['ts'] = 1.0 25 | agent_params['nl-rule-soft']['max_req'] = 1 26 | agent_params['nl-rule-soft']['frac'] = 0.5 27 | agent_params['nl-rule-soft']['upd'] = 10 28 | 29 | agent_params['simple-rl-soft'] = {} 30 | agent_params['simple-rl-soft']['tr'] = 0.5 31 | agent_params['simple-rl-soft']['ts'] = 1.0 32 | agent_params['simple-rl-soft']['max_req'] = 2 33 | agent_params['simple-rl-soft']['frac'] = 0.5 34 | agent_params['simple-rl-soft']['upd'] = 5 35 | agent_params['simple-rl-soft']['input'] = 'entropy' 36 | agent_params['simple-rl-soft']['pol_start'] = 0 37 | agent_params['simple-rl-soft']['nhid'] = 50 38 | agent_params['simple-rl-soft']['lr'] = 0.05 39 | agent_params['simple-rl-soft']['batch'] = 128 40 | agent_params['simple-rl-soft']['ment'] = 0. 41 | 42 | agent_params['simple-rl-hard'] = {} 43 | agent_params['simple-rl-hard']['ts'] = 0.5 44 | agent_params['simple-rl-hard']['max_req'] = 2 45 | agent_params['simple-rl-hard']['frac'] = 0.5 46 | agent_params['simple-rl-hard']['upd'] = 5 47 | agent_params['simple-rl-hard']['input'] = 'entropy' 48 | agent_params['simple-rl-hard']['pol_start'] = 0 49 | agent_params['simple-rl-hard']['nhid'] = 50 50 | agent_params['simple-rl-hard']['lr'] = 0.05 51 | agent_params['simple-rl-hard']['batch'] = 128 52 | agent_params['simple-rl-hard']['ment'] = 0. 53 | 54 | agent_params['simple-rl-no'] = {} 55 | agent_params['simple-rl-no']['ts'] = 1.0 56 | agent_params['simple-rl-no']['max_req'] = 2 57 | agent_params['simple-rl-no']['frac'] = 0.5 58 | agent_params['simple-rl-no']['upd'] = 5 59 | agent_params['simple-rl-no']['input'] = 'entropy' 60 | agent_params['simple-rl-no']['pol_start'] = 0 61 | agent_params['simple-rl-no']['nhid'] = 50 62 | agent_params['simple-rl-no']['lr'] = 0.05 63 | agent_params['simple-rl-no']['batch'] = 128 64 | agent_params['simple-rl-no']['ment'] = 0. 65 | 66 | agent_params['e2e-rl-soft'] = {} 67 | agent_params['e2e-rl-soft']['tr'] = 0.5 68 | agent_params['e2e-rl-soft']['ts'] = 1.0 69 | agent_params['e2e-rl-soft']['max_req'] = 2 70 | agent_params['e2e-rl-soft']['frac'] = 0.5 71 | agent_params['e2e-rl-soft']['upd'] = 5 72 | agent_params['e2e-rl-soft']['input'] = 'entropy' 73 | agent_params['e2e-rl-soft']['pol_start'] = 500 74 | agent_params['e2e-rl-soft']['nhid'] = 100 75 | agent_params['e2e-rl-soft']['lr'] = 0.05 76 | agent_params['e2e-rl-soft']['featN'] = 2 77 | agent_params['e2e-rl-soft']['batch'] = 128 78 | agent_params['e2e-rl-soft']['ment'] = 0. 79 | agent_params['e2e-rl-soft']['sl'] = 'e2e' 80 | agent_params['e2e-rl-soft']['rl'] = 'e2e' 81 | 82 | dataset_params = {} 83 | 84 | dataset_params['imdb-S'] = {} 85 | dataset_params['imdb-S']['dataset'] = 'imdb-S' 86 | dataset_params['imdb-S']['unk'] = 0.20 87 | 88 | dataset_params['imdb-M'] = {} 89 | dataset_params['imdb-M']['dataset'] = 'imdb-M' 90 | dataset_params['imdb-M']['unk'] = 0.20 91 | 92 | dataset_params['imdb-L'] = {} 93 | dataset_params['imdb-L']['dataset'] = 'imdb-L' 94 | dataset_params['imdb-L']['unk'] = 0.20 95 | 96 | dataset_params['imdb-XL'] = {} 97 | dataset_params['imdb-XL']['dataset'] = 'imdb-XL' 98 | dataset_params['imdb-XL']['unk'] = 0.20 99 | -------------------------------------------------------------------------------- /settings/config_imdb-S.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: @t-bhdhi 3 | Created on August 5, 2016 4 | 5 | This file specifies the default parameter settings for all agents, user simulator and the database. 6 | ''' 7 | 8 | agent_params = {} 9 | 10 | agent_params['act-rule'] = {} 11 | agent_params['act-rule']['tr'] = 1.0 12 | agent_params['act-rule']['ts'] = 1.0 13 | agent_params['act-rule']['max_req'] = 1 14 | agent_params['act-rule']['frac'] = 0.5 15 | agent_params['act-rule']['upd'] = 10 16 | 17 | agent_params['nl-rule-no'] = {} 18 | agent_params['nl-rule-no']['ts'] = 0.5 19 | agent_params['nl-rule-no']['max_req'] = 2 20 | agent_params['nl-rule-no']['frac'] = 0.5 21 | agent_params['nl-rule-no']['upd'] = 05 22 | 23 | agent_params['nl-rule-hard'] = {} 24 | agent_params['nl-rule-hard']['ts'] = 0.5 25 | agent_params['nl-rule-hard']['max_req'] = 2 26 | agent_params['nl-rule-hard']['frac'] = 0.5 27 | agent_params['nl-rule-hard']['upd'] = 05 28 | 29 | agent_params['nl-rule-soft'] = {} 30 | agent_params['nl-rule-soft']['tr'] = 5.0 31 | agent_params['nl-rule-soft']['ts'] = 0.5 32 | agent_params['nl-rule-soft']['max_req'] = 1 33 | agent_params['nl-rule-soft']['frac'] = 0.5 34 | agent_params['nl-rule-soft']['upd'] = 10 35 | 36 | agent_params['simple-rl-soft'] = {} 37 | agent_params['simple-rl-soft']['tr'] = 5.0 38 | agent_params['simple-rl-soft']['ts'] = 0.5 39 | agent_params['simple-rl-soft']['max_req'] = 1 40 | agent_params['simple-rl-soft']['frac'] = 0.5 41 | agent_params['simple-rl-soft']['upd'] = 10 42 | agent_params['simple-rl-soft']['input'] = 'entropy' 43 | agent_params['simple-rl-soft']['pol_start'] = 0 44 | agent_params['simple-rl-soft']['nhid'] = 50 45 | agent_params['simple-rl-soft']['lr'] = 0.05 46 | agent_params['simple-rl-soft']['batch'] = 128 47 | agent_params['simple-rl-soft']['ment'] = 0. 48 | 49 | agent_params['simple-rl-hard'] = {} 50 | agent_params['simple-rl-hard']['ts'] = 0.5 51 | agent_params['simple-rl-hard']['max_req'] = 2 52 | agent_params['simple-rl-hard']['frac'] = 0.5 53 | agent_params['simple-rl-hard']['upd'] = 05 54 | agent_params['simple-rl-hard']['input'] = 'entropy' 55 | agent_params['simple-rl-hard']['pol_start'] = 0 56 | agent_params['simple-rl-hard']['nhid'] = 50 57 | agent_params['simple-rl-hard']['lr'] = 0.05 58 | agent_params['simple-rl-hard']['batch'] = 128 59 | agent_params['simple-rl-hard']['ment'] = 0. 60 | 61 | agent_params['simple-rl-no'] = {} 62 | agent_params['simple-rl-no']['ts'] = 0.5 63 | agent_params['simple-rl-no']['max_req'] = 2 64 | agent_params['simple-rl-no']['frac'] = 0.5 65 | agent_params['simple-rl-no']['upd'] = 05 66 | agent_params['simple-rl-no']['input'] = 'entropy' 67 | agent_params['simple-rl-no']['pol_start'] = 0 68 | agent_params['simple-rl-no']['nhid'] = 50 69 | agent_params['simple-rl-no']['lr'] = 0.05 70 | agent_params['simple-rl-no']['batch'] = 128 71 | agent_params['simple-rl-no']['ment'] = 0. 72 | 73 | agent_params['e2e-rl-soft'] = {} 74 | agent_params['e2e-rl-soft']['tr'] = 5.0 75 | agent_params['e2e-rl-soft']['ts'] = 0.5 76 | agent_params['e2e-rl-soft']['max_req'] = 1 77 | agent_params['e2e-rl-soft']['frac'] = 0.5 78 | agent_params['e2e-rl-soft']['upd'] = 10 79 | agent_params['e2e-rl-soft']['input'] = 'entropy' 80 | agent_params['e2e-rl-soft']['pol_start'] = 500 81 | agent_params['e2e-rl-soft']['nhid'] = 100 82 | agent_params['e2e-rl-soft']['lr'] = 0.05 83 | agent_params['e2e-rl-soft']['featN'] = 2 84 | agent_params['e2e-rl-soft']['batch'] = 128 85 | agent_params['e2e-rl-soft']['ment'] = 0. 86 | agent_params['e2e-rl-soft']['sl'] = 'e2e' 87 | agent_params['e2e-rl-soft']['rl'] = 'e2e' 88 | 89 | dataset_params = {} 90 | 91 | dataset_params['imdb-S'] = {} 92 | dataset_params['imdb-S']['dataset'] = 'imdb-S' 93 | dataset_params['imdb-S']['unk'] = 0.20 94 | 95 | dataset_params['imdb-M'] = {} 96 | dataset_params['imdb-M']['dataset'] = 'imdb-M' 97 | dataset_params['imdb-M']['unk'] = 0.20 98 | 99 | dataset_params['imdb-L'] = {} 100 | dataset_params['imdb-L']['dataset'] = 'imdb-L' 101 | dataset_params['imdb-L']['unk'] = 0.20 102 | 103 | dataset_params['imdb-XL'] = {} 104 | dataset_params['imdb-XL']['dataset'] = 'imdb-XL' 105 | dataset_params['imdb-XL']['unk'] = 0.20 106 | -------------------------------------------------------------------------------- /settings/config_imdb-XL.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: @t-bhdhi 3 | Created on August 5, 2016 4 | 5 | This file specifies the default parameter settings for all agents, user simulator and the database. 6 | ''' 7 | 8 | agent_params = {} 9 | 10 | agent_params['act-rule'] = {} 11 | agent_params['act-rule']['tr'] = 1.0 12 | agent_params['act-rule']['ts'] = 1.0 13 | agent_params['act-rule']['max_req'] = 1 14 | agent_params['act-rule']['frac'] = 0.5 15 | agent_params['act-rule']['upd'] = 10 16 | 17 | agent_params['nl-rule-no'] = {} 18 | agent_params['nl-rule-no']['ts'] = 0.5 19 | agent_params['nl-rule-no']['max_req'] = 4 20 | agent_params['nl-rule-no']['frac'] = 0.5 21 | agent_params['nl-rule-no']['upd'] = 05 22 | 23 | agent_params['nl-rule-hard'] = {} 24 | agent_params['nl-rule-hard']['ts'] = 0.5 25 | agent_params['nl-rule-hard']['max_req'] = 1 26 | agent_params['nl-rule-hard']['frac'] = 0.5 27 | agent_params['nl-rule-hard']['upd'] = 05 28 | 29 | agent_params['nl-rule-soft'] = {} 30 | agent_params['nl-rule-soft']['tr'] = 1.0 31 | agent_params['nl-rule-soft']['ts'] = 0.5 32 | agent_params['nl-rule-soft']['max_req'] = 1 33 | agent_params['nl-rule-soft']['frac'] = 0.5 34 | agent_params['nl-rule-soft']['upd'] = 05 35 | 36 | agent_params['simple-rl-soft'] = {} 37 | agent_params['simple-rl-soft']['tr'] = 1.0 38 | agent_params['simple-rl-soft']['ts'] = 0.5 39 | agent_params['simple-rl-soft']['max_req'] = 1 40 | agent_params['simple-rl-soft']['frac'] = 0.5 41 | agent_params['simple-rl-soft']['upd'] = 05 42 | agent_params['simple-rl-soft']['input'] = 'entropy' 43 | agent_params['simple-rl-soft']['pol_start'] = 0 44 | agent_params['simple-rl-soft']['nhid'] = 50 45 | agent_params['simple-rl-soft']['lr'] = 0.05 46 | agent_params['simple-rl-soft']['batch'] = 128 47 | agent_params['simple-rl-soft']['ment'] = 0. 48 | 49 | agent_params['simple-rl-hard'] = {} 50 | agent_params['simple-rl-hard']['ts'] = 0.5 51 | agent_params['simple-rl-hard']['max_req'] = 1 52 | agent_params['simple-rl-hard']['frac'] = 0.5 53 | agent_params['simple-rl-hard']['upd'] = 05 54 | agent_params['simple-rl-hard']['input'] = 'entropy' 55 | agent_params['simple-rl-hard']['pol_start'] = 0 56 | agent_params['simple-rl-hard']['nhid'] = 50 57 | agent_params['simple-rl-hard']['lr'] = 0.05 58 | agent_params['simple-rl-hard']['batch'] = 128 59 | agent_params['simple-rl-hard']['ment'] = 0. 60 | 61 | agent_params['simple-rl-no'] = {} 62 | agent_params['simple-rl-no']['ts'] = 0.5 63 | agent_params['simple-rl-no']['max_req'] = 4 64 | agent_params['simple-rl-no']['frac'] = 0.5 65 | agent_params['simple-rl-no']['upd'] = 05 66 | agent_params['simple-rl-no']['input'] = 'entropy' 67 | agent_params['simple-rl-no']['pol_start'] = 0 68 | agent_params['simple-rl-no']['nhid'] = 50 69 | agent_params['simple-rl-no']['lr'] = 0.05 70 | agent_params['simple-rl-no']['batch'] = 128 71 | agent_params['simple-rl-no']['ment'] = 0. 72 | 73 | agent_params['e2e-rl-soft'] = {} 74 | agent_params['e2e-rl-soft']['tr'] = 1.0 75 | agent_params['e2e-rl-soft']['ts'] = 0.5 76 | agent_params['e2e-rl-soft']['max_req'] = 1 77 | agent_params['e2e-rl-soft']['frac'] = 0.5 78 | agent_params['e2e-rl-soft']['upd'] = 05 79 | agent_params['e2e-rl-soft']['input'] = 'entropy' 80 | agent_params['e2e-rl-soft']['pol_start'] = 500 81 | agent_params['e2e-rl-soft']['nhid'] = 100 82 | agent_params['e2e-rl-soft']['lr'] = 0.05 83 | agent_params['e2e-rl-soft']['featN'] = 2 84 | agent_params['e2e-rl-soft']['batch'] = 128 85 | agent_params['e2e-rl-soft']['ment'] = 0. 86 | agent_params['e2e-rl-soft']['sl'] = 'e2e' 87 | agent_params['e2e-rl-soft']['rl'] = 'e2e' 88 | 89 | dataset_params = {} 90 | 91 | dataset_params['imdb-S'] = {} 92 | dataset_params['imdb-S']['dataset'] = 'imdb-S' 93 | dataset_params['imdb-S']['unk'] = 0.20 94 | 95 | dataset_params['imdb-M'] = {} 96 | dataset_params['imdb-M']['dataset'] = 'imdb-M' 97 | dataset_params['imdb-M']['unk'] = 0.20 98 | 99 | dataset_params['imdb-L'] = {} 100 | dataset_params['imdb-L']['dataset'] = 'imdb-L' 101 | dataset_params['imdb-L']['unk'] = 0.20 102 | 103 | dataset_params['imdb-XL'] = {} 104 | dataset_params['imdb-XL']['dataset'] = 'imdb-XL' 105 | dataset_params['imdb-XL']['unk'] = 0.20 106 | -------------------------------------------------------------------------------- /sim.py: -------------------------------------------------------------------------------- 1 | import argparse, json, shutil, io, time 2 | import numpy as np 3 | 4 | from collections import Counter 5 | 6 | agent_map = {'rule-no' : 'nl-rule-no', 7 | 'rl-no' : 'simple-rl-no', 8 | 'rule-hard' : 'nl-rule-hard', 9 | 'rl-hard' : 'simple-rl-hard', 10 | 'rule-soft' : 'nl-rule-soft', 11 | 'rl-soft' : 'simple-rl-soft', 12 | 'e2e-soft' : 'e2e-rl-soft', 13 | } 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument('--agent', dest='agent_type', type=str, default='rule-soft', 18 | help='agent to use (rule-no / rl-no / rule-hard / rl-hard / rule-soft / rl-soft / e2e-soft)') 19 | parser.add_argument('--N', dest='N', type=int, default=5000, help='Number of simulations') 20 | parser.add_argument('--db', dest='db', type=str, default='imdb-M', 21 | help='imdb-(S/M/L/XL) -- This is the KB split to use, e.g. imdb-M') 22 | parser.add_argument('--max_turn', dest='max_turn', default=20, type=int, 23 | help='maximum length of each dialog (default=20, 0=no maximum length)') 24 | parser.add_argument('--err_prob', dest='err_prob', default=0.5, type=float, 25 | help='the probability of the user simulator corrupting a slot value') 26 | parser.add_argument('--dontknow_prob', dest='dontknow_prob', default=0.5, type=float, 27 | help='the probability that user simulator does not know a slot value') 28 | parser.add_argument('--sub_prob', dest='sub_prob', default=0.05, type=float, 29 | help='the probability that user simulator substitutes a slot value') 30 | parser.add_argument('--nlg_temp', dest='nlg_temp', type=float, default=1., 31 | help='Natural Language Generator softmax temperature (to control noise)') 32 | parser.add_argument('--max_first_turn', dest='max_first_turn', type=int, default=5, 33 | help='Maximum number of slots informed by user in first turn') 34 | parser.add_argument('--model_name', dest='model_name', type=str, default='pretrained', 35 | help='model name to evaluate (This should be the same as what you gave for training). Pass "pretrained" to use pretrained models.') 36 | 37 | args = parser.parse_args() 38 | params = vars(args) 39 | 40 | params['act_set'] = './data/dia_acts.txt' 41 | params['template_path'] = './data/templates.p' 42 | params['nlg_slots_path'] = './data/nlg_slot_set.txt' 43 | params['nlg_model_path'] = './data/pretrained/lstm_tanh_[1470015675.73]_115_120_0.657.p' 44 | 45 | config = importlib.import_module('settings.config_'+params['db']) 46 | agent_params = config.agent_params 47 | dataset_params = config.dataset_params 48 | for k,v in dataset_params[params['db']].iteritems(): 49 | params[k] = v 50 | for k,v in agent_params[agent_map[params['agent_type']]].iteritems(): 51 | params[k] = v 52 | 53 | print 'Dialog Parameters: ' 54 | print json.dumps(params, indent=2) 55 | 56 | 57 | max_turn = params['max_turn'] 58 | err_prob = params['err_prob'] 59 | dk_prob = params['dontknow_prob'] 60 | template_path = params['template_path'] 61 | agent_type = agent_map[params['agent_type']] 62 | N = params['N'] 63 | save_path = None 64 | 65 | datadir = './data/' + params['dataset'] 66 | db_full_path = datadir + '/db.txt' 67 | db_inc_path = datadir + '/incomplete_db_%.2f.txt' %params['unk'] 68 | dict_path = datadir + '/dicts.json' 69 | slot_path = datadir + '/slot_set.txt' 70 | corpus_path = './data/corpora/' + params['dataset'] + '_corpus.txt' 71 | 72 | from deep_dialog.dialog_system import DialogManager, MovieDict, DictReader, Database 73 | from deep_dialog.agents import AgentNLRuleSoft, AgentNLRuleHard, AgentNLRuleNoDB 74 | from deep_dialog.agents import AgentSimpleRLAllAct, AgentSimpleRLAllActHardDB 75 | from deep_dialog.agents import AgentSimpleRLAllActNoDB, AgentE2ERLAllAct 76 | from deep_dialog.usersims import RuleSimulator, TemplateNLG, S2SNLG 77 | from deep_dialog.objects import SlotReader 78 | 79 | act_set = DictReader() 80 | act_set.load_dict_from_file(params['act_set']) 81 | 82 | slot_set = SlotReader(slot_path) 83 | 84 | movie_kb = MovieDict(dict_path) 85 | 86 | db_full = Database(db_full_path, movie_kb, name=params['dataset']) 87 | db_inc = Database(db_inc_path, movie_kb, name='incomplete%.2f_'%params['unk']+params['dataset']) 88 | 89 | nlg = S2SNLG(template_path, params['nlg_slots_path'], params['nlg_model_path'], params['nlg_temp']) 90 | user_sim = RuleSimulator(movie_kb, act_set, slot_set, None, max_turn, nlg, err_prob, db_full, \ 91 | 1.-dk_prob, sub_prob=params['sub_prob'], max_first_turn=params['max_first_turn']) 92 | 93 | if params['model_name']=='pretrained': 94 | params['model_name'] = 'best_'+agent_type+'_imdb.m' 95 | if agent_type == 'act-rule': 96 | agent = AgentActRule(movie_kb, act_set, slot_set, db_inc, 97 | upd=params['upd'], tr=params['tr'], ts=params['ts'], 98 | frac=params['frac'], max_req=params['max_req']) 99 | elif agent_type == 'simple-rl-soft': 100 | agent = AgentSimpleRLAllAct(movie_kb, act_set, slot_set, db_inc, train=False, _reload=True, 101 | n_hid=params['nhid'], batch=params['batch'], ment=params['ment'], 102 | inputtype=params['input'], 103 | pol_start=params['pol_start'], lr=params['lr'], upd=params['upd'], tr=params['tr'], 104 | ts=params['ts'], frac=params['frac'], max_req=params['max_req'], 105 | name=params['model_name']) 106 | elif agent_type == 'simple-rl-hard': 107 | agent = AgentSimpleRLAllActHardDB(movie_kb, act_set, slot_set, db_inc, train=False, 108 | _reload=True, 109 | n_hid=params['nhid'], batch=params['batch'], ment=params['ment'], 110 | inputtype=params['input'], 111 | pol_start=params['pol_start'], lr=params['lr'], upd=params['upd'], 112 | ts=params['ts'], frac=params['frac'], max_req=params['max_req'], 113 | name=params['model_name']) 114 | elif agent_type == 'simple-rl-no': 115 | agent = AgentSimpleRLAllActNoDB(movie_kb, act_set, slot_set, db_inc, train=False, 116 | _reload=True, 117 | n_hid=params['nhid'], batch=params['batch'], ment=params['ment'], 118 | inputtype=params['input'], 119 | pol_start=params['pol_start'], lr=params['lr'], upd=params['upd'], 120 | ts=params['ts'], frac=params['frac'], max_req=params['max_req'], 121 | name=params['model_name']) 122 | elif agent_type == 'e2e-rl-soft': 123 | agent = AgentE2ERLAllAct(movie_kb, act_set, slot_set, db_inc, corpus_path, train=False, 124 | _reload=True, pol_start=params['pol_start'], sl=params['sl'], rl=params['rl'], 125 | n_hid=params['nhid'], batch=params['batch'], ment=params['ment'], lr=params['lr'], 126 | N=params['featN'], 127 | inputtype=params['input'], tr=params['tr'], ts=params['ts'], frac=params['frac'], 128 | max_req=params['max_req'], upd=params['upd'], name=params['model_name']) 129 | elif agent_type=='nl-rule-hard': 130 | agent = AgentNLRuleHard(movie_kb, act_set, slot_set, db_inc, corpus_path, 131 | ts=params['ts'], frac=params['frac'], 132 | max_req=params['max_req'], upd=params['upd']) 133 | elif agent_type=='nl-rule-soft': 134 | agent = AgentNLRuleSoft(movie_kb, act_set, slot_set, db_inc, corpus_path, 135 | tr=params['tr'], ts=params['ts'], frac=params['frac'], 136 | max_req=params['max_req'], upd=params['upd']) 137 | else: 138 | agent = AgentNLRuleNoDB(movie_kb, act_set, slot_set, db_inc, corpus_path, 139 | ts=params['ts'], frac=params['frac'], 140 | max_req=params['max_req'], upd=params['upd']) 141 | 142 | dialog_manager = DialogManager(agent, user_sim, db_full, db_inc, movie_kb, verbose=False) 143 | 144 | all_rewards = np.zeros((N,)) 145 | all_success = np.zeros((N,)) 146 | all_turns = np.zeros((N,)) 147 | if save_path is not None: fs = io.open(save_path, 'w') 148 | tst = time.time() 149 | 150 | for i in range(N): 151 | current_reward = 0 152 | current_success = False 153 | ua = dialog_manager.initialize_episode() 154 | utt = ua['nl_sentence'] 155 | if save_path is not None: fs.write(utt+'\n') 156 | t = 0 157 | while(True): 158 | t += 1 159 | episode_over, reward, ua, sa = dialog_manager.next_turn() 160 | utt = ua['nl_sentence'] 161 | if save_path is not None: fs.write(utt+'\n') 162 | current_reward += reward 163 | if episode_over: 164 | if reward > 0: 165 | print ("Successful Dialog! Total reward = {}".format(current_reward)) 166 | current_success = True 167 | else: 168 | print ("Failed Dialog! Total reward = {}".format(current_reward)) 169 | break 170 | all_rewards[i] = current_reward 171 | all_success[i] = 1 if current_success else 0 172 | all_turns[i] = t 173 | if save_path is not None: fs.close() 174 | time_elapsed = time.time()-tst 175 | nn = np.sqrt(N) 176 | print("Overall: {} times, (mean/std) {} / {} reward, {} / {} success rate, {} / {} turns, {} time elapsed".format(N, 177 | np.mean(all_rewards), np.std(all_rewards)/nn, np.mean(all_success), 178 | np.std(all_success)/nn, 179 | np.mean(all_turns), np.std(all_turns)/nn, time_elapsed)) 180 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse, json, shutil, time, sys 2 | import numpy as np 3 | import importlib 4 | 5 | from collections import Counter 6 | 7 | agent_map = {'rule-no' : 'nl-rule-no', 8 | 'rl-no' : 'simple-rl-no', 9 | 'rule-hard' : 'nl-rule-hard', 10 | 'rl-hard' : 'simple-rl-hard', 11 | 'rule-soft' : 'nl-rule-soft', 12 | 'rl-soft' : 'simple-rl-soft', 13 | 'e2e-soft' : 'e2e-rl-soft', 14 | } 15 | 16 | EVALF = 100 17 | 18 | parser = argparse.ArgumentParser() 19 | 20 | parser.add_argument('--agent', dest='agent_type', type=str, default='rule-soft', 21 | help='agent to use (rl-no / rl-hard / rl-soft / e2e-soft)') 22 | parser.add_argument('--db', dest='db', type=str, default='imdb-M', 23 | help='imdb-(S/M/L/XL) -- This is the KB split to use, e.g. imdb-M') 24 | parser.add_argument('--model_name', dest='model_name', type=str, default='no_name', 25 | help='model name to save') 26 | parser.add_argument('--N', dest='N', type=int, default=500000, help='Number of simulations') 27 | parser.add_argument('--max_turn', dest='max_turn', default=20, type=int, 28 | help='maximum length of each dialog (default=20, 0=no maximum length)') 29 | parser.add_argument('--nlg_temp', dest='nlg_temp', type=float, default=1., 30 | help='Natural Language Generator softmax temperature (to control noise)') 31 | parser.add_argument('--max_first_turn', dest='max_first_turn', type=int, default=5, 32 | help='Maximum number of slots informed by user in first turn') 33 | parser.add_argument('--err_prob', dest='err_prob', default=0.5, type=float, 34 | help='the probability of the user simulator corrupting a slot value') 35 | parser.add_argument('--dontknow_prob', dest='dontknow_prob', default=0.5, type=float, 36 | help='the probability that user simulator does not know a slot value') 37 | parser.add_argument('--sub_prob', dest='sub_prob', default=0.05, type=float, 38 | help='the probability that user simulator substitutes a slot value') 39 | parser.add_argument('--reload', dest='reload', type=int, default=0, 40 | help='Reload previously saved model (0-no, 1-yes)') 41 | 42 | args = parser.parse_args() 43 | params = vars(args) 44 | 45 | params['act_set'] = './data/dia_acts.txt' 46 | params['template_path'] = './data/templates.p' 47 | params['nlg_slots_path'] = './data/nlg_slot_set.txt' 48 | params['nlg_model_path'] = './data/pretrained/lstm_tanh_[1470015675.73]_115_120_0.657.p' 49 | 50 | config = importlib.import_module('settings.config_'+params['db']) 51 | agent_params = config.agent_params 52 | dataset_params = config.dataset_params 53 | for k,v in dataset_params[params['db']].iteritems(): 54 | params[k] = v 55 | for k,v in agent_params[agent_map[params['agent_type']]].iteritems(): 56 | params[k] = v 57 | 58 | print 'Dialog Parameters: ' 59 | print json.dumps(params, indent=2) 60 | 61 | max_turn = params['max_turn'] 62 | err_prob = params['err_prob'] 63 | dk_prob = params['dontknow_prob'] 64 | template_path = params['template_path'] 65 | agent_type = agent_map[params['agent_type']] 66 | N = params['N'] 67 | _reload = bool(params['reload']) 68 | 69 | datadir = './data/' + params['dataset'] 70 | db_full_path = datadir + '/db.txt' 71 | db_inc_path = datadir + '/incomplete_db_%.2f.txt' %params['unk'] 72 | dict_path = datadir + '/dicts.json' 73 | slot_path = datadir + '/slot_set.txt' 74 | corpus_path = './data/corpora/' + params['dataset'] + '_corpus.txt' 75 | 76 | from deep_dialog.dialog_system import DialogManager, MovieDict, DictReader, Database 77 | from deep_dialog.agents import AgentSimpleRLAllAct, AgentSimpleRLAllActHardDB 78 | from deep_dialog.agents import AgentSimpleRLAllActNoDB, AgentE2ERLAllAct 79 | from deep_dialog.usersims import RuleSimulator, TemplateNLG, S2SNLG 80 | from deep_dialog.objects import SlotReader 81 | from deep_dialog import dialog_config 82 | 83 | act_set = DictReader() 84 | act_set.load_dict_from_file(params['act_set']) 85 | 86 | slot_set = SlotReader(slot_path) 87 | 88 | movie_kb = MovieDict(dict_path) 89 | 90 | db_full = Database(db_full_path, movie_kb, name=params['dataset']) 91 | db_inc = Database(db_inc_path, movie_kb, name='incomplete%.2f_'%params['unk']+params['dataset']) 92 | 93 | nlg = S2SNLG(template_path, params['nlg_slots_path'], params['nlg_model_path'], 94 | params['nlg_temp']) 95 | user_sim = RuleSimulator(movie_kb, act_set, slot_set, None, max_turn, nlg, err_prob, db_full, \ 96 | 1.-dk_prob, sub_prob=params['sub_prob'], max_first_turn=params['max_first_turn']) 97 | 98 | if agent_type == 'simple-rl-soft': 99 | agent = AgentSimpleRLAllAct(movie_kb, act_set, slot_set, db_inc, _reload=_reload, 100 | n_hid=params['nhid'], 101 | batch=params['batch'], ment=params['ment'], inputtype=params['input'], 102 | pol_start=params['pol_start'], 103 | lr=params['lr'], upd=params['upd'], tr=params['tr'], ts=params['ts'], 104 | frac=params['frac'], max_req=params['max_req'], name=params['model_name']) 105 | agent_eval = AgentSimpleRLAllAct(movie_kb, act_set, slot_set, db_inc, train=False, 106 | _reload=False, n_hid=params['nhid'], 107 | batch=params['batch'], ment=params['ment'], inputtype=params['input'], 108 | pol_start=params['pol_start'], 109 | lr=params['lr'], upd=params['upd'], tr=params['tr'], ts=params['ts'], 110 | frac=params['frac'], max_req=params['max_req'], name=params['model_name']) 111 | elif agent_type == 'simple-rl-no': 112 | agent = AgentSimpleRLAllActNoDB(movie_kb, act_set, slot_set, db_inc, _reload=_reload, 113 | n_hid=params['nhid'], batch=params['batch'], ment=params['ment'], 114 | inputtype=params['input'], 115 | pol_start=params['pol_start'], lr=params['lr'], upd=params['upd'], 116 | ts=params['ts'], frac=params['frac'], max_req=params['max_req'], 117 | name=params['model_name']) 118 | agent_eval = AgentSimpleRLAllActNoDB(movie_kb, act_set, slot_set, db_inc, train=False, 119 | _reload=False, 120 | n_hid=params['nhid'], batch=params['batch'], ment=params['ment'], 121 | inputtype=params['input'], 122 | pol_start=params['pol_start'], lr=params['lr'], upd=params['upd'], 123 | ts=params['ts'], frac=params['frac'], max_req=params['max_req'], 124 | name=params['model_name']) 125 | elif agent_type == 'simple-rl-hard': 126 | agent = AgentSimpleRLAllActHardDB(movie_kb, act_set, slot_set, db_inc, _reload=_reload, 127 | n_hid=params['nhid'], batch=params['batch'], ment=params['ment'], 128 | inputtype=params['input'], 129 | pol_start=params['pol_start'], lr=params['lr'], upd=params['upd'], 130 | ts=params['ts'], frac=params['frac'], max_req=params['max_req'], 131 | name=params['model_name']) 132 | agent_eval = AgentSimpleRLAllActHardDB(movie_kb, act_set, slot_set, db_inc, train=False, 133 | _reload=False, 134 | n_hid=params['nhid'], batch=params['batch'], ment=params['ment'], 135 | inputtype=params['input'], 136 | pol_start=params['pol_start'], lr=params['lr'], upd=params['upd'], 137 | ts=params['ts'], frac=params['frac'], max_req=params['max_req'], 138 | name=params['model_name']) 139 | elif agent_type == 'e2e-rl-soft': 140 | agent = AgentE2ERLAllAct(movie_kb, act_set, slot_set, db_inc, corpus_path, _reload=_reload, 141 | n_hid=params['nhid'], batch=params['batch'], ment=params['ment'], 142 | inputtype=params['input'], sl=params['sl'], 143 | rl=params['rl'], pol_start=params['pol_start'], lr=params['lr'], N=params['featN'], 144 | tr=params['tr'], ts=params['ts'], frac=params['frac'], max_req=params['max_req'], 145 | upd=params['upd'], name=params['model_name']) 146 | agent_eval = AgentE2ERLAllAct(movie_kb, act_set, slot_set, db_inc, corpus_path, train=False, 147 | _reload=False, 148 | n_hid=params['nhid'], batch=params['batch'], ment=params['ment'], 149 | inputtype=params['input'], sl=params['sl'], 150 | rl=params['rl'], pol_start=params['pol_start'], lr=params['lr'], N=params['featN'], 151 | tr=params['tr'], ts=params['ts'], frac=params['frac'], max_req=params['max_req'], 152 | upd=params['upd'], name=params['model_name']) 153 | else: 154 | print "Invalid agent!" 155 | sys.exit() 156 | 157 | dialog_manager = DialogManager(agent, user_sim, db_full, db_inc, movie_kb, verbose=False) 158 | dialog_manager_eval = DialogManager(agent_eval, user_sim, db_full, db_inc, movie_kb, 159 | verbose=False) 160 | 161 | def eval_agent(ite, max_perf, best=False): 162 | num_iter = 2000 163 | nn = np.sqrt(num_iter) 164 | if best: agent_eval.load_model(dialog_config.MODEL_PATH+'best_'+agent_eval._name) 165 | else: agent_eval.load_model(dialog_config.MODEL_PATH+agent_eval._name) 166 | all_rewards = np.zeros((num_iter,)) 167 | all_success = np.zeros((num_iter,)) 168 | all_turns = np.zeros((num_iter,)) 169 | for i in range(num_iter): 170 | current_reward = 0 171 | current_success = False 172 | utt = dialog_manager_eval.initialize_episode() 173 | t = 0 174 | while(True): 175 | t += 1 176 | episode_over, reward, utt, sact = dialog_manager_eval.next_turn() 177 | current_reward += reward 178 | if episode_over: 179 | if reward > 0: 180 | current_success = True 181 | break 182 | all_rewards[i] = current_reward 183 | all_success[i] = 1 if current_success else 0 184 | all_turns[i] = t 185 | curr_perf = np.mean(all_rewards) 186 | print("EVAL {}: {} / {} reward {} / {} success rate {} / {} turns".format(ite, \ 187 | curr_perf, np.std(all_rewards)/nn, \ 188 | np.mean(all_success), np.std(all_success)/nn, \ 189 | np.mean(all_turns), np.std(all_turns)/nn)) 190 | if curr_perf>max_perf and not best: 191 | max_perf=curr_perf 192 | agent_eval.save_model(dialog_config.MODEL_PATH+'best_'+agent_eval._name) 193 | return max_perf 194 | 195 | print("Starting training") 196 | mp = -10. 197 | for i in range(N): 198 | if i%(EVALF*params['batch'])==0: 199 | mp = eval_agent(i,mp) 200 | utt = dialog_manager.initialize_episode() 201 | while(True): 202 | episode_over, reward, utt, sact = dialog_manager.next_turn() 203 | if episode_over: 204 | break 205 | perf = eval_agent('BEST',mp,best=True) 206 | --------------------------------------------------------------------------------